Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-03-23 10:13:45 +08:00
commit 6816f89189
233 changed files with 15445 additions and 1656 deletions

View File

@ -94,11 +94,6 @@ jobs:
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
- name: Setup web environment
if: steps.web-changes.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web

View File

@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75
uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -97,3 +97,8 @@ Feel free to reach out if you encounter any issues during the setup process.
## Getting Help
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
## Automated Agent Contributions
> [!NOTE]
> If you are an automated agent just add 🤖🤖🤖 to the end of the PR title to opt-in.

View File

@ -1,7 +1,7 @@
import flask_restx
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import select
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@ -33,16 +33,10 @@ api_key_list_model = console_ns.model(
def _get_resource(resource_id, tenant_id, resource_model):
if resource_model == App:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
else:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
if resource is None:
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")
@ -80,10 +74,13 @@ class BaseApiKeyListResource(Resource):
resource_id = str(resource_id)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count()
current_key_count: int = (
db.session.scalar(
select(func.count(ApiToken.id)).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
)
or 0
)
if current_key_count >= self.max_keys:
@ -119,14 +116,14 @@ class BaseApiKeyResource(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
key = (
db.session.query(ApiToken)
key = db.session.scalar(
select(ApiToken)
.where(
getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
.limit(1)
)
if key is None:
@ -137,7 +134,7 @@ class BaseApiKeyResource(Resource):
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
db.session.commit()
return {"result": "success"}, 204

View File

@ -5,7 +5,7 @@ from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -376,8 +376,12 @@ class CompletionConversationApi(Resource):
# FIXME, the type ignore in this file
if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
query = (
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
)
elif args.annotation_status == "not_annotated":
query = (
@ -511,8 +515,12 @@ class ChatConversationApi(Resource):
match args.annotation_status:
case "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
query = (
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
)
case "not_annotated":
query = (

View File

@ -1,5 +1,6 @@
from flask import request
from flask_restx import Resource
from sqlalchemy import select
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
@ -17,14 +18,18 @@ class BannerApi(Resource):
language = request.args.get("language", "en-US")
# Build base query for enabled banners
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
# Try to get banners in the requested language
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
banners = db.session.scalars(
base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort)
).all()
# Fallback to en-US if no banners found and language is not en-US
if not banners and language != "en-US":
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
banners = db.session.scalars(
base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort)
).all()
# Convert banners to serializable format
result = []
for banner in banners:

View File

@ -133,13 +133,15 @@ class InstalledAppsListApi(Resource):
def post(self):
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
recommended_app = db.session.scalar(
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1)
)
if recommended_app is None:
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == payload.app_id).first()
app = db.session.get(App, payload.app_id)
if app is None:
raise NotFound("App entity not found")
@ -147,10 +149,10 @@ class InstalledAppsListApi(Resource):
if not app.is_public:
raise Forbidden("You can't install a non-public app")
installed_app = (
db.session.query(InstalledApp)
installed_app = db.session.scalar(
select(InstalledApp)
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
.first()
.limit(1)
)
if installed_app is None:

View File

@ -4,6 +4,7 @@ from typing import Any, Literal, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@ -476,7 +477,7 @@ class TrialSitApi(Resource):
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise Forbidden()
@ -541,13 +542,7 @@ class AppWorkflowApi(Resource):
if not app_model.workflow_id:
raise AppUnavailableError()
workflow = (
db.session.query(Workflow)
.where(
Workflow.id == app_model.workflow_id,
)
.first()
)
workflow = db.session.get(Workflow, app_model.workflow_id)
return workflow

View File

@ -4,6 +4,7 @@ from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_restx import Resource
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
@ -24,10 +25,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
installed_app = (
db.session.query(InstalledApp)
installed_app = db.session.scalar(
select(InstalledApp)
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if installed_app is None:
@ -78,7 +79,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1))
if trial_app is None:
raise TrialAppNotAllowed()
@ -87,10 +88,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
account_trial_app_record = db.session.scalar(
select(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.first()
.limit(1)
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:

View File

@ -2,6 +2,7 @@ from typing import Literal
from flask import request
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from configs import dify_config
from controllers.fastopenapi import console_router
@ -100,6 +101,6 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
def get_setup_status() -> DifySetup | bool | None:
if dify_config.EDITION == "SELF_HOSTED":
return db.session.query(DifySetup).first()
return db.session.scalar(select(DifySetup).limit(1))
return True

View File

@ -212,13 +212,13 @@ class AccountInitApi(Resource):
raise ValueError("invitation_code is required")
# check invitation code
invitation_code = (
db.session.query(InvitationCode)
invitation_code = db.session.scalar(
select(InvitationCode)
.where(
InvitationCode.code == args.invitation_code,
InvitationCode.status == InvitationCodeStatus.UNUSED,
)
.first()
.limit(1)
)
if not invitation_code:

View File

@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first()
member = db.session.get(Account, str(member_id))
if member is None:
abort(404)
else:

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
@ -29,6 +30,7 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.file_service import FileService
@ -108,9 +110,29 @@ class TenantListApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
is_saas = dify_config.EDITION == "CLOUD" and dify_config.BILLING_ENABLED
tenant_plans: dict[str, SubscriptionPlan] = {}
if is_saas:
tenant_ids = [tenant.id for tenant in tenants]
if tenant_ids:
tenant_plans = BillingService.get_plan_bulk(tenant_ids)
if not tenant_plans:
logger.warning("get_plan_bulk returned empty result, falling back to legacy feature path")
for tenant in tenants:
features = FeatureService.get_features(tenant.id)
plan: str = CloudPlan.SANDBOX
if is_saas:
tenant_plan = tenant_plans.get(tenant.id)
if tenant_plan:
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
else:
features = FeatureService.get_features(tenant.id)
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
elif not is_enterprise_only:
features = FeatureService.get_features(tenant.id)
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
# Create a dictionary with tenant attributes
tenant_dict = {
@ -118,7 +140,7 @@ class TenantListApi(Resource):
"name": tenant.name,
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
"plan": plan,
"current": tenant.id == current_tenant_id if current_tenant_id else False,
}
@ -198,7 +220,7 @@ class SwitchWorkspaceApi(Resource):
except Exception:
raise AccountNotLinkTenantError("Account not link tenant")
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant
if new_tenant is None:
raise ValueError("Tenant not found")

View File

@ -7,6 +7,7 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from sqlalchemy import select
from configs import dify_config
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
@ -218,13 +219,9 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check setup
if (
dify_config.EDITION == "SELF_HOSTED"
and os.environ.get("INIT_PASSWORD")
and not db.session.query(DifySetup).first()
):
raise NotInitValidateError()
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)):
if os.environ.get("INIT_PASSWORD"):
raise NotInitValidateError()
raise NotSetupError()
return view(*args, **kwargs)

View File

@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.enums import ConversationFromSource, CreatorUserRole, MessageFileBelongsTo
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
@ -130,10 +130,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
end_user_id = None
account_id = None
if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
from_source = "api"
from_source = ConversationFromSource.API
end_user_id = application_generate_entity.user_id
else:
from_source = "console"
from_source = ConversationFromSource.CONSOLE
account_id = application_generate_entity.user_id
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):

View File

@ -6,7 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import CollectionBindingType
from models.enums import CollectionBindingType, ConversationFromSource
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
@ -68,9 +68,9 @@ class AnnotationReplyFeature:
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
from_source = "api"
from_source = ConversationFromSource.API
else:
from_source = "console"
from_source = ConversationFromSource.CONSOLE
# insert annotation history
AppAnnotationService.add_annotation_history(

View File

@ -284,27 +284,29 @@ class TidbOnQdrantVector(BaseVector):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
for node_id in ids:
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
if not ids:
return
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchAny(any=ids),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []

View File

@ -34,10 +34,12 @@ from .enums import (
AppMCPServerStatus,
AppStatus,
BannerStatus,
ConversationFromSource,
ConversationStatus,
CreatorUserRole,
FeedbackFromSource,
FeedbackRating,
InvokeFrom,
MessageChainType,
MessageFileBelongsTo,
MessageStatus,
@ -1022,10 +1024,12 @@ class Conversation(Base):
#
# Its value corresponds to the members of `InvokeFrom`.
# (api/core/app/entities/app_invoke_entities.py)
invoke_from = mapped_column(String(255), nullable=True)
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
# ref: ConversationSource.
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_source: Mapped[ConversationFromSource] = mapped_column(
EnumText(ConversationFromSource, length=255), nullable=False
)
from_end_user_id = mapped_column(StringUUID)
from_account_id = mapped_column(StringUUID)
read_at = mapped_column(sa.DateTime)
@ -1374,8 +1378,10 @@ class Message(Base):
)
error: Mapped[str | None] = mapped_column(LongText)
message_metadata: Mapped[str | None] = mapped_column(LongText)
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
from_source: Mapped[ConversationFromSource] = mapped_column(
EnumText(ConversationFromSource, length=255), nullable=False
)
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())

View File

@ -335,7 +335,11 @@ class BillingService:
# Redis returns bytes, decode to string and parse JSON
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
plan_dict = json.loads(json_str)
# NOTE (hj24): New billing versions may return timestamp as str, and validate_python
# in non-strict mode will coerce it to the expected int type.
# To preserve compatibility, always keep non-strict mode here and avoid strict mode.
subscription_plan = subscription_adapter.validate_python(plan_dict)
# NOTE END
tenant_plans[tenant_id] = subscription_plan
except Exception:
logger.exception(

View File

@ -1,10 +1,10 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, cast
from httpx import get
from sqlalchemy import select
from typing_extensions import TypedDict
from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool_runtime import ToolRuntime
@ -28,9 +28,16 @@ from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ApiSchemaParseResult(TypedDict):
schema_type: str
parameters_schema: list[dict[str, Any]]
credentials_schema: list[dict[str, Any]]
warning: dict[str, str]
class ApiToolManageService:
@staticmethod
def parser_api_schema(schema: str) -> Mapping[str, Any]:
def parser_api_schema(schema: str) -> ApiSchemaParseResult:
"""
parse api schema to tool bundle
"""
@ -71,7 +78,7 @@ class ApiToolManageService:
]
return cast(
Mapping,
ApiSchemaParseResult,
jsonable_encoder(
{
"schema_type": schema_type,

View File

@ -18,6 +18,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.types import Tool as MCPTool
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.utils.encryption import ProviderConfigEncrypter
from models.tools import MCPToolProvider
@ -681,7 +682,7 @@ class MCPToolManageService:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list[MCPTool]
) -> ToolProviderApiEntity:
"""Build API response for tool provider."""
user = db_provider.load_user()
@ -703,7 +704,7 @@ class MCPToolManageService:
raise ValueError(f"MCP tool {server_url} already exists")
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
raise error
def _is_valid_url(self, url: str) -> bool:
"""Validate URL format."""

View File

@ -1,5 +1,7 @@
import json
from typing import Any, TypedDict
from typing import Any
from typing_extensions import TypedDict
from core.app.app_config.entities import (
DatasetEntity,
@ -34,6 +36,17 @@ class _NodeType(TypedDict):
data: dict[str, Any]
class _EdgeType(TypedDict):
id: str
source: str
target: str
class WorkflowGraph(TypedDict):
nodes: list[_NodeType]
edges: list[_EdgeType]
class WorkflowConverter:
"""
App Convert to Workflow Mode
@ -107,7 +120,7 @@ class WorkflowConverter:
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
# init workflow graph
graph: dict[str, Any] = {"nodes": [], "edges": []}
graph: WorkflowGraph = {"nodes": [], "edges": []}
# Convert list:
# - variables -> start
@ -385,7 +398,7 @@ class WorkflowConverter:
self,
original_app_mode: AppMode,
new_app_mode: AppMode,
graph: dict,
graph: WorkflowGraph,
model_config: ModelConfigEntity,
prompt_template: PromptTemplateEntity,
file_upload: FileUploadConfig | None = None,
@ -595,7 +608,7 @@ class WorkflowConverter:
"data": {"title": "ANSWER", "type": BuiltinNodeTypes.ANSWER, "answer": "{{#llm.text#}}"},
}
def _create_edge(self, source: str, target: str):
def _create_edge(self, source: str, target: str) -> _EdgeType:
"""
Create Edge
:param source: source node id
@ -604,7 +617,7 @@ class WorkflowConverter:
"""
return {"id": f"{source}-{target}", "source": source, "target": target}
def _append_node(self, graph: dict[str, Any], node: _NodeType):
def _append_node(self, graph: WorkflowGraph, node: _NodeType):
"""
Append Node to Graph

View File

@ -5,6 +5,7 @@ from typing import Any
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
from dify_graph.enums import WorkflowExecutionStatus
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
@ -14,6 +15,10 @@ from services.plugin.plugin_service import PluginService
from services.workflow.entities import TriggerMetadata
class LogViewDetails(TypedDict):
trigger_metadata: dict[str, Any] | None
# Since the workflow_app_log table has exceeded 100 million records, we use an additional details field to extend it
class LogView:
"""Lightweight wrapper for WorkflowAppLog with computed details.
@ -22,12 +27,12 @@ class LogView:
- Proxies all other attributes to the underlying `WorkflowAppLog`
"""
def __init__(self, log: WorkflowAppLog, details: dict | None):
def __init__(self, log: WorkflowAppLog, details: LogViewDetails | None):
self.log = log
self.details_ = details
@property
def details(self) -> dict | None:
def details(self) -> LogViewDetails | None:
return self.details_
def __getattr__(self, name):

View File

@ -35,7 +35,7 @@ from factories.variable_factory import build_segment, segment_to_variable
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from models import Account, App, Conversation
from models.enums import DraftVariableType
from models.enums import ConversationFromSource, DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable
from repositories.factory import DifyAPIRepositoryFactory
from services.file_service import FileService
@ -601,7 +601,7 @@ class WorkflowDraftVariableService:
system_instruction_tokens=0,
status="normal",
invoke_from=InvokeFrom.DEBUGGER,
from_source="console",
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=account_id,
)

View File

@ -13,6 +13,7 @@ from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import App, Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.enums import ConversationFromSource
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -154,7 +155,7 @@ class TestChatMessageApiPermissions:
re_sign_file_url_answer="",
answer_tokens=0,
provider_response_latency=0.0,
from_source="console",
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=mock_account.id,
feedbacks=[],

View File

@ -165,8 +165,9 @@ class DifyTestContainers:
# Start Dify Sandbox container for code execution environment
# Dify Sandbox provides a secure environment for executing user code
# Use pinned version 0.2.12 to match production docker-compose configuration
logger.info("Initializing Dify Sandbox container...")
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest").with_network(self.network)
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.12").with_network(self.network)
self.dify_sandbox.with_exposed_ports(8194)
self.dify_sandbox.env = {
"API_KEY": "test_api_key",

View File

@ -13,7 +13,7 @@ from libs.datetime_utils import naive_utc_now
from libs.token import _real_cookie_name, generate_csrf_token
from models import Account, DifySetup, Tenant, TenantAccountJoin
from models.account import AccountStatus, TenantAccountRole
from models.enums import CreatorUserRole
from models.enums import ConversationFromSource, CreatorUserRole
from models.model import App, AppMode, Conversation, Message
from models.workflow import WorkflowRun
from services.account_service import AccountService
@ -75,7 +75,7 @@ def _create_conversation(db_session: Session, app_id: str, account_id: str) -> C
inputs={},
status="normal",
mode=AppMode.CHAT,
from_source=CreatorUserRole.ACCOUNT,
from_source=ConversationFromSource.CONSOLE,
from_account_id=account_id,
)
db_session.add(conversation)
@ -124,7 +124,7 @@ def _create_message(
answer_price_unit=0.001,
currency="USD",
status="normal",
from_source=CreatorUserRole.ACCOUNT,
from_source=ConversationFromSource.CONSOLE,
from_account_id=account_id,
workflow_run_id=workflow_run_id,
inputs={"query": "Hello"},

View File

@ -7,6 +7,7 @@ from uuid import uuid4
from dify_graph.nodes.human_input.entities import FormDefinition, UserAction
from models.account import Account, Tenant, TenantAccountJoin
from models.enums import ConversationFromSource, InvokeFrom
from models.execution_extra_content import HumanInputContent
from models.human_input import HumanInputForm, HumanInputFormStatus
from models.model import App, Conversation, Message
@ -78,8 +79,8 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
introduction="",
system_instruction="",
status="normal",
invoke_from="console",
from_source="console",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_account_id=account.id,
from_end_user_id=None,
)
@ -101,7 +102,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
answer_unit_price=Decimal("0.001"),
provider_response_latency=0.5,
currency="USD",
from_source="console",
from_source=ConversationFromSource.CONSOLE,
from_account_id=account.id,
workflow_run_id=workflow_run_id,
)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import secrets
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from unittest.mock import Mock
@ -12,15 +13,26 @@ from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session, sessionmaker
from dify_graph.entities import WorkflowExecution
from dify_graph.entities.pause_reason import PauseReasonType
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.human_input import (
BackstageRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
)
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_build_human_input_required_reason,
_PrivateWorkflowPauseEntity,
_WorkflowRunError,
)
@ -90,6 +102,19 @@ def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
WorkflowRun.app_id == scope.app_id,
)
)
form_ids_subquery = select(HumanInputForm.id).where(
HumanInputForm.tenant_id == scope.tenant_id,
HumanInputForm.app_id == scope.app_id,
)
session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery)))
session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery)))
session.execute(
delete(HumanInputForm).where(
HumanInputForm.tenant_id == scope.tenant_id,
HumanInputForm.app_id == scope.app_id,
)
)
session.commit()
for state_key in scope.state_keys:
@ -504,3 +529,200 @@ class TestDeleteWorkflowPause:
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"):
repository.delete_workflow_pause(pause_entity=pause_entity)
class TestPrivateWorkflowPauseEntity:
"""Integration tests for _PrivateWorkflowPauseEntity using real DB models."""
def test_properties(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Entity properties delegate to the persisted WorkflowPause model."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
db_session_with_containers.add(pause)
db_session_with_containers.commit()
db_session_with_containers.refresh(pause)
test_scope.state_keys.add(pause.state_object_key)
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
assert entity.id == pause.id
assert entity.workflow_execution_id == workflow_run.id
assert entity.resumed_at is None
def test_get_state(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""get_state loads state data from storage using the persisted state_object_key."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
state_key = f"workflow-state-{uuid4()}.json"
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=state_key,
)
db_session_with_containers.add(pause)
db_session_with_containers.commit()
db_session_with_containers.refresh(pause)
test_scope.state_keys.add(state_key)
expected_state = b'{"test": "state"}'
storage.save(state_key, expected_state)
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
result = entity.get_state()
assert result == expected_state
def test_get_state_caching(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""get_state caches the result so storage is only accessed once."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
state_key = f"workflow-state-{uuid4()}.json"
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=state_key,
)
db_session_with_containers.add(pause)
db_session_with_containers.commit()
db_session_with_containers.refresh(pause)
test_scope.state_keys.add(state_key)
expected_state = b'{"test": "state"}'
storage.save(state_key, expected_state)
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
result1 = entity.get_state()
# Delete from storage to prove second call uses cache
storage.delete(state_key)
test_scope.state_keys.discard(state_key)
result2 = entity.get_state()
assert result1 == expected_state
assert result2 == expected_state
class TestBuildHumanInputRequiredReason:
"""Integration tests for _build_human_input_required_reason using real DB models."""
def test_prefers_backstage_token_when_available(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Use backstage token when multiple recipient types may exist."""
expiration_time = naive_utc_now()
form_definition = FormDefinition(
form_content="content",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values={"name": "Alice"},
node_title="Ask Name",
display_in_ui=True,
)
form_model = HumanInputForm(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
workflow_run_id=str(uuid4()),
node_id="node-1",
form_definition=form_definition.model_dump_json(),
rendered_content="rendered",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
db_session_with_containers.add(form_model)
db_session_with_containers.flush()
delivery = HumanInputDelivery(
form_id=form_model.id,
delivery_method_type=DeliveryMethodType.WEBAPP,
channel_payload="{}",
)
db_session_with_containers.add(delivery)
db_session_with_containers.flush()
access_token = secrets.token_urlsafe(8)
recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=access_token,
)
db_session_with_containers.add(recipient)
db_session_with_containers.flush()
# Create a pause so the reason has a valid pause_id
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
db_session_with_containers.add(pause)
db_session_with_containers.flush()
test_scope.state_keys.add(pause.state_object_key)
reason_model = WorkflowPauseReason(
pause_id=pause.id,
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id=form_model.id,
node_id="node-1",
message="",
)
db_session_with_containers.add(reason_model)
db_session_with_containers.commit()
# Refresh to ensure we have DB-round-tripped objects
db_session_with_containers.refresh(form_model)
db_session_with_containers.refresh(reason_model)
db_session_with_containers.refresh(recipient)
reason = _build_human_input_required_reason(reason_model, form_model, [recipient])
assert isinstance(reason, HumanInputRequired)
assert reason.form_token == access_token
assert reason.node_title == "Ask Name"
assert reason.form_content == "content"
assert reason.inputs[0].output_variable_name == "name"
assert reason.actions[0].id == "approve"

View File

@ -0,0 +1,391 @@
"""Integration tests for get_paginated_workflow_runs and get_workflow_runs_count using testcontainers."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import timedelta
from uuid import uuid4
import pytest
from sqlalchemy import Engine, delete
from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import Session, sessionmaker
from dify_graph.entities import WorkflowExecution
from dify_graph.enums import WorkflowExecutionStatus
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun, WorkflowType
from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository
class _TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
"""Concrete repository for tests where save() is not under test."""
def save(self, execution: WorkflowExecution) -> None:
return None
@dataclass
class _TestScope:
"""Per-test data scope used to isolate DB rows."""
tenant_id: str = field(default_factory=lambda: str(uuid4()))
app_id: str = field(default_factory=lambda: str(uuid4()))
workflow_id: str = field(default_factory=lambda: str(uuid4()))
user_id: str = field(default_factory=lambda: str(uuid4()))
def _create_workflow_run(
session: Session,
scope: _TestScope,
*,
status: WorkflowExecutionStatus,
triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING,
created_at_offset: timedelta | None = None,
) -> WorkflowRun:
"""Create and persist a workflow run bound to the current test scope."""
now = naive_utc_now()
workflow_run = WorkflowRun(
id=str(uuid4()),
tenant_id=scope.tenant_id,
app_id=scope.app_id,
workflow_id=scope.workflow_id,
type=WorkflowType.WORKFLOW,
triggered_from=triggered_from,
version="draft",
graph="{}",
inputs="{}",
status=status,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=scope.user_id,
created_at=now + created_at_offset if created_at_offset is not None else now,
)
session.add(workflow_run)
session.commit()
return workflow_run
def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
"""Remove test-created DB rows for a test scope."""
session.execute(
delete(WorkflowRun).where(
WorkflowRun.tenant_id == scope.tenant_id,
WorkflowRun.app_id == scope.app_id,
)
)
session.commit()
@pytest.fixture
def repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowRunRepository:
"""Build a repository backed by the testcontainers database engine."""
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
return _TestWorkflowRunRepository(session_maker=sessionmaker(bind=engine, expire_on_commit=False))
@pytest.fixture
def test_scope(db_session_with_containers: Session) -> _TestScope:
"""Provide an isolated scope and clean related data after each test."""
scope = _TestScope()
yield scope
_cleanup_scope_data(db_session_with_containers, scope)
class TestGetPaginatedWorkflowRuns:
"""Integration tests for get_paginated_workflow_runs."""
def test_returns_runs_without_status_filter(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Return all runs for the given tenant/app when no status filter is applied."""
for status in (
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.RUNNING,
):
_create_workflow_run(db_session_with_containers, test_scope, status=status)
result = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status=None,
)
assert len(result.data) == 3
assert result.limit == 20
assert result.has_more is False
def test_filters_by_status(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Return only runs matching the requested status."""
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
result = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status="succeeded",
)
assert len(result.data) == 2
assert all(run.status == WorkflowExecutionStatus.SUCCEEDED for run in result.data)
def test_pagination_has_more(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Return has_more=True when more records exist beyond the limit."""
for i in range(5):
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
created_at_offset=timedelta(seconds=i),
)
result = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=3,
last_id=None,
status=None,
)
assert len(result.data) == 3
assert result.has_more is True
def test_cursor_based_pagination(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Cursor-based pagination returns the next page of results."""
for i in range(5):
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
created_at_offset=timedelta(seconds=i),
)
# First page
page1 = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=3,
last_id=None,
status=None,
)
assert len(page1.data) == 3
assert page1.has_more is True
# Second page using cursor
page2 = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=3,
last_id=page1.data[-1].id,
status=None,
)
assert len(page2.data) == 2
assert page2.has_more is False
# No overlap between pages
page1_ids = {r.id for r in page1.data}
page2_ids = {r.id for r in page2.data}
assert page1_ids.isdisjoint(page2_ids)
def test_invalid_last_id_raises(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
test_scope: _TestScope,
) -> None:
"""Raise ValueError when last_id refers to a non-existent run."""
with pytest.raises(ValueError, match="Last workflow run not exists"):
repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=str(uuid4()),
status=None,
)
def test_tenant_isolation(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Runs from other tenants are not returned."""
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
other_scope = _TestScope(app_id=test_scope.app_id)
try:
_create_workflow_run(db_session_with_containers, other_scope, status=WorkflowExecutionStatus.SUCCEEDED)
result = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status=None,
)
assert len(result.data) == 1
assert result.data[0].tenant_id == test_scope.tenant_id
finally:
_cleanup_scope_data(db_session_with_containers, other_scope)
class TestGetWorkflowRunsCount:
"""Integration tests for get_workflow_runs_count."""
def test_count_without_status_filter(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Count all runs grouped by status when no status filter is applied."""
for _ in range(3):
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
for _ in range(2):
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.RUNNING)
result = repository.get_workflow_runs_count(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status=None,
)
assert result["total"] == 6
assert result["succeeded"] == 3
assert result["failed"] == 2
assert result["running"] == 1
assert result["stopped"] == 0
assert result["partial-succeeded"] == 0
def test_count_with_status_filter(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Count only runs matching the requested status."""
for _ in range(3):
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
result = repository.get_workflow_runs_count(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="succeeded",
)
assert result["total"] == 3
assert result["succeeded"] == 3
assert result["failed"] == 0
def test_count_with_invalid_status_raises(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Invalid status raises StatementError because the column uses an enum type."""
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
with pytest.raises(sa_exc.StatementError) as exc_info:
repository.get_workflow_runs_count(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="invalid_status",
)
assert isinstance(exc_info.value.orig, ValueError)
def test_count_with_time_range(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Time range filter excludes runs created outside the window."""
# Recent run (within 1 day)
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
# Old run (8 days ago)
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
created_at_offset=timedelta(days=-8),
)
result = repository.get_workflow_runs_count(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status=None,
time_range="7d",
)
assert result["total"] == 1
assert result["succeeded"] == 1
def test_count_with_status_and_time_range(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Both status and time_range filters apply together."""
# Recent succeeded
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
# Recent failed
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
# Old succeeded (outside time range)
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
created_at_offset=timedelta(days=-8),
)
result = repository.get_workflow_runs_count(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="succeeded",
time_range="7d",
)
assert result["total"] == 1
assert result["succeeded"] == 1
assert result["failed"] == 0

View File

@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
from core.plugin.impl.exc import PluginDaemonClientSideError
from models import Account
from models.enums import MessageFileBelongsTo
from models.enums import ConversationFromSource, MessageFileBelongsTo
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
from services.account_service import AccountService, TenantService
from services.agent_service import AgentService
@ -165,7 +165,7 @@ class TestAgentService:
inputs={},
status="normal",
mode="chat",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
@ -204,7 +204,7 @@ class TestAgentService:
answer_unit_price=0.001,
provider_response_latency=1.5,
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(message)
db_session_with_containers.commit()
@ -406,7 +406,7 @@ class TestAgentService:
inputs={},
status="normal",
mode="chat",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
@ -445,7 +445,7 @@ class TestAgentService:
answer_unit_price=0.001,
provider_response_latency=1.5,
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(message)
db_session_with_containers.commit()
@ -478,7 +478,7 @@ class TestAgentService:
inputs={},
status="normal",
mode="chat",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
@ -517,7 +517,7 @@ class TestAgentService:
answer_unit_price=0.001,
provider_response_latency=1.5,
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(message)
db_session_with_containers.commit()
@ -624,7 +624,7 @@ class TestAgentService:
inputs={},
status="normal",
mode="chat",
from_source="api",
from_source=ConversationFromSource.API,
app_model_config_id=None, # Explicitly set to None
)
db_session_with_containers.add(conversation)
@ -647,7 +647,7 @@ class TestAgentService:
answer_unit_price=0.001,
provider_response_latency=1.5,
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(message)
db_session_with_containers.commit()

View File

@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from models import Account
from models.enums import ConversationFromSource, InvokeFrom
from models.model import MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.app_service import AppService
@ -136,8 +137,8 @@ class TestAnnotationService:
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from="console",
from_source="console",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=account.id,
)
@ -174,8 +175,8 @@ class TestAnnotationService:
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from="console",
from_source="console",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=account.id,
)
@ -721,7 +722,7 @@ class TestAnnotationService:
query=f"Query {i}: {fake.sentence()}",
user_id=account.id,
message_id=fake.uuid4(),
from_source="console",
from_source=ConversationFromSource.CONSOLE,
score=0.8 + (i * 0.1),
)
@ -772,7 +773,7 @@ class TestAnnotationService:
query=query,
user_id=account.id,
message_id=message_id,
from_source="console",
from_source=ConversationFromSource.CONSOLE,
score=score,
)

View File

@ -10,6 +10,7 @@ from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account, Tenant, TenantAccountJoin
from models.enums import ConversationFromSource
from models.model import App, Conversation, EndUser, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.conversation_service import ConversationService
@ -107,7 +108,7 @@ class ConversationServiceIntegrationTestDataFactory:
system_instruction_tokens=0,
status="normal",
invoke_from=invoke_from.value,
from_source="api" if isinstance(user, EndUser) else "console",
from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE,
from_end_user_id=user.id if isinstance(user, EndUser) else None,
from_account_id=user.id if isinstance(user, Account) else None,
dialogue_count=0,
@ -154,7 +155,7 @@ class ConversationServiceIntegrationTestDataFactory:
currency="USD",
status="normal",
invoke_from=InvokeFrom.WEB_APP.value,
from_source="api" if isinstance(user, EndUser) else "console",
from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE,
from_end_user_id=user.id if isinstance(user, EndUser) else None,
from_account_id=user.id if isinstance(user, Account) else None,
)

View File

@ -7,7 +7,7 @@ import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import FeedbackFromSource, FeedbackRating
from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating
from models.model import (
App,
AppAnnotationHitHistory,
@ -94,7 +94,7 @@ class TestAppMessageExportServiceIntegration:
name="conv",
inputs={"seed": 1},
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid.uuid4()),
)
session.add(conversation)
@ -129,7 +129,7 @@ class TestAppMessageExportServiceIntegration:
total_price=Decimal("0.003"),
currency="USD",
message_metadata=message_metadata,
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=conversation.from_end_user_id,
created_at=created_at,
)

View File

@ -4,7 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.enums import FeedbackRating
from models.enums import ConversationFromSource, FeedbackRating, InvokeFrom
from models.model import MessageFeedback
from services.app_service import AppService
from services.errors.message import (
@ -149,8 +149,8 @@ class TestMessageService:
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from="console",
from_source="console",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=account.id,
)
@ -187,8 +187,8 @@ class TestMessageService:
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from="console",
from_source="console",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=account.id,
)

View File

@ -4,6 +4,7 @@ from decimal import Decimal
import pytest
from models.enums import ConversationFromSource
from models.model import Message
from services import message_service
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
@ -36,7 +37,7 @@ def test_attach_message_extra_contents_assigns_serialized_payload(db_session_wit
total_price=Decimal(0),
currency="USD",
status="normal",
from_source="console",
from_source=ConversationFromSource.CONSOLE,
from_account_id=fixture.account.id,
)
db_session_with_containers.add(message_without_extra_content)

View File

@ -11,7 +11,14 @@ from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import DataSourceType, FeedbackFromSource, FeedbackRating, MessageChainType, MessageFileBelongsTo
from models.enums import (
ConversationFromSource,
DataSourceType,
FeedbackFromSource,
FeedbackRating,
MessageChainType,
MessageFileBelongsTo,
)
from models.model import (
App,
AppAnnotationHitHistory,
@ -166,7 +173,7 @@ class TestMessagesCleanServiceIntegration:
name="Test conversation",
inputs={},
status="normal",
from_source=FeedbackFromSource.USER,
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid.uuid4()),
)
db_session_with_containers.add(conversation)
@ -196,7 +203,7 @@ class TestMessagesCleanServiceIntegration:
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
from_source=FeedbackFromSource.USER,
from_source=ConversationFromSource.API,
from_account_id=conversation.from_end_user_id,
created_at=created_at,
)

View File

@ -4,6 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.enums import ConversationFromSource
from models.model import EndUser, Message
from models.web import SavedMessage
from services.app_service import AppService
@ -132,11 +133,14 @@ class TestSavedMessageService:
# Create a simple conversation first
from models.model import Conversation
is_account = hasattr(user, "current_tenant")
from_source = ConversationFromSource.CONSOLE if is_account else ConversationFromSource.API
conversation = Conversation(
app_id=app.id,
from_source="account" if hasattr(user, "current_tenant") else "end_user",
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
from_account_id=user.id if hasattr(user, "current_tenant") else None,
from_source=from_source,
from_end_user_id=user.id if not is_account else None,
from_account_id=user.id if is_account else None,
name=fake.sentence(nb_words=3),
inputs={},
status="normal",
@ -150,9 +154,9 @@ class TestSavedMessageService:
message = Message(
app_id=app.id,
conversation_id=conversation.id,
from_source="account" if hasattr(user, "current_tenant") else "end_user",
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
from_account_id=user.id if hasattr(user, "current_tenant") else None,
from_source=from_source,
from_end_user_id=user.id if not is_account else None,
from_account_id=user.id if is_account else None,
inputs={},
query=fake.sentence(nb_words=5),
message=fake.text(max_nb_chars=100),

View File

@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import Account
from models.enums import ConversationFromSource
from models.model import Conversation, EndUser
from models.web import PinnedConversation
from services.account_service import AccountService, TenantService
@ -145,7 +146,7 @@ class TestWebConversationService:
system_instruction_tokens=50,
status="normal",
invoke_from=InvokeFrom.WEB_APP,
from_source="console" if isinstance(user, Account) else "api",
from_source=ConversationFromSource.CONSOLE if isinstance(user, Account) else ConversationFromSource.API,
from_end_user_id=user.id if isinstance(user, EndUser) else None,
from_account_id=user.id if isinstance(user, Account) else None,
dialogue_count=0,

View File

@ -7,7 +7,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.enums import CreatorUserRole
from models.enums import ConversationFromSource, CreatorUserRole
from models.model import (
Message,
)
@ -165,7 +165,7 @@ class TestWorkflowRunService:
inputs={},
status="normal",
mode="chat",
from_source=CreatorUserRole.ACCOUNT,
from_source=ConversationFromSource.CONSOLE,
from_account_id=account.id,
)
db_session_with_containers.add(conversation)
@ -186,7 +186,7 @@ class TestWorkflowRunService:
message.answer_price_unit = 0.001
message.currency = "USD"
message.status = "normal"
message.from_source = CreatorUserRole.ACCOUNT
message.from_source = ConversationFromSource.CONSOLE
message.from_account_id = account.id
message.workflow_run_id = workflow_run.id
message.inputs = {"input": "test input"}

View File

@ -0,0 +1,320 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.local import LocalProxy
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.app.message import (
ChatMessageListApi,
ChatMessagesQuery,
FeedbackExportQuery,
MessageAnnotationCountApi,
MessageApi,
MessageFeedbackApi,
MessageFeedbackExportApi,
MessageFeedbackPayload,
MessageSuggestedQuestionApi,
)
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from models import App, AppMode
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
return flask_app
@pytest.fixture
def mock_account():
from models.account import Account, AccountStatus
account = MagicMock(spec=Account)
account.id = "user_123"
account.timezone = "UTC"
account.status = AccountStatus.ACTIVE
account.is_admin_or_owner = True
account.current_tenant.current_role = "owner"
account.has_edit_permission = True
return account
@pytest.fixture
def mock_app_model():
app_model = MagicMock(spec=App)
app_model.id = "app_123"
app_model.mode = AppMode.CHAT
app_model.tenant_id = "tenant_123"
return app_model
@pytest.fixture(autouse=True)
def mock_csrf():
with patch("libs.login.check_csrf_token") as mock:
yield mock
import contextlib
@contextlib.contextmanager
def setup_test_context(
test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None
):
with (
patch("extensions.ext_database.db") as mock_db,
patch("controllers.console.app.wraps.db", mock_db),
patch("controllers.console.wraps.db", mock_db),
patch("controllers.console.app.message.db", mock_db),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
):
# Set up a generic query mock that usually returns mock_app_model when getting app
app_query_mock = MagicMock()
app_query_mock.filter.return_value.first.return_value = mock_app_model
app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model
app_query_mock.where.return_value.first.return_value = mock_app_model
app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model
data_query_mock = MagicMock()
def query_side_effect(*args, **kwargs):
if args and hasattr(args[0], "__name__") and args[0].__name__ == "App":
return app_query_mock
return data_query_mock
mock_db.session.query.side_effect = query_side_effect
mock_db.data_query = data_query_mock
# Let the caller override the stat db query logic
proxy_mock = LocalProxy(lambda: mock_account)
query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()])
full_path = f"{route_path}?{query_string}" if qs else route_path
with (
patch("libs.login.current_user", proxy_mock),
patch("flask_login.current_user", proxy_mock),
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
):
with test_app.test_request_context(full_path, method=method, json=payload):
request.view_args = {"app_id": "app_123"}
if "suggested-questions" in route_path:
# simplistic extraction for message_id
parts = route_path.split("chat-messages/")
if len(parts) > 1:
request.view_args["message_id"] = parts[1].split("/")[0]
elif "messages/" in route_path and "chat-messages" not in route_path:
parts = route_path.split("messages/")
if len(parts) > 1:
request.view_args["message_id"] = parts[1].split("/")[0]
api_instance = endpoint_class()
# Check if it has a dispatch_request or method
if hasattr(api_instance, method.lower()):
yield api_instance, mock_db, request.view_args
class TestMessageValidators:
def test_chat_messages_query_validators(self):
# Test empty_to_none
assert ChatMessagesQuery.empty_to_none("") is None
assert ChatMessagesQuery.empty_to_none("val") == "val"
# Test validate_uuid
assert ChatMessagesQuery.validate_uuid(None) is None
assert (
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
== "123e4567-e89b-12d3-a456-426614174000"
)
def test_message_feedback_validators(self):
assert (
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
== "123e4567-e89b-12d3-a456-426614174000"
)
def test_feedback_export_validators(self):
assert FeedbackExportQuery.parse_bool(None) is None
assert FeedbackExportQuery.parse_bool(True) is True
assert FeedbackExportQuery.parse_bool("1") is True
assert FeedbackExportQuery.parse_bool("0") is False
assert FeedbackExportQuery.parse_bool("off") is False
with pytest.raises(ValueError):
FeedbackExportQuery.parse_bool("invalid")
class TestMessageEndpoints:
def test_chat_message_list_not_found(self, app, mock_account, mock_app_model):
with setup_test_context(
app,
ChatMessageListApi,
"/apps/app_123/chat-messages",
"GET",
mock_account,
mock_app_model,
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.first.return_value = None
with pytest.raises(NotFound):
api.get(**v_args)
def test_chat_message_list_success(self, app, mock_account, mock_app_model):
with setup_test_context(
app,
ChatMessageListApi,
"/apps/app_123/chat-messages",
"GET",
mock_account,
mock_app_model,
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1},
) as (api, mock_db, v_args):
mock_conv = MagicMock()
mock_conv.id = "123e4567-e89b-12d3-a456-426614174000"
mock_msg = MagicMock()
mock_msg.id = "msg_123"
mock_msg.feedbacks = []
mock_msg.annotation = None
mock_msg.annotation_hit_history = None
mock_msg.agent_thoughts = []
mock_msg.message_files = []
mock_msg.extra_contents = []
mock_msg.message = {}
mock_msg.message_metadata_dict = {}
# mock returns
q_mock = mock_db.data_query
q_mock.where.return_value.first.side_effect = [mock_conv]
q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg]
mock_db.session.scalar.return_value = False
resp = api.get(**v_args)
assert resp["limit"] == 1
assert resp["has_more"] is False
assert len(resp["data"]) == 1
def test_message_feedback_not_found(self, app, mock_account, mock_app_model):
with setup_test_context(
app,
MessageFeedbackApi,
"/apps/app_123/feedbacks",
"POST",
mock_account,
mock_app_model,
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.first.return_value = None
with pytest.raises(NotFound):
api.post(**v_args)
def test_message_feedback_success(self, app, mock_account, mock_app_model):
payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"}
with setup_test_context(
app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload
) as (api, mock_db, v_args):
mock_msg = MagicMock()
mock_msg.admin_feedback = None
mock_db.data_query.where.return_value.first.return_value = mock_msg
resp = api.post(**v_args)
assert resp == {"result": "success"}
def test_message_annotation_count(self, app, mock_account, mock_app_model):
with setup_test_context(
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.count.return_value = 5
resp = api.get(**v_args)
assert resp == {"count": 5}
@patch("controllers.console.app.message.MessageService")
def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model):
mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"]
with setup_test_context(
app,
MessageSuggestedQuestionApi,
"/apps/app_123/chat-messages/msg_123/suggested-questions",
"GET",
mock_account,
mock_app_model,
) as (api, mock_db, v_args):
resp = api.get(**v_args)
assert resp == {"data": ["q1", "q2"]}
@pytest.mark.parametrize(
("exc", "expected_exc"),
[
(MessageNotExistsError, NotFound),
(ConversationNotExistsError, NotFound),
(ProviderTokenNotInitError, ProviderNotInitializeError),
(QuotaExceededError, ProviderQuotaExceededError),
(ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError),
(SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError),
(Exception, InternalServerError),
],
)
@patch("controllers.console.app.message.MessageService")
def test_message_suggested_questions_errors(
self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model
):
mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc()
with setup_test_context(
app,
MessageSuggestedQuestionApi,
"/apps/app_123/chat-messages/msg_123/suggested-questions",
"GET",
mock_account,
mock_app_model,
) as (api, mock_db, v_args):
with pytest.raises(expected_exc):
api.get(**v_args)
@patch("services.feedback_service.FeedbackService.export_feedbacks")
def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model):
mock_export.return_value = {"exported": True}
with setup_test_context(
app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model
) as (api, mock_db, v_args):
resp = api.get(**v_args)
assert resp == {"exported": True}
def test_message_api_get_success(self, app, mock_account, mock_app_model):
with setup_test_context(
app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model
) as (api, mock_db, v_args):
mock_msg = MagicMock()
mock_msg.id = "msg_123"
mock_msg.feedbacks = []
mock_msg.annotation = None
mock_msg.annotation_hit_history = None
mock_msg.agent_thoughts = []
mock_msg.message_files = []
mock_msg.extra_contents = []
mock_msg.message = {}
mock_msg.message_metadata_dict = {}
mock_db.data_query.where.return_value.first.return_value = mock_msg
resp = api.get(**v_args)
assert resp["id"] == "msg_123"

View File

@ -0,0 +1,275 @@
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.local import LocalProxy
from controllers.console.app.statistic import (
AverageResponseTimeStatistic,
AverageSessionInteractionStatistic,
DailyConversationStatistic,
DailyMessageStatistic,
DailyTerminalsStatistic,
DailyTokenCostStatistic,
TokensPerSecondStatistic,
UserSatisfactionRateStatistic,
)
from models import App, AppMode
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture
def mock_account():
from models.account import Account, AccountStatus
account = MagicMock(spec=Account)
account.id = "user_123"
account.timezone = "UTC"
account.status = AccountStatus.ACTIVE
account.is_admin_or_owner = True
account.current_tenant.current_role = "owner"
account.has_edit_permission = True
return account
@pytest.fixture
def mock_app_model():
app_model = MagicMock(spec=App)
app_model.id = "app_123"
app_model.mode = AppMode.CHAT
app_model.tenant_id = "tenant_123"
return app_model
@pytest.fixture(autouse=True)
def mock_csrf():
with patch("libs.login.check_csrf_token") as mock:
yield mock
def setup_test_context(
test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None)
):
with (
patch("controllers.console.app.statistic.db") as mock_db_stat,
patch("controllers.console.app.wraps.db") as mock_db_wraps,
patch("controllers.console.wraps.db", mock_db_wraps),
patch(
"controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123")
),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
):
mock_conn = MagicMock()
mock_conn.execute.return_value = mock_rs
mock_begin = MagicMock()
mock_begin.__enter__.return_value = mock_conn
mock_db_stat.engine.begin.return_value = mock_begin
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = mock_app_model
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
mock_query.where.return_value.first.return_value = mock_app_model
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
mock_db_wraps.session.query.return_value = mock_query
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with test_app.test_request_context(route_path, method="GET"):
request.view_args = {"app_id": "app_123"}
api_instance = endpoint_class()
response = api_instance.get(app_id="app_123")
return response
class TestStatisticEndpoints:
def test_daily_message_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.message_count = 10
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyMessageStatistic,
"/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["message_count"] == 10
def test_daily_conversation_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.conversation_count = 5
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyConversationStatistic,
"/apps/app_123/statistics/daily-conversations",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["conversation_count"] == 5
def test_daily_terminals_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.terminal_count = 2
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyTerminalsStatistic,
"/apps/app_123/statistics/daily-end-users",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["terminal_count"] == 2
def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.token_count = 100
mock_row.total_price = Decimal("0.02")
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyTokenCostStatistic,
"/apps/app_123/statistics/token-costs",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["token_count"] == 100
assert response.json["data"][0]["total_price"] == "0.02"
def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.interactions = Decimal("3.523")
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
AverageSessionInteractionStatistic,
"/apps/app_123/statistics/average-session-interactions",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["interactions"] == 3.52
def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.message_count = 100
mock_row.feedback_count = 10
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
UserSatisfactionRateStatistic,
"/apps/app_123/statistics/user-satisfaction-rate",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["rate"] == 100.0
def test_average_response_time_statistic(self, app, mock_account, mock_app_model):
mock_app_model.mode = AppMode.COMPLETION
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.latency = 1.234
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
AverageResponseTimeStatistic,
"/apps/app_123/statistics/average-response-time",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["latency"] == 1234.0
def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.tokens_per_second = 15.5
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
TokensPerSecondStatistic,
"/apps/app_123/statistics/tokens-per-second",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["tps"] == 15.5
@patch("controllers.console.app.statistic.parse_time_range")
def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model):
mock_parse.side_effect = ValueError("Invalid time")
from werkzeug.exceptions import BadRequest
with pytest.raises(BadRequest):
setup_test_context(
app,
DailyMessageStatistic,
"/apps/app_123/statistics/daily-messages?start=invalid&end=invalid",
mock_account,
mock_app_model,
[],
)
@patch("controllers.console.app.statistic.parse_time_range")
def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model):
import datetime
start = datetime.datetime.now()
end = datetime.datetime.now()
mock_parse.return_value = (start, end)
response = setup_test_context(
app,
DailyMessageStatistic,
"/apps/app_123/statistics/daily-messages?start=something&end=something",
mock_account,
mock_app_model,
[],
)
assert response.status_code == 200
mock_parse.assert_called_once()

View File

@ -0,0 +1,313 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.local import LocalProxy
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.app.workflow_draft_variable import (
ConversationVariableCollectionApi,
EnvironmentVariableCollectionApi,
NodeVariableCollectionApi,
SystemVariableCollectionApi,
VariableApi,
VariableResetApi,
WorkflowVariableCollectionApi,
)
from controllers.web.error import InvalidArgumentError, NotFoundError
from models import App, AppMode
from models.enums import DraftVariableType
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
return flask_app
@pytest.fixture
def mock_account():
from models.account import Account, AccountStatus
account = MagicMock(spec=Account)
account.id = "user_123"
account.timezone = "UTC"
account.status = AccountStatus.ACTIVE
account.is_admin_or_owner = True
account.current_tenant.current_role = "owner"
account.has_edit_permission = True
return account
@pytest.fixture
def mock_app_model():
app_model = MagicMock(spec=App)
app_model.id = "app_123"
app_model.mode = AppMode.WORKFLOW
app_model.tenant_id = "tenant_123"
return app_model
@pytest.fixture(autouse=True)
def mock_csrf():
with patch("libs.login.check_csrf_token") as mock:
yield mock
def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None):
with (
patch("controllers.console.app.wraps.db") as mock_db_wraps,
patch("controllers.console.wraps.db", mock_db_wraps),
patch("controllers.console.app.workflow_draft_variable.db"),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
):
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = mock_app_model
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
mock_query.where.return_value.first.return_value = mock_app_model
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
mock_db_wraps.session.query.return_value = mock_query
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with test_app.test_request_context(route_path, method=method, json=payload):
request.view_args = {"app_id": "app_123"}
# extract node_id or variable_id from path manually since view_args overrides
if "nodes/" in route_path:
request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0]
if "variables/" in route_path:
# simplistic extraction
parts = route_path.split("variables/")
if len(parts) > 1 and parts[1] and parts[1] != "reset":
request.view_args["variable_id"] = parts[1].split("/")[0]
api_instance = endpoint_class()
# we just call dispatch_request to avoid manual argument passing
if hasattr(api_instance, method.lower()):
func = getattr(api_instance, method.lower())
return func(**request.view_args)
class TestWorkflowDraftVariableEndpoints:
@staticmethod
def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock:
class DummyValueType:
def exposed_type(self):
return DraftVariableType.NODE
mock_var = MagicMock()
mock_var.app_id = "app_123"
mock_var.id = "var_123"
mock_var.name = "test_var"
mock_var.description = ""
mock_var.get_variable_type.return_value = variable_type
mock_var.get_selector.return_value = []
mock_var.value_type = DummyValueType()
mock_var.edited = False
mock_var.visible = True
mock_var.file_id = None
mock_var.variable_file = None
mock_var.is_truncated.return_value = False
mock_var.get_value.return_value.model_copy.return_value.value = "test_value"
return mock_var
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_workflow_variable_collection_get_success(
self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model
):
mock_wf_srv.return_value.is_workflow_exist.return_value = True
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList(
variables=[], total=0
)
resp = setup_test_context(
app,
WorkflowVariableCollectionApi,
"/apps/app_123/workflows/draft/variables?page=1&limit=20",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": [], "total": 0}
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf_srv.return_value.is_workflow_exist.return_value = False
with pytest.raises(DraftWorkflowNotExist):
setup_test_context(
app,
WorkflowVariableCollectionApi,
"/apps/app_123/workflows/draft/variables",
"GET",
mock_account,
mock_app_model,
)
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
resp = setup_test_context(
app,
WorkflowVariableCollectionApi,
"/apps/app_123/workflows/draft/variables",
"DELETE",
mock_account,
mock_app_model,
)
assert resp.status_code == 204
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[])
resp = setup_test_context(
app,
NodeVariableCollectionApi,
"/apps/app_123/workflows/draft/nodes/node_123/variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}
def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model):
with pytest.raises(InvalidArgumentError):
setup_test_context(
app,
NodeVariableCollectionApi,
"/apps/app_123/workflows/draft/nodes/sys/variables",
"GET",
mock_account,
mock_app_model,
)
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
resp = setup_test_context(
app,
NodeVariableCollectionApi,
"/apps/app_123/workflows/draft/nodes/node_123/variables",
"DELETE",
mock_account,
mock_app_model,
)
assert resp.status_code == 204
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
resp = setup_test_context(
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
)
assert resp["id"] == "var_123"
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = None
with pytest.raises(NotFoundError):
setup_test_context(
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
)
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
resp = setup_test_context(
app,
VariableApi,
"/apps/app_123/workflows/draft/variables/var_123",
"PATCH",
mock_account,
mock_app_model,
payload={"name": "new_name"},
)
assert resp["id"] == "var_123"
mock_draft_srv.return_value.update_variable.assert_called_once()
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
resp = setup_test_context(
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model
)
assert resp.status_code == 204
mock_draft_srv.return_value.delete_variable.assert_called_once()
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
mock_draft_srv.return_value.reset_variable.return_value = None # means no content
resp = setup_test_context(
app,
VariableResetApi,
"/apps/app_123/workflows/draft/variables/var_123/reset",
"PUT",
mock_account,
mock_app_model,
)
assert resp.status_code == 204
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[])
resp = setup_test_context(
app,
ConversationVariableCollectionApi,
"/apps/app_123/workflows/draft/conversation-variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model):
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[])
resp = setup_test_context(
app,
SystemVariableCollectionApi,
"/apps/app_123/workflows/draft/system-variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf = MagicMock()
mock_wf.environment_variables = []
mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf
resp = setup_test_context(
app,
EnvironmentVariableCollectionApi,
"/apps/app_123/workflows/draft/environment-variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}

View File

@ -0,0 +1,209 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.data_source_bearer_auth import (
ApiKeyAuthDataSource,
ApiKeyAuthDataSourceBinding,
ApiKeyAuthDataSourceBindingDelete,
)
from controllers.console.auth.error import ApiKeyAuthFailedError
class TestApiKeyAuthDataSource:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
app.config["WTF_CSRF_ENABLED"] = False
return app
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_binding = MagicMock()
mock_binding.id = "bind_123"
mock_binding.category = "api_key"
mock_binding.provider = "custom_provider"
mock_binding.disabled = False
mock_binding.created_at.timestamp.return_value = 1620000000
mock_binding.updated_at.timestamp.return_value = 1620000001
mock_get_list.return_value = [mock_binding]
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSource()
response = api_instance.get()
assert "sources" in response
assert len(response["sources"]) == 1
assert response["sources"][0]["provider"] == "custom_provider"
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_get_list.return_value = None
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSource()
response = api_instance.get()
assert "sources" in response
assert len(response["sources"]) == 0
class TestApiKeyAuthDataSourceBinding:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
app.config["WTF_CSRF_ENABLED"] = False
return app
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context(
"/console/api/api-key-auth/data-source/binding",
method="POST",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSourceBinding()
response = api_instance.post()
assert response[0]["result"] == "success"
assert response[1] == 200
mock_validate.assert_called_once()
mock_create.assert_called_once()
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_create.side_effect = ValueError("Invalid structure")
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context(
"/console/api/api-key-auth/data-source/binding",
method="POST",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSourceBinding()
with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"):
api_instance.post()
class TestApiKeyAuthDataSourceBindingDelete:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
app.config["WTF_CSRF_ENABLED"] = False
return app
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth")
def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSourceBindingDelete()
response = api_instance.delete("binding_123")
assert response[0]["result"] == "success"
assert response[1] == 204
mock_delete.assert_called_once_with("tenant_123", "binding_123")

View File

@ -0,0 +1,192 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.local import LocalProxy
from controllers.console.auth.data_source_oauth import (
OAuthDataSource,
OAuthDataSourceBinding,
OAuthDataSourceCallback,
OAuthDataSourceSync,
)
class TestOAuthDataSource:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
@patch("flask_login.current_user")
@patch("libs.login.current_user")
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None)
def test_get_oauth_url_successful(
self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app
):
mock_oauth_provider = MagicMock()
mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth"
mock_get_providers.return_value = {"notion": mock_oauth_provider}
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_libs_user.return_value = mock_account
mock_flask_user.return_value = mock_account
# also patch current_account_with_tenant
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSource()
response = api_instance.get("notion")
assert response[0]["data"] == "http://oauth.provider/auth"
assert response[1] == 200
mock_oauth_provider.get_authorization_url.assert_called_once()
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
@patch("flask_login.current_user")
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app):
mock_get_providers.return_value = {"notion": MagicMock()}
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSource()
response = api_instance.get("unknown_provider")
assert response[0]["error"] == "Invalid provider"
assert response[1] == 400
class TestOAuthDataSourceCallback:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_oauth_callback_successful(self, mock_get_providers, app):
provider_mock = MagicMock()
mock_get_providers.return_value = {"notion": provider_mock}
with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"):
api_instance = OAuthDataSourceCallback()
response = api_instance.get("notion")
assert response.status_code == 302
assert "code=mock_code" in response.location
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_oauth_callback_missing_code(self, mock_get_providers, app):
provider_mock = MagicMock()
mock_get_providers.return_value = {"notion": provider_mock}
with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"):
api_instance = OAuthDataSourceCallback()
response = api_instance.get("notion")
assert response.status_code == 302
assert "error=Access denied" in response.location
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_oauth_callback_invalid_provider(self, mock_get_providers, app):
mock_get_providers.return_value = {"notion": MagicMock()}
with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"):
api_instance = OAuthDataSourceCallback()
response = api_instance.get("invalid")
assert response[0]["error"] == "Invalid provider"
assert response[1] == 400
class TestOAuthDataSourceBinding:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_get_binding_successful(self, mock_get_providers, app):
mock_provider = MagicMock()
mock_provider.get_access_token.return_value = None
mock_get_providers.return_value = {"notion": mock_provider}
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"):
api_instance = OAuthDataSourceBinding()
response = api_instance.get("notion")
assert response[0]["result"] == "success"
assert response[1] == 200
mock_provider.get_access_token.assert_called_once_with("auth_code_123")
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_get_binding_missing_code(self, mock_get_providers, app):
mock_get_providers.return_value = {"notion": MagicMock()}
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"):
api_instance = OAuthDataSourceBinding()
response = api_instance.get("notion")
assert response[0]["error"] == "Invalid code"
assert response[1] == 400
class TestOAuthDataSourceSync:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app):
mock_provider = MagicMock()
mock_provider.sync_data_source.return_value = None
mock_get_providers.return_value = {"notion": mock_provider}
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSourceSync()
# The route pattern uses <uuid:binding_id>, so we just pass a string for unit testing
response = api_instance.get("notion", "binding_123")
assert response[0]["result"] == "success"
assert response[1] == 200
mock_provider.sync_data_source.assert_called_once_with("binding_123")

View File

@ -0,0 +1,417 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.auth.oauth_server import (
OAuthServerAppApi,
OAuthServerUserAccountApi,
OAuthServerUserAuthorizeApi,
OAuthServerUserTokenApi,
)
class TestOAuthServerAppApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
from models.model import OAuthProviderApp
oauth_app = MagicMock(spec=OAuthProviderApp)
oauth_app.client_id = "test_client_id"
oauth_app.redirect_uris = ["http://localhost/callback"]
oauth_app.app_icon = "icon_url"
oauth_app.app_label = "Test App"
oauth_app.scope = "read,write"
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider",
method="POST",
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
):
api_instance = OAuthServerAppApi()
response = api_instance.post()
assert response["app_icon"] == "icon_url"
assert response["app_label"] == "Test App"
assert response["scope"] == "read,write"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider",
method="POST",
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
):
api_instance = OAuthServerAppApi()
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_client_id(self, mock_get_app, mock_db, app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = None
with app.test_request_context(
"/oauth/provider",
method="POST",
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
):
api_instance = OAuthServerAppApi()
with pytest.raises(NotFound, match="client_id is invalid"):
api_instance.post()
class TestOAuthServerUserAuthorizeApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
oauth_app = MagicMock()
oauth_app.client_id = "test_client_id"
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.current_account_with_tenant")
@patch("controllers.console.wraps.current_account_with_tenant")
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code")
@patch("libs.login.check_csrf_token")
def test_successful_authorize(
self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_account = MagicMock()
mock_account.id = "user_123"
from models.account import AccountStatus
mock_account.status = AccountStatus.ACTIVE
mock_current.return_value = (mock_account, MagicMock())
mock_wrap_current.return_value = (mock_account, MagicMock())
mock_sign.return_value = "auth_code_123"
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
with patch("libs.login.current_user", mock_account):
api_instance = OAuthServerUserAuthorizeApi()
response = api_instance.post()
assert response["code"] == "auth_code_123"
mock_sign.assert_called_once_with("test_client_id", "user_123")
class TestOAuthServerUserTokenApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
from models.model import OAuthProviderApp
oauth_app = MagicMock(spec=OAuthProviderApp)
oauth_app.client_id = "test_client_id"
oauth_app.client_secret = "test_secret"
oauth_app.redirect_uris = ["http://localhost/callback"]
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_sign.return_value = ("access_123", "refresh_123")
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "test_secret",
"redirect_uri": "http://localhost/callback",
},
):
api_instance = OAuthServerUserTokenApi()
response = api_instance.post()
assert response["access_token"] == "access_123"
assert response["refresh_token"] == "refresh_123"
assert response["token_type"] == "Bearer"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"client_secret": "test_secret",
"redirect_uri": "http://localhost/callback",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="code is required"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "invalid_secret",
"redirect_uri": "http://localhost/callback",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="client_secret is invalid"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "test_secret",
"redirect_uri": "http://invalid/callback",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_sign.return_value = ("new_access", "new_refresh")
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
):
api_instance = OAuthServerUserTokenApi()
response = api_instance.post()
assert response["access_token"] == "new_access"
assert response["refresh_token"] == "new_refresh"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "refresh_token",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="refresh_token is required"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "invalid_grant",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="invalid grant_type"):
api_instance.post()
class TestOAuthServerUserAccountApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
from models.model import OAuthProviderApp
oauth_app = MagicMock(spec=OAuthProviderApp)
oauth_app.client_id = "test_client_id"
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_account = MagicMock()
mock_account.name = "Test User"
mock_account.email = "test@example.com"
mock_account.avatar = "avatar_url"
mock_account.interface_language = "en-US"
mock_account.timezone = "UTC"
mock_validate.return_value = mock_account
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Bearer valid_access_token"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response["name"] == "Test User"
assert response["email"] == "test@example.com"
assert response["avatar"] == "avatar_url"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "Authorization header is required"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "InvalidFormat"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "Invalid Authorization header format"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Basic something"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "token_type is invalid"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Bearer "},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "Invalid Authorization header format"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_validate.return_value = None
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Bearer invalid_token"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "access_token or client_id is invalid"

View File

@ -24,13 +24,8 @@ class TestBannerApi:
banner.status = BannerStatus.ENABLED
banner.created_at = datetime(2024, 1, 1)
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.return_value = [banner]
session = MagicMock()
session.query.return_value = query
session.scalars.return_value.all.return_value = [banner]
with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session):
result = method(api)
@ -58,16 +53,14 @@ class TestBannerApi:
banner.status = BannerStatus.ENABLED
banner.created_at = None
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.side_effect = [
scalars_result = MagicMock()
scalars_result.all.side_effect = [
[],
[banner],
]
session = MagicMock()
session.query.return_value = query
session.scalars.return_value = scalars_result
with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session):
result = method(api)
@ -87,13 +80,8 @@ class TestBannerApi:
api = banner_module.BannerApi()
method = unwrap(api.get)
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.return_value = []
session = MagicMock()
session.query.return_value = query
session.scalars.return_value.all.return_value = []
with app.test_request_context("/"), patch.object(banner_module.db, "session", session):
result = method(api)

View File

@ -260,11 +260,10 @@ class TestInstalledAppsCreateApi:
app_entity.tenant_id = "t2"
session = MagicMock()
session.query.return_value.where.return_value.first.side_effect = [
recommended,
app_entity,
None,
]
# scalar() is called for recommended_app and installed_app lookups
session.scalar.side_effect = [recommended, None]
# get() is called for app PK lookup
session.get.return_value = app_entity
with (
app.test_request_context("/", json={"app_id": "a1"}),
@ -282,7 +281,7 @@ class TestInstalledAppsCreateApi:
method = unwrap(api.post)
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with (
app.test_request_context("/", json={"app_id": "a1"}),
@ -300,10 +299,10 @@ class TestInstalledAppsCreateApi:
app_entity = MagicMock(is_public=False)
session = MagicMock()
session.query.return_value.where.return_value.first.side_effect = [
recommended,
app_entity,
]
# scalar() returns recommended_app
session.scalar.return_value = recommended
# get() returns the app entity
session.get.return_value = app_entity
with (
app.test_request_context("/", json={"app_id": "a1"}),

View File

@ -958,8 +958,8 @@ class TestTrialSitApi:
app_model = MagicMock()
app_model.id = "a1"
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = None
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
mock_scalar.return_value = None
with pytest.raises(Forbidden):
method(api, app_model)
@ -973,8 +973,8 @@ class TestTrialSitApi:
app_model.tenant = MagicMock()
app_model.tenant.status = TenantStatus.ARCHIVE
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = site
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
mock_scalar.return_value = site
with pytest.raises(Forbidden):
method(api, app_model)
@ -990,10 +990,10 @@ class TestTrialSitApi:
with (
app.test_request_context("/"),
patch.object(module.db.session, "query") as mock_query,
patch.object(module.db.session, "scalar") as mock_scalar,
patch.object(module.SiteResponse, "model_validate") as mock_validate,
):
mock_query.return_value.where.return_value.first.return_value = site
mock_scalar.return_value = site
mock_validate_result = MagicMock()
mock_validate_result.model_dump.return_value = {"name": "test", "icon": "icon"}
mock_validate.return_value = mock_validate_result

View File

@ -34,9 +34,9 @@ def test_installed_app_required_not_found():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.return_value = None
scalar_mock.return_value = None
with pytest.raises(NotFound):
view("app-id")
@ -54,11 +54,11 @@ def test_installed_app_required_app_deleted():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
patch("controllers.console.explore.wraps.db.session.delete"),
patch("controllers.console.explore.wraps.db.session.commit"),
):
q.return_value.where.return_value.first.return_value = installed_app
scalar_mock.return_value = installed_app
with pytest.raises(NotFound):
view("app-id")
@ -76,9 +76,9 @@ def test_installed_app_required_success():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.return_value = installed_app
scalar_mock.return_value = installed_app
result = view("app-id")
assert result == installed_app
@ -149,9 +149,9 @@ def test_trial_app_required_not_allowed():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.return_value = None
scalar_mock.return_value = None
with pytest.raises(TrialAppNotAllowed):
view("app-id")
@ -170,9 +170,9 @@ def test_trial_app_required_limit_exceeded():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.side_effect = [
scalar_mock.side_effect = [
trial_app,
record,
]
@ -194,9 +194,9 @@ def test_trial_app_required_success():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.side_effect = [
scalar_mock.side_effect = [
trial_app,
record,
]

View File

@ -114,7 +114,7 @@ class TestBaseApiKeyResource:
def test_delete_key_not_found(self, tenant_context_admin, db_mock):
resource = DummyApiKeyResource()
db_mock.session.query.return_value.where.return_value.first.return_value = None
db_mock.session.scalar.return_value = None
with patch("controllers.console.apikey._get_resource"):
with pytest.raises(Exception) as exc_info:
@ -125,7 +125,7 @@ class TestBaseApiKeyResource:
def test_delete_success(self, tenant_context_admin, db_mock):
resource = DummyApiKeyResource()
db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock()
db_mock.session.scalar.return_value = MagicMock()
with (
patch("controllers.console.apikey._get_resource"),

View File

@ -328,7 +328,7 @@ class TestSystemSetup:
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
"""Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = None # No setup
mock_db.session.scalar.return_value = None # No setup
mock_environ_get.return_value = "some_password"
@setup_required
@ -345,7 +345,7 @@ class TestSystemSetup:
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
"""Test NotSetupError when no INIT_PASSWORD and setup not complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = None # No setup
mock_db.session.scalar.return_value = None # No setup
mock_environ_get.return_value = None # No INIT_PASSWORD
@setup_required

View File

@ -55,9 +55,9 @@ class TestAccountInitApi:
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
patch("controllers.console.workspace.account.db.session.query") as query_mock,
patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock,
):
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
scalar_mock.return_value = MagicMock(status="unused")
resp = method(api)
assert resp["result"] == "success"

View File

@ -207,10 +207,10 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 200
@ -226,9 +226,9 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
):
q.return_value.where.return_value.first.return_value = None
get_mock.return_value = None
with pytest.raises(HTTPException):
method(api, "x")
@ -244,13 +244,13 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 400
@ -266,13 +266,13 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 403
@ -288,13 +288,13 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 404

View File

@ -36,7 +36,115 @@ def unwrap(func):
class TestTenantListApi:
def test_get_success(self, app):
def test_get_success_saas_path(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
patch(
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
return_value={
"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0},
"t2": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": 0},
},
) as get_plan_bulk_mock,
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
assert status == 200
assert len(result["workspaces"]) == 2
assert result["workspaces"][0]["current"] is True
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_not_called()
def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app):
"""Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used.
billing.enabled is mocked False to prove the endpoint does not gate on it for this path
(SaaS contract treats enabled as on; display follows subscription.plan).
"""
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
features_t2 = MagicMock()
features_t2.billing.enabled = False
features_t2.billing.subscription.plan = CloudPlan.PROFESSIONAL
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
patch(
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
return_value={"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}},
) as get_plan_bulk_mock,
patch(
"controllers.console.workspace.workspace.FeatureService.get_features",
return_value=features_t2,
) as get_features_mock,
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_called_once_with("t2")
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app):
"""Test fallback to FeatureService when bulk billing returns empty result.
BillingService.get_plan_bulk catches exceptions internally and returns empty dict,
so we simulate the real failure mode by returning empty dict for non-empty input.
"""
api = TenantListApi()
method = unwrap(api.get)
@ -54,27 +162,41 @@ class TestTenantListApi:
)
features = MagicMock()
features.billing.enabled = True
features.billing.subscription.plan = CloudPlan.SANDBOX
features.billing.enabled = False
features.billing.subscription.plan = CloudPlan.TEAM
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
patch(
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
return_value={}, # Simulates real failure: empty result for non-empty input
) as get_plan_bulk_mock,
patch(
"controllers.console.workspace.workspace.FeatureService.get_features",
return_value=features,
) as get_features_mock,
patch("controllers.console.workspace.workspace.logger.warning") as logger_warning_mock,
):
result, status = method(api)
assert status == 200
assert len(result["workspaces"]) == 2
assert result["workspaces"][0]["current"] is True
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
assert result["workspaces"][1]["plan"] == CloudPlan.TEAM
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
assert get_features_mock.call_count == 2
logger_warning_mock.assert_called_once()
def test_get_billing_disabled(self, app):
def test_get_billing_disabled_community_path(self, app):
api = TenantListApi()
method = unwrap(api.get)
@ -87,6 +209,7 @@ class TestTenantListApi:
features = MagicMock()
features.billing.enabled = False
features.billing.subscription.plan = CloudPlan.SANDBOX
with (
app.test_request_context("/workspaces"),
@ -98,15 +221,83 @@ class TestTenantListApi:
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
patch(
"controllers.console.workspace.workspace.FeatureService.get_features",
return_value=features,
),
) as get_features_mock,
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
get_features_mock.assert_called_once_with("t1")
def test_get_enterprise_only_skips_feature_service(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
assert result["workspaces"][1]["plan"] == CloudPlan.SANDBOX
assert result["workspaces"][0]["current"] is False
assert result["workspaces"][1]["current"] is True
get_features_mock.assert_not_called()
def test_get_enterprise_only_with_empty_tenants(self, app):
api = TenantListApi()
method = unwrap(api.get)
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None)
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
assert status == 200
assert result["workspaces"] == []
get_features_mock.assert_not_called()
class TestWorkspaceListApi:
@ -258,12 +449,12 @@ class TestSwitchWorkspaceApi:
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
),
):
query_mock.return_value.get.return_value = tenant
get_mock.return_value = tenant
result = method(api)
assert result["result"] == "success"
@ -297,9 +488,9 @@ class TestSwitchWorkspaceApi:
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
):
query_mock.return_value.get.return_value = None
get_mock.return_value = None
with pytest.raises(ValueError):
method(api)

View File

@ -11,6 +11,7 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.task_pipeline import message_cycle_manager
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from models.enums import ConversationFromSource
from models.model import AppMode, Conversation, Message
@ -92,7 +93,7 @@ def test_init_generate_records_marks_existing_conversation():
system_instruction_tokens=0,
status="normal",
invoke_from=InvokeFrom.WEB_APP.value,
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id="user-id",
from_account_id=None,
)

View File

@ -0,0 +1,181 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint
from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams
from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult
from models.api_based_extension import APIBasedExtension
class TestApiModeration:
@pytest.fixture
def api_config(self):
return {
"inputs_config": {
"enabled": True,
},
"outputs_config": {
"enabled": True,
},
"api_based_extension_id": "test-extension-id",
}
@pytest.fixture
def api_moderation(self, api_config):
return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config)
def test_moderation_input_params(self):
params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query")
assert params.app_id == "app-1"
assert params.inputs == {"key": "val"}
assert params.query == "test query"
# Test defaults
params_default = ModerationInputParams()
assert params_default.app_id == ""
assert params_default.inputs == {}
assert params_default.query == ""
def test_moderation_output_params(self):
params = ModerationOutputParams(app_id="app-1", text="test text")
assert params.app_id == "app-1"
assert params.text == "test text"
with pytest.raises(ValidationError):
ModerationOutputParams()
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_validate_config_success(self, mock_get_extension, api_config):
mock_get_extension.return_value = MagicMock(spec=APIBasedExtension)
ApiModeration.validate_config("test-tenant-id", api_config)
mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id")
def test_validate_config_missing_extension_id(self):
config = {
"inputs_config": {"enabled": True},
"outputs_config": {"enabled": True},
}
with pytest.raises(ValueError, match="api_based_extension_id is required"):
ApiModeration.validate_config("test-tenant-id", config)
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_validate_config_extension_not_found(self, mock_get_extension, api_config):
mock_get_extension.return_value = None
with pytest.raises(ValueError, match="API-based Extension not found"):
ApiModeration.validate_config("test-tenant-id", api_config)
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation):
mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"}
result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello")
assert isinstance(result, ModerationInputsResult)
assert result.flagged is True
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == "Blocked by API"
mock_get_config.assert_called_once_with(
APIBasedExtensionPoint.APP_MODERATION_INPUT,
{"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"},
)
def test_moderation_for_inputs_disabled(self):
config = {
"inputs_config": {"enabled": False},
"outputs_config": {"enabled": True},
"api_based_extension_id": "ext-id",
}
moderation = ApiModeration("app-id", "tenant-id", config)
result = moderation.moderation_for_inputs(inputs={}, query="")
assert result.flagged is False
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == ""
def test_moderation_for_inputs_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation.moderation_for_inputs({}, "")
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation):
mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""}
result = api_moderation.moderation_for_outputs(text="hello world")
assert isinstance(result, ModerationOutputsResult)
assert result.flagged is False
mock_get_config.assert_called_once_with(
APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"}
)
def test_moderation_for_outputs_disabled(self):
config = {
"inputs_config": {"enabled": True},
"outputs_config": {"enabled": False},
"api_based_extension_id": "ext-id",
}
moderation = ApiModeration("app-id", "tenant-id", config)
result = moderation.moderation_for_outputs(text="test")
assert result.flagged is False
assert result.action == ModerationAction.DIRECT_OUTPUT
def test_moderation_for_outputs_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation.moderation_for_outputs("test")
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
@patch("core.moderation.api.api.decrypt_token")
@patch("core.moderation.api.api.APIBasedExtensionRequestor")
def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation):
mock_ext = MagicMock(spec=APIBasedExtension)
mock_ext.api_endpoint = "http://api.test"
mock_ext.api_key = "encrypted-key"
mock_get_ext.return_value = mock_ext
mock_decrypt.return_value = "decrypted-key"
mock_requestor = MagicMock()
mock_requestor.request.return_value = {"flagged": True}
mock_requestor_cls.return_value = mock_requestor
params = {"some": "params"}
result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
assert result == {"flagged": True}
mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id")
mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key")
mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key")
mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
def test_get_config_by_requestor_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation):
mock_get_ext.return_value = None
with pytest.raises(ValueError, match="API-based Extension not found"):
api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
@patch("core.moderation.api.api.db.session.scalar")
def test_get_api_based_extension(self, mock_scalar):
mock_ext = MagicMock(spec=APIBasedExtension)
mock_scalar.return_value = mock_ext
result = ApiModeration._get_api_based_extension("tenant-1", "ext-1")
assert result == mock_ext
mock_scalar.assert_called_once()
# Verify the call has the correct filters
args, kwargs = mock_scalar.call_args
stmt = args[0]
# We can't easily inspect the statement without complex sqlalchemy tricks,
# but calling it is usually enough for unit tests if we mock the result.

View File

@ -0,0 +1,207 @@
from unittest.mock import MagicMock, patch
import pytest
from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity
from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult
from core.moderation.input_moderation import InputModeration
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager
class TestInputModeration:
@pytest.fixture
def app_config(self):
config = MagicMock(spec=AppConfig)
config.sensitive_word_avoidance = None
return config
@pytest.fixture
def input_moderation(self):
return InputModeration()
def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is False
assert final_inputs == inputs
assert final_query == query
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {"keywords": ["bad"]}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is False
assert final_inputs == inputs
assert final_query == query
mock_factory_cls.assert_called_once_with(
name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]}
)
mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query)
@patch("core.moderation.input_moderation.ModerationFactory")
@patch("core.moderation.input_moderation.TraceTask")
def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
trace_manager = MagicMock(spec=TraceQueueManager)
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_inputs.return_value = mock_result
input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
trace_manager=trace_manager,
)
trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value)
mock_trace_task.assert_called_once()
call_kwargs = mock_trace_task.call_args.kwargs
call_args = mock_trace_task.call_args.args
assert call_args[0] == TraceTaskName.MODERATION_TRACE
assert call_kwargs["message_id"] == message_id
assert call_kwargs["moderation_result"] == mock_result
assert call_kwargs["inputs"] == inputs
assert "timer" in call_kwargs
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content"
)
mock_factory.moderation_for_inputs.return_value = mock_result
with pytest.raises(ModerationError) as excinfo:
input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
)
assert str(excinfo.value) == "Blocked content"
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(
flagged=True,
action=ModerationAction.OVERRIDDEN,
inputs={"input_key": "overridden_value"},
query="overridden query",
)
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is True
assert final_inputs == {"input_key": "overridden_value"}
assert final_query == "overridden query"
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = MagicMock()
mock_result.flagged = True
mock_result.action = "NONE" # Some other action
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
)
assert flagged is True
assert final_inputs == inputs
assert final_query == query

View File

@ -0,0 +1,234 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import QueueMessageReplaceEvent
from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.output_moderation import ModerationRule, OutputModeration
class TestOutputModeration:
@pytest.fixture
def mock_queue_manager(self):
return MagicMock(spec=AppQueueManager)
@pytest.fixture
def moderation_rule(self):
return ModerationRule(type="keywords", config={"keywords": "badword"})
@pytest.fixture
def output_moderation(self, mock_queue_manager, moderation_rule):
return OutputModeration(
tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager
)
def test_should_direct_output(self, output_moderation):
assert output_moderation.should_direct_output() is False
output_moderation.final_output = "blocked"
assert output_moderation.should_direct_output() is True
def test_get_final_output(self, output_moderation):
assert output_moderation.get_final_output() == ""
output_moderation.final_output = "blocked"
assert output_moderation.get_final_output() == "blocked"
def test_append_new_token(self, output_moderation):
with patch.object(OutputModeration, "start_thread") as mock_start:
output_moderation.append_new_token("hello")
assert output_moderation.buffer == "hello"
mock_start.assert_called_once()
output_moderation.thread = MagicMock()
output_moderation.append_new_token(" world")
assert output_moderation.buffer == "hello world"
assert mock_start.call_count == 1
def test_moderation_completion_no_flag(self, output_moderation):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
output, flagged = output_moderation.moderation_completion("safe content")
assert output == "safe content"
assert flagged is False
assert output_moderation.is_final_chunk is True
def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
)
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
assert output == "preset"
assert flagged is True
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert isinstance(args[0], QueueMessageReplaceEvent)
assert args[0].text == "preset"
assert args[1] == PublishFrom.TASK_PIPELINE
def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content"
)
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
assert output == "masked content"
assert flagged is True
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked content"
def test_start_thread(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
mock_current_app._get_current_object.return_value = mock_app
with patch("threading.Thread") as mock_thread_class:
mock_thread_instance = MagicMock()
mock_thread_class.return_value = mock_thread_instance
thread = output_moderation.start_thread()
assert thread == mock_thread_instance
mock_thread_class.assert_called_once()
mock_thread_instance.start.assert_called_once()
def test_stop_thread(self, output_moderation):
mock_thread = MagicMock()
mock_thread.is_alive.return_value = True
output_moderation.thread = mock_thread
output_moderation.stop_thread()
assert output_moderation.thread_running is False
output_moderation.thread_running = True
mock_thread.is_alive.return_value = False
output_moderation.stop_thread()
assert output_moderation.thread_running is True
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_success(self, mock_factory_class, output_moderation):
mock_factory = mock_factory_class.return_value
mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_outputs.return_value = mock_result
result = output_moderation.moderation("tenant", "app", "buffer")
assert result == mock_result
mock_factory_class.assert_called_once_with(
name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"}
)
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_exception(self, mock_factory_class, output_moderation):
mock_factory_class.side_effect = Exception("error")
result = output_moderation.moderation("tenant", "app", "buffer")
assert result is None
def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
# Test exit on thread_running=False
output_moderation.thread_running = False
output_moderation.worker(mock_app, 10)
# Should exit immediately
def test_worker_no_flag(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
output_moderation.buffer = "safe"
output_moderation.is_final_chunk = True
# To avoid infinite loop, we'll set thread_running to False after one iteration
def side_effect(*args, **kwargs):
output_moderation.thread_running = False
return mock_moderation.return_value
mock_moderation.side_effect = side_effect
output_moderation.worker(mock_app, 10)
assert mock_moderation.called
def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
)
output_moderation.buffer = "badword"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
assert output_moderation.final_output == "preset"
mock_queue_manager.publish.assert_called_once()
# It breaks on DIRECT_OUTPUT
def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
# Use side_effect to change thread_running on second call
def side_effect(*args, **kwargs):
if mock_moderation.call_count > 1:
output_moderation.thread_running = False
return None
return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked")
mock_moderation.side_effect = side_effect
output_moderation.buffer = "badword"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked"
def test_worker_chunk_too_small(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch("time.sleep") as mock_sleep:
# chunk_length < buffer_size and not is_final_chunk
output_moderation.buffer = "123" # length 3
output_moderation.is_final_chunk = False
def sleep_side_effect(seconds):
output_moderation.thread_running = False
mock_sleep.side_effect = sleep_side_effect
output_moderation.worker(mock_app, 10) # buffer_size 10
mock_sleep.assert_called_once_with(1)
def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
# Return None (exception or no rule)
mock_moderation.return_value = None
def side_effect(*args, **kwargs):
output_moderation.thread_running = False
mock_moderation.side_effect = side_effect
output_moderation.buffer = "something"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
mock_queue_manager.publish.assert_not_called()

View File

@ -0,0 +1,160 @@
from unittest.mock import patch
import httpx
import pytest
from qdrant_client.http import models as rest
from qdrant_client.http.exceptions import UnexpectedResponse
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import (
TidbOnQdrantConfig,
TidbOnQdrantVector,
)
class TestTidbOnQdrantVectorDeleteByIds:
"""Unit tests for TidbOnQdrantVector.delete_by_ids method."""
@pytest.fixture
def vector_instance(self):
"""Create a TidbOnQdrantVector instance for testing."""
config = TidbOnQdrantConfig(
endpoint="http://localhost:6333",
api_key="test_api_key",
)
with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"):
vector = TidbOnQdrantVector(
collection_name="test_collection",
group_id="test_group",
config=config,
)
return vector
def test_delete_by_ids_with_multiple_ids(self, vector_instance):
"""Test batch deletion with multiple document IDs."""
ids = ["doc1", "doc2", "doc3"]
vector_instance.delete_by_ids(ids)
# Verify that delete was called once with MatchAny filter
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
# Check collection name
assert call_args[1]["collection_name"] == "test_collection"
# Verify filter uses MatchAny with all IDs
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
assert len(filter_obj.must) == 1
field_condition = filter_obj.must[0]
assert field_condition.key == "metadata.doc_id"
assert isinstance(field_condition.match, rest.MatchAny)
assert set(field_condition.match.any) == {"doc1", "doc2", "doc3"}
def test_delete_by_ids_with_single_id(self, vector_instance):
"""Test deletion with a single document ID."""
ids = ["doc1"]
vector_instance.delete_by_ids(ids)
# Verify that delete was called once
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
# Verify filter uses MatchAny with single ID
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
field_condition = filter_obj.must[0]
assert isinstance(field_condition.match, rest.MatchAny)
assert field_condition.match.any == ["doc1"]
def test_delete_by_ids_with_empty_list(self, vector_instance):
"""Test deletion with empty ID list returns early without API call."""
vector_instance.delete_by_ids([])
# Verify that delete was NOT called
vector_instance._client.delete.assert_not_called()
def test_delete_by_ids_with_404_error(self, vector_instance):
"""Test that 404 errors (collection not found) are handled gracefully."""
ids = ["doc1", "doc2"]
# Mock a 404 error
error = UnexpectedResponse(
status_code=404,
reason_phrase="Not Found",
content=b"Collection not found",
headers=httpx.Headers(),
)
vector_instance._client.delete.side_effect = error
# Should not raise an exception
vector_instance.delete_by_ids(ids)
# Verify delete was called
vector_instance._client.delete.assert_called_once()
def test_delete_by_ids_with_unexpected_error(self, vector_instance):
"""Test that non-404 errors are re-raised."""
ids = ["doc1", "doc2"]
# Mock a 500 error
error = UnexpectedResponse(
status_code=500,
reason_phrase="Internal Server Error",
content=b"Server error",
headers=httpx.Headers(),
)
vector_instance._client.delete.side_effect = error
# Should re-raise the exception
with pytest.raises(UnexpectedResponse) as exc_info:
vector_instance.delete_by_ids(ids)
assert exc_info.value.status_code == 500
def test_delete_by_ids_with_large_batch(self, vector_instance):
"""Test deletion with a large batch of IDs."""
# Create 1000 IDs
ids = [f"doc_{i}" for i in range(1000)]
vector_instance.delete_by_ids(ids)
# Verify single delete call with all IDs
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
field_condition = filter_obj.must[0]
# Verify all 1000 IDs are in the batch
assert len(field_condition.match.any) == 1000
assert "doc_0" in field_condition.match.any
assert "doc_999" in field_condition.match.any
def test_delete_by_ids_filter_structure(self, vector_instance):
"""Test that the filter structure is correctly constructed."""
ids = ["doc1", "doc2"]
vector_instance.delete_by_ids(ids)
call_args = vector_instance._client.delete.call_args
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
# Verify Filter structure
assert isinstance(filter_obj, rest.Filter)
assert filter_obj.must is not None
assert len(filter_obj.must) == 1
# Verify FieldCondition structure
field_condition = filter_obj.must[0]
assert isinstance(field_condition, rest.FieldCondition)
assert field_condition.key == "metadata.doc_id"
# Verify MatchAny structure
assert isinstance(field_condition.match, rest.MatchAny)
assert field_condition.match.any == ids

View File

@ -0,0 +1,677 @@
from __future__ import annotations
import dataclasses
import json
from collections.abc import Sequence
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_repository import (
HumanInputFormRecord,
HumanInputFormRepositoryImpl,
HumanInputFormSubmissionRepository,
_HumanInputFormEntityImpl,
_HumanInputFormRecipientEntityImpl,
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
from dify_graph.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
HumanInputNodeData,
MemberRecipient,
UserAction,
WebAppDeliveryMethod,
)
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
from libs.datetime_utils import naive_utc_now
from models.human_input import HumanInputFormRecipient, RecipientType
@pytest.fixture(autouse=True)
def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None:
class _FakeSelect:
def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect())
monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader")
def _make_form_definition_json(*, include_expiration_time: bool) -> str:
payload: dict[str, Any] = {
"form_content": "hi",
"inputs": [],
"user_actions": [{"id": "submit", "title": "Submit"}],
"rendered_content": "<p>hi</p>",
}
if include_expiration_time:
payload["expiration_time"] = naive_utc_now()
return json.dumps(payload, default=str)
@dataclasses.dataclass
class _DummyForm:
id: str
workflow_run_id: str | None
node_id: str
tenant_id: str
app_id: str
form_definition: str
rendered_content: str
expiration_time: datetime
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
selected_action_id: str | None = None
submitted_data: str | None = None
submitted_at: datetime | None = None
submission_user_id: str | None = None
submission_end_user_id: str | None = None
completed_by_recipient_id: str | None = None
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
@dataclasses.dataclass
class _DummyRecipient:
id: str
form_id: str
recipient_type: RecipientType
access_token: str | None
class _FakeScalarResult:
def __init__(self, obj: Any):
self._obj = obj
def first(self) -> Any:
if isinstance(self._obj, list):
return self._obj[0] if self._obj else None
return self._obj
def all(self) -> list[Any]:
if self._obj is None:
return []
if isinstance(self._obj, list):
return list(self._obj)
return [self._obj]
class _FakeExecuteResult:
def __init__(self, rows: Sequence[tuple[Any, ...]]):
self._rows = list(rows)
def all(self) -> list[tuple[Any, ...]]:
return list(self._rows)
class _FakeSession:
def __init__(
self,
*,
scalars_result: Any = None,
scalars_results: list[Any] | None = None,
forms: dict[str, _DummyForm] | None = None,
recipients: dict[str, _DummyRecipient] | None = None,
execute_rows: Sequence[tuple[Any, ...]] = (),
):
if scalars_results is not None:
self._scalars_queue = list(scalars_results)
else:
self._scalars_queue = [scalars_result]
self._forms = forms or {}
self._recipients = recipients or {}
self._execute_rows = list(execute_rows)
self.added: list[Any] = []
def scalars(self, _query: Any) -> _FakeScalarResult:
if self._scalars_queue:
value = self._scalars_queue.pop(0)
else:
value = None
return _FakeScalarResult(value)
def execute(self, _stmt: Any) -> _FakeExecuteResult:
return _FakeExecuteResult(self._execute_rows)
def get(self, model_cls: Any, obj_id: str) -> Any:
name = getattr(model_cls, "__name__", "")
if name == "HumanInputForm":
return self._forms.get(obj_id)
if name == "HumanInputFormRecipient":
return self._recipients.get(obj_id)
return None
def add(self, obj: Any) -> None:
self.added.append(obj)
def add_all(self, objs: Sequence[Any]) -> None:
self.added.extend(list(objs))
def flush(self) -> None:
# Simulate DB default population for attributes referenced in entity wrappers.
for obj in self.added:
if hasattr(obj, "id") and obj.id in (None, ""):
obj.id = f"gen-{len(str(self.added))}"
if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None:
if obj.recipient_type == RecipientType.CONSOLE:
obj.access_token = "token-console"
elif obj.recipient_type == RecipientType.BACKSTAGE:
obj.access_token = "token-backstage"
else:
obj.access_token = "token-webapp"
def refresh(self, _obj: Any) -> None:
return None
def begin(self) -> _FakeSession:
return self
def __enter__(self) -> _FakeSession:
return self
def __exit__(self, exc_type, exc, tb) -> None:
return None
class _SessionFactoryStub:
def __init__(self, session: _FakeSession):
self._session = session
def create_session(self) -> _FakeSession:
return self._session
def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None:
monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session))
def test_recipient_entity_token_raises_when_missing() -> None:
recipient = SimpleNamespace(id="r1", access_token=None)
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
with pytest.raises(AssertionError, match="access_token should not be None"):
_ = entity.token
def test_recipient_entity_id_and_token_success() -> None:
recipient = SimpleNamespace(id="r1", access_token="tok")
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
assert entity.id == "r1"
assert entity.token == "tok"
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None:
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok")
webapp = _DummyRecipient(
id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok"
)
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
assert entity.web_app_token == "ctok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
assert entity.web_app_token == "wtok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.web_app_token is None
def test_form_entity_submitted_data_parsed() -> None:
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
submitted_data='{"a": 1}',
submitted_at=naive_utc_now(),
)
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.submitted is True
assert entity.submitted_data == {"a": 1}
assert entity.rendered_content == "<p>x</p>"
assert entity.selected_action_id is None
assert entity.status == HumanInputFormStatus.WAITING
def test_form_record_from_models_injects_expiration_time_when_missing() -> None:
expiration = naive_utc_now()
form = _DummyForm(
id="f1",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=False),
rendered_content="<p>x</p>",
expiration_time=expiration,
submitted_data='{"k": "v"}',
)
record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type]
assert record.definition.expiration_time == expiration
assert record.submitted_data == {"k": "v"}
assert record.submitted is False
def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None:
created: list[SimpleNamespace] = []
def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def]
recipient = SimpleNamespace(
id=f"{payload.TYPE}-{len(created)}",
form_id=form_id,
delivery_id=delivery_id,
recipient_type=payload.TYPE,
recipient_payload=payload.model_dump_json(),
access_token="tok",
)
created.append(recipient)
return recipient
monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined]
form_id="f",
delivery_id="d",
members=[
_WorkspaceMemberInfo(user_id="u1", email=""),
_WorkspaceMemberInfo(user_id="u2", email="a@example.com"),
_WorkspaceMemberInfo(user_id="u3", email="a@example.com"),
],
external_emails=["", "a@example.com", "b@example.com", "b@example.com"],
)
assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL]
def test_query_workspace_members_by_ids_empty_returns_empty() -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == []
def test_query_workspace_members_by_ids_maps_rows() -> None:
session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")])
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"])
assert rows == [
_WorkspaceMemberInfo(user_id="u1", email="a@example.com"),
_WorkspaceMemberInfo(user_id="u2", email="b@example.com"),
]
def test_query_all_workspace_members_maps_rows() -> None:
session = _FakeSession(execute_rows=[("u1", "a@example.com")])
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
rows = repo._query_all_workspace_members(session=session)
assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
def test_repository_init_sets_tenant_id() -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo._tenant_id == "tenant"
def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
result = repo._delivery_method_to_model(
session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod()
)
assert result.delivery.id == "del-1"
assert result.delivery.form_id == "form-1"
assert len(result.recipients) == 1
assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP
def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
called: dict[str, Any] = {}
def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]:
called.update(
{"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config}
)
return ["r"]
monkeypatch.setattr(repo, "_build_email_recipients", fake_build)
method = EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
),
subject="s",
body="b",
)
)
result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method)
assert result.recipients == ["r"]
assert called["delivery_id"] == "del-1"
def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr(
repo,
"_query_all_workspace_members",
lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")],
)
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
recipients = repo._build_email_recipients(
session=MagicMock(),
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]),
)
assert recipients == ["ok"]
def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]:
assert restrict_to_user_ids == ["u1"]
return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query)
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
recipients = repo._build_email_recipients(
session=MagicMock(),
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
),
)
assert recipients == ["ok"]
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo.get_form("run", "node") is None
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="r1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="tok",
)
session = _FakeSession(scalars_results=[form, [recipient]])
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
entity = repo.get_form("run", "node")
assert entity is not None
assert entity.id == "f1"
assert entity.recipients[0].id == "r1"
assert entity.recipients[0].token == "tok"
def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
ids = iter(["form-id", "del-web", "del-console", "del-backstage"])
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids))
session = _FakeSession()
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
form_config = HumanInputNodeData(
title="Title",
delivery_methods=[],
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
)
params = FormCreateParams(
app_id="app",
workflow_execution_id="run",
node_id="node",
form_config=form_config,
rendered_content="<p>hello</p>",
delivery_methods=[WebAppDeliveryMethod()],
display_in_ui=True,
resolved_default_values={},
form_kind=HumanInputFormKind.RUNTIME,
console_recipient_required=True,
console_creator_account_id="acc-1",
backstage_recipient_required=True,
)
entity = repo.create_form(params)
assert entity.id == "form-id"
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
# Console token should take precedence when console recipient is present.
assert entity.web_app_token == "token-console"
assert len(entity.recipients) == 3
def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_token("tok") is None
recipient = SimpleNamespace(form=None)
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_token("tok") is None
def test_submission_repository_init_no_args() -> None:
repo = HumanInputFormSubmissionRepository()
assert isinstance(repo, HumanInputFormSubmissionRepository)
def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f1",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
recipient = SimpleNamespace(
id="r1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="tok",
form=form,
)
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_token("tok")
assert record is not None
assert record.access_token == "tok"
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP)
assert record is not None
assert record.recipient_id == "r1"
def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None
def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
missing_session = _FakeSession(forms={})
_patch_session_factory(monkeypatch, missing_session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form not found"):
repo.mark_submitted(
form_id="missing",
recipient_id=None,
selected_action_id="a",
form_data={},
submission_user_id=None,
submission_end_user_id=None,
)
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=fixed_now,
)
recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok")
session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_submitted(
form_id=form.id,
recipient_id=recipient.id,
selected_action_id="approve",
form_data={"k": "v"},
submission_user_id="u",
submission_end_user_id="eu",
)
assert form.status == HumanInputFormStatus.SUBMITTED
assert form.submitted_at == fixed_now
assert record.submitted_data == {"k": "v"}
def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"):
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type]
def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
status=HumanInputFormStatus.TIMEOUT,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r")
assert record.status == HumanInputFormStatus.TIMEOUT
def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
status=HumanInputFormStatus.SUBMITTED,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form already submitted"):
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
selected_action_id="a",
submitted_data="{}",
submission_user_id="u",
submission_end_user_id="eu",
completed_by_recipient_id="r",
status=HumanInputFormStatus.WAITING,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
assert form.status == HumanInputFormStatus.EXPIRED
assert form.selected_action_id is None
assert form.submitted_data is None
assert form.submission_user_id is None
assert form.submission_end_user_id is None
assert form.completed_by_recipient_id is None
assert record.status == HumanInputFormStatus.EXPIRED
def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(forms={}))
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form not found"):
repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT)

View File

@ -1,84 +1,291 @@
from datetime import datetime
from datetime import UTC, datetime
from unittest.mock import MagicMock
from uuid import uuid4
from sqlalchemy import create_engine
import pytest
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, WorkflowRun
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from models import Account, CreatorUserRole, EndUser, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
engine = create_engine("sqlite:///:memory:")
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
user = MagicMock(spec=Account)
user.id = str(uuid4())
user.current_tenant_id = str(uuid4())
repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=real_session_factory,
user=user,
app_id="app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
session_context = MagicMock()
session_context.__enter__.return_value = session
session_context.__exit__.return_value = False
repository._session_factory = MagicMock(return_value=session_context)
return repository
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
return WorkflowExecution.new(
id_=execution_id,
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0.0",
graph={"nodes": [], "edges": []},
inputs={"query": "hello"},
started_at=started_at,
)
def test_save_uses_execution_started_at_when_record_does_not_exist():
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
session_factory = MagicMock(spec=sessionmaker)
session = MagicMock()
session.get.return_value = None
repository = _build_repository_with_mocked_session(session)
started_at = datetime(2026, 1, 1, 12, 0, 0)
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
session_factory.return_value.__enter__.return_value = session
return session_factory
def test_save_preserves_existing_created_at_when_record_already_exists():
session = MagicMock()
repository = _build_repository_with_mocked_session(session)
@pytest.fixture
def mock_engine():
"""Mock SQLAlchemy Engine."""
return MagicMock(spec=Engine)
execution_id = str(uuid4())
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repository._tenant_id
existing_run.created_at = existing_created_at
session.get.return_value = existing_run
execution = _build_execution(
execution_id=execution_id,
started_at=datetime(2026, 1, 1, 12, 30, 0),
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = MagicMock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = MagicMock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_execution():
"""Sample WorkflowExecution for testing."""
return WorkflowExecution(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
outputs={"output1": "result1"},
status=WorkflowExecutionStatus.SUCCEEDED,
error_message="",
total_tokens=100,
total_steps=5,
exceptions_count=0,
started_at=datetime.now(UTC),
finished_at=datetime.now(UTC),
)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()
class TestSQLAlchemyWorkflowExecutionRepository:
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
app_id = "test_app_id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from
)
assert repo._session_factory == mock_session_factory
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role == CreatorUserRole.ACCOUNT
def test_init_with_engine(self, mock_engine, mock_account):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_engine,
user=mock_account,
app_id="test_app_id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert isinstance(repo._session_factory, sessionmaker)
assert repo._session_factory.kw["bind"] == mock_engine
def test_init_invalid_session_factory(self, mock_account):
with pytest.raises(ValueError, match="Invalid session_factory type"):
SQLAlchemyWorkflowExecutionRepository(
session_factory="invalid", user=mock_account, app_id=None, triggered_from=None
)
def test_init_no_tenant_id(self, mock_session_factory):
user = MagicMock(spec=Account)
user.current_tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None
)
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None
)
assert repo._tenant_id == mock_end_user.tenant_id
assert repo._creator_user_role == CreatorUserRole.END_USER
def test_to_domain_model(self, mock_session_factory, mock_account):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
)
db_model = MagicMock(spec=WorkflowRun)
db_model.id = str(uuid4())
db_model.workflow_id = str(uuid4())
db_model.type = "workflow"
db_model.version = "1.0"
db_model.inputs_dict = {"in": "val"}
db_model.outputs_dict = {"out": "val"}
db_model.graph_dict = {"nodes": []}
db_model.status = "succeeded"
db_model.error = "some error"
db_model.total_tokens = 50
db_model.total_steps = 3
db_model.exceptions_count = 1
db_model.created_at = datetime.now(UTC)
db_model.finished_at = datetime.now(UTC)
domain_model = repo._to_domain_model(db_model)
assert domain_model.id_ == db_model.id
assert domain_model.workflow_id == db_model.workflow_id
assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED
assert domain_model.inputs == db_model.inputs_dict
assert domain_model.error_message == "some error"
def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Make elapsed time deterministic to avoid flaky tests
sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)
sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC)
db_model = repo._to_db_model(sample_workflow_execution)
assert db_model.id == sample_workflow_execution.id_
assert db_model.tenant_id == repo._tenant_id
assert db_model.app_id == "test_app"
assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
assert db_model.status == sample_workflow_execution.status.value
assert db_model.total_tokens == sample_workflow_execution.total_tokens
assert db_model.elapsed_time == 10.0
def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Test with empty/None fields
sample_workflow_execution.graph = None
sample_workflow_execution.inputs = None
sample_workflow_execution.outputs = None
sample_workflow_execution.error_message = None
sample_workflow_execution.finished_at = None
db_model = repo._to_db_model(sample_workflow_execution)
assert db_model.graph is None
assert db_model.inputs is None
assert db_model.outputs is None
assert db_model.error is None
assert db_model.elapsed_time == 0
def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=None,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
db_model = repo._to_db_model(sample_workflow_execution)
assert not hasattr(db_model, "app_id") or db_model.app_id is None
assert db_model.tenant_id == repo._tenant_id
def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
)
# Test triggered_from missing
with pytest.raises(ValueError, match="triggered_from is required"):
repo._to_db_model(sample_workflow_execution)
repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._to_db_model(sample_workflow_execution)
repo._creator_user_id = "some_id"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._to_db_model(sample_workflow_execution)
def test_save(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.save(sample_workflow_execution)
session = mock_session_factory.return_value.__enter__.return_value
session.merge.assert_called_once()
session.commit.assert_called_once()
# Check cache
assert sample_workflow_execution.id_ in repo._execution_cache
cached_model = repo._execution_cache[sample_workflow_execution.id_]
assert cached_model.id == sample_workflow_execution.id_
def test_save_uses_execution_started_at_when_record_does_not_exist(
self, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
sample_workflow_execution.started_at = started_at
session = mock_session_factory.return_value.__enter__.return_value
session.get.return_value = None
repo.save(sample_workflow_execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
def test_save_preserves_existing_created_at_when_record_already_exists(
self, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
execution_id = sample_workflow_execution.id_
existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repo._tenant_id
existing_run.created_at = existing_created_at
session = mock_session_factory.return_value.__enter__.return_value
session.get.return_value = existing_run
sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC)
repo.save(sample_workflow_execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()

View File

@ -0,0 +1,772 @@
from __future__ import annotations
import json
import logging
from collections.abc import Mapping
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, Mock
import psycopg2.errors
import pytest
from sqlalchemy import Engine, create_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository,
_deterministic_json_dump,
_filter_by_offload_type,
_find_first,
_replace_or_append_offload,
)
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
from models import Account, EndUser
from models.enums import ExecutionOffLoadType
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account:
user = Mock(spec=Account)
user.id = user_id
user.current_tenant_id = tenant_id
return user
def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser:
user = Mock(spec=EndUser)
user.id = user_id
user.tenant_id = tenant_id
return user
def _execution(
*,
execution_id: str = "exec-id",
node_execution_id: str = "node-exec-id",
workflow_run_id: str = "run-id",
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED,
inputs: Mapping[str, Any] | None = None,
outputs: Mapping[str, Any] | None = None,
process_data: Mapping[str, Any] | None = None,
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
) -> WorkflowNodeExecution:
return WorkflowNodeExecution(
id=execution_id,
node_execution_id=node_execution_id,
workflow_id="workflow-id",
workflow_execution_id=workflow_run_id,
index=1,
predecessor_node_id=None,
node_id="node-id",
node_type=NodeType.LLM,
title="Title",
inputs=inputs,
outputs=outputs,
process_data=process_data,
status=status,
error=None,
elapsed_time=1.0,
metadata=metadata,
created_at=datetime.now(UTC),
finished_at=None,
)
class _SessionCtx:
def __init__(self, session: Any):
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type, exc, tb) -> None:
return None
def _session_factory(session: Any) -> sessionmaker:
factory = Mock(spec=sessionmaker)
factory.return_value = _SessionCtx(session)
return factory
def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
engine: Engine = create_engine("sqlite:///:memory:")
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=engine,
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert isinstance(repo._session_factory, sessionmaker)
sm = Mock(spec=sessionmaker)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=sm,
user=_mock_end_user(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert repo._creator_user_role.value == "end_user"
def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
with pytest.raises(ValueError, match="Invalid session_factory type"):
SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type]
session_factory=object(),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
user = _mock_account()
user.current_tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=user,
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None:
created: dict[str, Any] = {}
class FakeTruncator:
def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int):
created.update(
{
"max_size_bytes": max_size_bytes,
"array_element_limit": array_element_limit,
"string_length_limit": string_length_limit,
}
)
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator",
FakeTruncator,
)
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
_ = repo._create_truncator()
assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE
def test_helpers_find_first_and_replace_or_append_and_filter() -> None:
assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}'
assert _find_first([], lambda _: True) is None
assert _find_first([1, 2, 3], lambda x: x > 1) == 2
off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2
replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS))
assert len(replaced) == 2
assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS]
def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1})
# Happy path: deterministic json dump should be sorted
db_model = repo._to_db_model(execution)
assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1}
assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1
repo._triggered_from = None
with pytest.raises(ValueError, match="triggered_from is required"):
repo._to_db_model(execution)
def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution()
db_model = repo._to_db_model(execution)
assert db_model.app_id == "app"
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._to_db_model(execution)
repo._creator_user_id = "user"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._to_db_model(execution)
def test_is_duplicate_key_error_and_regenerate_id(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
unique = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError("dup", params=None, orig=unique)
assert repo._is_duplicate_key_error(duplicate_error) is True
assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False
execution = _execution(execution_id="old-id")
db_model = WorkflowNodeExecutionModel()
db_model.id = "old-id"
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
caplog.set_level(logging.WARNING)
repo._regenerate_id_on_duplicate(execution, db_model)
assert execution.id == "new-id"
assert db_model.id == "new-id"
assert any("Duplicate key conflict" in r.message for r in caplog.records)
def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id1"
db_model.node_execution_id = "node1"
db_model.foo = "bar" # type: ignore[attr-defined]
db_model.__dict__["_private"] = "x"
existing = SimpleNamespace()
session.get.return_value = existing
repo._persist_to_database(db_model)
assert existing.foo == "bar"
session.add.assert_not_called()
assert repo._node_execution_cache["node1"] is db_model
session.reset_mock()
session.get.return_value = None
repo._node_execution_cache.clear()
repo._persist_to_database(db_model)
session.add.assert_called_once_with(db_model)
assert repo._node_execution_cache["node1"] is db_model
def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None
class FakeTruncator:
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
return value, False
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None
def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None:
uploaded: dict[str, Any] = {}
class FakeFileService:
def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def]
uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user})
return SimpleNamespace(id="file-id", key="file-key")
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService()
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id")
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
class FakeTruncator:
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
return {"truncated": True}, True
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS)
assert result is not None
assert result.truncated_value == {"truncated": True}
assert uploaded["filename"].startswith("node_execution_exec_inputs.json")
assert result.offload.file_id == "file-id"
assert result.offload.type_ == ExecutionOffLoadType.INPUTS
def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id"
db_model.node_execution_id = "node-exec"
db_model.workflow_id = "wf"
db_model.workflow_run_id = "run"
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = NodeType.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"trunc": "i"})
db_model.process_data = json.dumps({"trunc": "p"})
db_model.outputs = json.dumps({"trunc": "o"})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 0.1
db_model.execution_metadata = json.dumps({"total_tokens": 3})
db_model.created_at = datetime.now(UTC)
db_model.finished_at = None
off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
off_in.file = SimpleNamespace(key="k-in")
off_out.file = SimpleNamespace(key="k-out")
off_proc.file = SimpleNamespace(key="k-proc")
db_model.offload_data = [off_out, off_in, off_proc]
def fake_load(key: str) -> bytes:
return json.dumps({"full": key}).encode()
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load)
domain = repo._to_domain_model(db_model)
assert domain.inputs == {"full": "k-in"}
assert domain.outputs == {"full": "k-out"}
assert domain.process_data == {"full": "k-proc"}
assert domain.get_truncated_inputs() == {"trunc": "i"}
assert domain.get_truncated_outputs() == {"trunc": "o"}
assert domain.get_truncated_process_data() == {"trunc": "p"}
def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id"
db_model.node_execution_id = "node-exec"
db_model.workflow_id = "wf"
db_model.workflow_run_id = "run"
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = NodeType.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"i": 1})
db_model.process_data = json.dumps({"p": 2})
db_model.outputs = json.dumps({"o": 3})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 0.1
db_model.execution_metadata = "{}"
db_model.created_at = datetime.now(UTC)
db_model.finished_at = None
db_model.offload_data = []
domain = repo._to_domain_model(db_model)
assert domain.inputs == {"i": 1}
assert domain.outputs == {"o": 3}
def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeConverter:
def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]:
return {"wrapped": values["a"]}
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter",
FakeConverter,
)
assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}'
def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace(
id="id",
offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)],
inputs=None,
outputs=None,
process_data=None,
)
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
trunc_result = SimpleNamespace(
truncated_value={"trunc": True},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"),
)
monkeypatch.setattr(
repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None
)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
repo.save_execution_data(execution)
# Inputs should be truncated, outputs/process_data encoded directly
db_model = session.merge.call_args.args[0]
assert json.loads(db_model.inputs) == {"trunc": True}
assert json.loads(db_model.outputs) == {"b": 2}
assert json.loads(db_model.process_data) == {"c": 3}
assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data)
assert execution.get_truncated_inputs() == {"trunc": True}
def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
existing = SimpleNamespace(
id="id",
offload_data=[],
inputs=None,
outputs=None,
process_data=None,
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = existing
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any:
if values == {"b": 2}:
return SimpleNamespace(
truncated_value={"b": "trunc"},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"),
)
if values == {"c": 3}:
return SimpleNamespace(
truncated_value={"c": "trunc"},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"),
)
return None
monkeypatch.setattr(repo, "_truncate_and_upload", trunc)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
repo.save_execution_data(execution)
db_model = session.merge.call_args.args[0]
assert json.loads(db_model.outputs) == {"b": "trunc"}
assert json.loads(db_model.process_data) == {"c": "trunc"}
assert execution.get_truncated_outputs() == {"b": "trunc"}
assert execution.get_truncated_process_data() == {"c": "trunc"}
def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = None
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1})
fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None)
monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model)
monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values))
repo.save_execution_data(execution)
merged = session.merge.call_args.args[0]
assert merged.inputs == '{"a": 1}'
def test_save_retries_duplicate_and_logs_non_duplicate(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(execution_id="id")
unique = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError("dup", params=None, orig=unique)
other_error = IntegrityError("other", params=None, orig=None)
calls = {"n": 0}
def persist(_db_model: Any) -> None:
calls["n"] += 1
if calls["n"] == 1:
raise duplicate_error
monkeypatch.setattr(repo, "_persist_to_database", persist)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
repo.save(execution)
assert execution.id == "new-id"
assert repo._node_execution_cache[execution.node_execution_id] is not None
caplog.set_level(logging.ERROR)
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error))
with pytest.raises(IntegrityError):
repo.save(_execution(execution_id="id2", node_execution_id="node2"))
assert any("Non-duplicate key integrity error" in r.message for r in caplog.records)
def test_save_logs_and_reraises_on_unexpected_error(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
caplog.set_level(logging.ERROR)
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom")))
with pytest.raises(RuntimeError, match="boom"):
repo.save(_execution(execution_id="id3", node_execution_id="node3"))
assert any("Failed to save workflow node execution" in r.message for r in caplog.records)
def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
class FakeStmt:
def __init__(self) -> None:
self.where_calls = 0
self.order_by_args: tuple[Any, ...] | None = None
def where(self, *_args: Any) -> FakeStmt:
self.where_calls += 1
return self
def order_by(self, *args: Any) -> FakeStmt:
self.order_by_args = args
return self
stmt = FakeStmt()
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
lambda _q: stmt,
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
model1 = SimpleNamespace(node_execution_id="n1")
model2 = SimpleNamespace(node_execution_id=None)
session = MagicMock()
session.scalars.return_value.all.return_value = [model1, model2]
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
order = OrderConfig(order_by=["index", "missing"], order_direction="desc")
db_models = repo.get_db_models_by_workflow_run("run", order)
assert db_models == [model1, model2]
assert repo._node_execution_cache["n1"] is model1
assert stmt.order_by_args is not None
def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
class FakeStmt:
def where(self, *_args: Any) -> FakeStmt:
return self
def order_by(self, *args: Any) -> FakeStmt:
self.args = args # type: ignore[attr-defined]
return self
stmt = FakeStmt()
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
lambda _q: stmt,
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
session = MagicMock()
session.scalars.return_value.all.return_value = []
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc"))
def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")]
monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models)
monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}")
class FakeExecutor:
def __enter__(self) -> FakeExecutor:
return self
def __exit__(self, exc_type, exc, tb) -> None:
return None
def map(self, func, items, timeout: int): # type: ignore[no-untyped-def]
assert timeout == 30
return list(map(func, items))
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor",
lambda max_workers: FakeExecutor(),
)
result = repo.get_by_workflow_run("run", order_config=None)
assert result == ["domain:db1", "domain:db2"]

View File

@ -0,0 +1,137 @@
import json
from unittest.mock import patch
from core.schemas.registry import SchemaRegistry
class TestSchemaRegistry:
def test_initialization(self, tmp_path):
base_dir = tmp_path / "schemas"
base_dir.mkdir()
registry = SchemaRegistry(str(base_dir))
assert registry.base_dir == base_dir
assert registry.versions == {}
assert registry.metadata == {}
def test_default_registry_singleton(self):
registry1 = SchemaRegistry.default_registry()
registry2 = SchemaRegistry.default_registry()
assert registry1 is registry2
assert isinstance(registry1, SchemaRegistry)
def test_load_all_versions_non_existent_dir(self, tmp_path):
base_dir = tmp_path / "non_existent"
registry = SchemaRegistry(str(base_dir))
registry.load_all_versions()
assert registry.versions == {}
def test_load_all_versions_filtering(self, tmp_path):
base_dir = tmp_path / "schemas"
base_dir.mkdir()
(base_dir / "not_a_version_dir").mkdir()
(base_dir / "v1").mkdir()
(base_dir / "some_file.txt").write_text("content")
registry = SchemaRegistry(str(base_dir))
with patch.object(registry, "_load_version_dir") as mock_load:
registry.load_all_versions()
mock_load.assert_called_once()
assert mock_load.call_args[0][0] == "v1"
def test_load_version_dir_filtering(self, tmp_path):
version_dir = tmp_path / "v1"
version_dir.mkdir()
(version_dir / "schema1.json").write_text("{}")
(version_dir / "not_a_schema.txt").write_text("content")
registry = SchemaRegistry(str(tmp_path))
with patch.object(registry, "_load_schema") as mock_load:
registry._load_version_dir("v1", version_dir)
mock_load.assert_called_once()
assert mock_load.call_args[0][1] == "schema1"
def test_load_version_dir_non_existent(self, tmp_path):
version_dir = tmp_path / "non_existent"
registry = SchemaRegistry(str(tmp_path))
registry._load_version_dir("v1", version_dir)
assert "v1" not in registry.versions
def test_load_schema_success(self, tmp_path):
schema_path = tmp_path / "test.json"
schema_content = {"title": "Test Schema", "description": "A test schema"}
schema_path.write_text(json.dumps(schema_content))
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
registry._load_schema("v1", "test", schema_path)
assert registry.versions["v1"]["test"] == schema_content
uri = "https://dify.ai/schemas/v1/test.json"
assert registry.metadata[uri]["title"] == "Test Schema"
assert registry.metadata[uri]["version"] == "v1"
def test_load_schema_invalid_json(self, tmp_path, caplog):
schema_path = tmp_path / "invalid.json"
schema_path.write_text("invalid json")
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
registry._load_schema("v1", "invalid", schema_path)
assert "Failed to load schema v1/invalid" in caplog.text
def test_load_schema_os_error(self, tmp_path, caplog):
schema_path = tmp_path / "error.json"
schema_path.write_text("{}")
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
with patch("builtins.open", side_effect=OSError("Read error")):
registry._load_schema("v1", "error", schema_path)
assert "Failed to load schema v1/error" in caplog.text
def test_get_schema(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"test": {"type": "object"}}}
# Valid URI
assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"}
# Invalid URI
assert registry.get_schema("invalid-uri") is None
# Missing version
assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None
def test_list_versions(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v2": {}, "v1": {}}
assert registry.list_versions() == ["v1", "v2"]
def test_list_schemas(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"b": {}, "a": {}}}
assert registry.list_schemas("v1") == ["a", "b"]
assert registry.list_schemas("v2") == []
def test_get_all_schemas_for_version(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"test": {"title": "Test Label"}}}
results = registry.get_all_schemas_for_version("v1")
assert len(results) == 1
assert results[0]["name"] == "test"
assert results[0]["label"] == "Test Label"
assert results[0]["schema"] == {"title": "Test Label"}
# Default label if title missing
registry.versions["v1"]["no_title"] = {}
results = registry.get_all_schemas_for_version("v1")
item = next(r for r in results if r["name"] == "no_title")
assert item["label"] == "no_title"
# Empty if version missing
assert registry.get_all_schemas_for_version("v2") == []

View File

@ -0,0 +1,80 @@
from unittest.mock import MagicMock, patch
from core.schemas.registry import SchemaRegistry
from core.schemas.schema_manager import SchemaManager
def test_init_with_provided_registry():
mock_registry = MagicMock(spec=SchemaRegistry)
manager = SchemaManager(registry=mock_registry)
assert manager.registry == mock_registry
@patch("core.schemas.schema_manager.SchemaRegistry.default_registry")
def test_init_with_default_registry(mock_default_registry):
mock_registry = MagicMock(spec=SchemaRegistry)
mock_default_registry.return_value = mock_registry
manager = SchemaManager()
mock_default_registry.assert_called_once()
assert manager.registry == mock_registry
def test_get_all_schema_definitions():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}]
mock_registry.get_all_schemas_for_version.return_value = expected_definitions
manager = SchemaManager(registry=mock_registry)
result = manager.get_all_schema_definitions(version="v2")
mock_registry.get_all_schemas_for_version.assert_called_once_with("v2")
assert result == expected_definitions
def test_get_schema_by_name_success():
mock_registry = MagicMock(spec=SchemaRegistry)
mock_schema = {"type": "object"}
mock_registry.get_schema.return_value = mock_schema
manager = SchemaManager(registry=mock_registry)
result = manager.get_schema_by_name("my_schema", version="v1")
expected_uri = "https://dify.ai/schemas/v1/my_schema.json"
mock_registry.get_schema.assert_called_once_with(expected_uri)
assert result == {"name": "my_schema", "schema": mock_schema}
def test_get_schema_by_name_not_found():
mock_registry = MagicMock(spec=SchemaRegistry)
mock_registry.get_schema.return_value = None
manager = SchemaManager(registry=mock_registry)
result = manager.get_schema_by_name("non_existent", version="v1")
assert result is None
def test_list_available_schemas():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_schemas = ["schema1", "schema2"]
mock_registry.list_schemas.return_value = expected_schemas
manager = SchemaManager(registry=mock_registry)
result = manager.list_available_schemas(version="v1")
mock_registry.list_schemas.assert_called_once_with("v1")
assert result == expected_schemas
def test_list_available_versions():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_versions = ["v1", "v2"]
mock_registry.list_versions.return_value = expected_versions
manager = SchemaManager(registry=mock_registry)
result = manager.list_available_versions()
mock_registry.list_versions.assert_called_once()
assert result == expected_versions

View File

@ -16,6 +16,7 @@ from uuid import uuid4
import pytest
from models.enums import ConversationFromSource
from models.model import (
App,
AppAnnotationHitHistory,
@ -324,7 +325,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=from_end_user_id,
)
@ -345,7 +346,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
)
conversation._inputs = inputs
@ -364,7 +365,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
)
inputs = {"query": "Hello", "context": "test"}
@ -383,7 +384,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
summary="Test summary",
)
@ -402,7 +403,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
summary=None,
)
@ -425,7 +426,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
override_model_configs='{"model": "gpt-4"}',
)
@ -446,7 +447,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=from_end_user_id,
dialogue_count=5,
)
@ -487,7 +488,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
# Assert
@ -511,7 +512,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
message._inputs = inputs
@ -533,7 +534,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
inputs = {"query": "Hello", "context": "test"}
@ -555,7 +556,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
override_model_configs='{"model": "gpt-4"}',
)
@ -578,7 +579,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
message_metadata=json.dumps(metadata),
)
@ -600,7 +601,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
message_metadata=None,
)
@ -627,7 +628,7 @@ class TestMessageModel:
answer_unit_price=Decimal("0.0002"),
total_price=Decimal("0.0003"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
status="normal",
)
message.id = str(uuid4())
@ -988,7 +989,7 @@ class TestModelIntegration:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
)
conversation.id = conversation_id
@ -1003,7 +1004,7 @@ class TestModelIntegration:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
message.id = message_id
@ -1064,7 +1065,7 @@ class TestModelIntegration:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
message.id = message_id
@ -1158,7 +1159,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = str(uuid4())
@ -1183,7 +1184,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
@ -1215,7 +1216,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
@ -1307,7 +1308,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
@ -1361,7 +1362,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
@ -1418,7 +1419,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id

View File

@ -1,135 +0,0 @@
"""Unit tests for non-SQL helper logic in workflow run repository."""
import secrets
from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormStatus
from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowPauseReason
from repositories.sqlalchemy_api_workflow_run_repository import (
_build_human_input_required_reason,
_PrivateWorkflowPauseEntity,
)
@pytest.fixture
def sample_workflow_pause() -> Mock:
"""Create a sample WorkflowPause model."""
pause = Mock(spec=WorkflowPauseModel)
pause.id = "pause-123"
pause.workflow_id = "workflow-123"
pause.workflow_run_id = "workflow-run-123"
pause.state_object_key = "workflow-state-123.json"
pause.resumed_at = None
pause.created_at = datetime.now(UTC)
return pause
class TestPrivateWorkflowPauseEntity:
"""Test _PrivateWorkflowPauseEntity class."""
def test_properties(self, sample_workflow_pause: Mock) -> None:
"""Test entity properties."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
# Assert
assert entity.id == sample_workflow_pause.id
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
assert entity.resumed_at == sample_workflow_pause.resumed_at
def test_get_state(self, sample_workflow_pause: Mock) -> None:
"""Test getting state from storage."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = expected_state
# Act
result = entity.get_state()
# Assert
assert result == expected_state
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
def test_get_state_caching(self, sample_workflow_pause: Mock) -> None:
"""Test state caching in get_state method."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = expected_state
# Act
result1 = entity.get_state()
result2 = entity.get_state()
# Assert
assert result1 == expected_state
assert result2 == expected_state
mock_storage.load.assert_called_once()
class TestBuildHumanInputRequiredReason:
"""Test helper that builds HumanInputRequired pause reasons."""
def test_prefers_backstage_token_when_available(self) -> None:
"""Use backstage token when multiple recipient types may exist."""
# Arrange
expiration_time = datetime.now(UTC)
form_definition = FormDefinition(
form_content="content",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values={"name": "Alice"},
node_title="Ask Name",
display_in_ui=True,
)
form_model = HumanInputForm(
id="form-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_run_id="run-1",
node_id="node-1",
form_definition=form_definition.model_dump_json(),
rendered_content="rendered",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
reason_model = WorkflowPauseReason(
pause_id="pause-1",
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id="form-1",
node_id="node-1",
message="",
)
access_token = secrets.token_urlsafe(8)
backstage_recipient = HumanInputFormRecipient(
form_id="form-1",
delivery_id="delivery-1",
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=access_token,
)
# Act
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient])
# Assert
assert isinstance(reason, HumanInputRequired)
assert reason.form_token == access_token
assert reason.node_title == "Ask Name"
assert reason.form_content == "content"
assert reason.inputs[0].output_variable_name == "name"
assert reason.actions[0].id == "approve"

View File

@ -1,251 +0,0 @@
"""Unit tests for workflow run repository with status filter."""
import uuid
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import sessionmaker
from models import WorkflowRun, WorkflowRunTriggeredFrom
from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository
class TestDifyAPISQLAlchemyWorkflowRunRepository:
"""Test workflow run repository with status filtering."""
@pytest.fixture
def mock_session_maker(self):
"""Create a mock session maker."""
return MagicMock(spec=sessionmaker)
@pytest.fixture
def repository(self, mock_session_maker):
"""Create repository instance with mock session."""
return DifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
def test_get_paginated_workflow_runs_without_status(self, repository, mock_session_maker):
"""Test getting paginated workflow runs without status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_runs = [MagicMock(spec=WorkflowRun) for _ in range(3)]
mock_session.scalars.return_value.all.return_value = mock_runs
# Act
result = repository.get_paginated_workflow_runs(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status=None,
)
# Assert
assert len(result.data) == 3
assert result.limit == 20
assert result.has_more is False
def test_get_paginated_workflow_runs_with_status_filter(self, repository, mock_session_maker):
"""Test getting paginated workflow runs with status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_runs = [MagicMock(spec=WorkflowRun, status="succeeded") for _ in range(2)]
mock_session.scalars.return_value.all.return_value = mock_runs
# Act
result = repository.get_paginated_workflow_runs(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status="succeeded",
)
# Assert
assert len(result.data) == 2
assert all(run.status == "succeeded" for run in result.data)
def test_get_workflow_runs_count_without_status(self, repository, mock_session_maker):
"""Test getting workflow runs count without status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the GROUP BY query results
mock_results = [
("succeeded", 5),
("failed", 2),
("running", 1),
]
mock_session.execute.return_value.all.return_value = mock_results
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status=None,
)
# Assert
assert result["total"] == 8
assert result["succeeded"] == 5
assert result["failed"] == 2
assert result["running"] == 1
assert result["stopped"] == 0
assert result["partial-succeeded"] == 0
def test_get_workflow_runs_count_with_status_filter(self, repository, mock_session_maker):
"""Test getting workflow runs count with status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the count query for succeeded status
mock_session.scalar.return_value = 5
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="succeeded",
)
# Assert
assert result["total"] == 5
assert result["succeeded"] == 5
assert result["running"] == 0
assert result["failed"] == 0
assert result["stopped"] == 0
assert result["partial-succeeded"] == 0
def test_get_workflow_runs_count_with_invalid_status(self, repository, mock_session_maker):
"""Test that invalid status is still counted in total but not in any specific status."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock count query returning 0 for invalid status
mock_session.scalar.return_value = 0
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="invalid_status",
)
# Assert
assert result["total"] == 0
assert all(result[status] == 0 for status in ["running", "succeeded", "failed", "stopped", "partial-succeeded"])
def test_get_workflow_runs_count_with_time_range(self, repository, mock_session_maker):
"""Test getting workflow runs count with time range filter verifies SQL query construction."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the GROUP BY query results
mock_results = [
("succeeded", 3),
("running", 2),
]
mock_session.execute.return_value.all.return_value = mock_results
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status=None,
time_range="1d",
)
# Assert results
assert result["total"] == 5
assert result["succeeded"] == 3
assert result["running"] == 2
assert result["failed"] == 0
# Verify that execute was called (which means GROUP BY query was used)
assert mock_session.execute.called, "execute should have been called for GROUP BY query"
# Verify SQL query includes time filter by checking the statement
call_args = mock_session.execute.call_args
assert call_args is not None, "execute should have been called with a statement"
# The first argument should be the SQL statement
stmt = call_args[0][0]
# Convert to string to inspect the query
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
# Verify the query includes created_at filter
# The query should have a WHERE clause with created_at comparison
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
"Query should include created_at filter for time range"
)
def test_get_workflow_runs_count_with_status_and_time_range(self, repository, mock_session_maker):
"""Test getting workflow runs count with both status and time range filters verifies SQL query."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the count query for running status within time range
mock_session.scalar.return_value = 2
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="running",
time_range="1d",
)
# Assert results
assert result["total"] == 2
assert result["running"] == 2
assert result["succeeded"] == 0
assert result["failed"] == 0
# Verify that scalar was called (which means COUNT query was used)
assert mock_session.scalar.called, "scalar should have been called for count query"
# Verify SQL query includes both status and time filter
call_args = mock_session.scalar.call_args
assert call_args is not None, "scalar should have been called with a statement"
# The first argument should be the SQL statement
stmt = call_args[0][0]
# Convert to string to inspect the query
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
# Verify the query includes both filters
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
"Query should include created_at filter for time range"
)
assert "status" in query_str.lower() or "workflow_runs.status" in query_str.lower(), (
"Query should include status filter"
)

View File

@ -1303,6 +1303,24 @@ class TestBillingServiceSubscriptionOperations:
# Assert
assert result == {}
def test_get_plan_bulk_converts_string_expiration_date_to_int(self, mock_send_request):
"""Test bulk plan retrieval converts string expiration_date to int."""
# Arrange
tenant_ids = ["tenant-1"]
mock_send_request.return_value = {
"data": {
"tenant-1": {"plan": "sandbox", "expiration_date": "1735689600"},
}
}
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert "tenant-1" in result
assert isinstance(result["tenant-1"]["expiration_date"], int)
assert result["tenant-1"]["expiration_date"] == 1735689600
def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request):
"""Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant)."""
# Arrange

View File

@ -15,6 +15,7 @@ from sqlalchemy import asc, desc
from core.app.entities.app_invoke_entities import InvokeFrom
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account, ConversationVariable
from models.enums import ConversationFromSource
from models.model import App, Conversation, EndUser, Message
from services.conversation_service import ConversationService
from services.errors.conversation import (
@ -350,7 +351,7 @@ class TestConversationServiceGetConversation:
app_model = ConversationServiceTestDataFactory.create_app_mock()
user = ConversationServiceTestDataFactory.create_account_mock()
conversation = ConversationServiceTestDataFactory.create_conversation_mock(
from_account_id=user.id, from_source="console"
from_account_id=user.id, from_source=ConversationFromSource.CONSOLE
)
mock_query = mock_db_session.query.return_value
@ -374,7 +375,7 @@ class TestConversationServiceGetConversation:
app_model = ConversationServiceTestDataFactory.create_app_mock()
user = ConversationServiceTestDataFactory.create_end_user_mock()
conversation = ConversationServiceTestDataFactory.create_conversation_mock(
from_end_user_id=user.id, from_source="api"
from_end_user_id=user.id, from_source=ConversationFromSource.API
)
mock_query = mock_db_session.query.return_value
@ -1111,7 +1112,7 @@ class TestConversationServiceEdgeCases:
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
conversation = ConversationServiceTestDataFactory.create_conversation_mock(
from_source="api", from_end_user_id="user-123"
from_source=ConversationFromSource.API, from_end_user_id="user-123"
)
mock_session.scalars.return_value.all.return_value = [conversation]
@ -1143,7 +1144,7 @@ class TestConversationServiceEdgeCases:
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
conversation = ConversationServiceTestDataFactory.create_conversation_mock(
from_source="console", from_account_id="account-123"
from_source=ConversationFromSource.CONSOLE, from_account_id="account-123"
)
mock_session.scalars.return_value.all.return_value = [conversation]

View File

@ -0,0 +1,558 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
MetadataArgs,
MetadataDetail,
MetadataOperationData,
)
from services.metadata_service import MetadataService
@dataclass
class _DocumentStub:
id: str
name: str
uploader: str
upload_date: datetime
last_update_date: datetime
data_source_type: str
doc_metadata: dict[str, object] | None
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
mocked_db = mocker.patch("services.metadata_service.db")
mocked_db.session = MagicMock()
return mocked_db
@pytest.fixture
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
return mocker.patch("services.metadata_service.redis_client")
@pytest.fixture
def mock_current_account(mocker: MockerFixture) -> MagicMock:
mock_user = SimpleNamespace(id="user-1")
return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1"))
def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub:
now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC)
return _DocumentStub(
id=document_id,
name=f"doc-{document_id}",
uploader="qa@example.com",
upload_date=now,
last_update_date=now,
data_source_type="upload_file",
doc_metadata=doc_metadata,
)
def _dataset(**kwargs: Any) -> Dataset:
return cast(Dataset, SimpleNamespace(**kwargs))
def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None:
# Arrange
metadata_args = MetadataArgs(type="string", name="x" * 256)
# Act + Assert
with pytest.raises(ValueError, match="cannot exceed 255"):
MetadataService.create_metadata("dataset-1", metadata_args)
def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists(
mock_db: MagicMock,
mock_current_account: MagicMock,
) -> None:
# Arrange
metadata_args = MetadataArgs(type="string", name="priority")
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
# Act + Assert
with pytest.raises(ValueError, match="already exists"):
MetadataService.create_metadata("dataset-1", metadata_args)
# Assert
mock_current_account.assert_called_once()
def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin(
mock_db: MagicMock, mock_current_account: MagicMock
) -> None:
# Arrange
metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Built-in fields"):
MetadataService.create_metadata("dataset-1", metadata_args)
def test_create_metadata_should_persist_metadata_when_input_is_valid(
mock_db: MagicMock, mock_current_account: MagicMock
) -> None:
# Arrange
metadata_args = MetadataArgs(type="number", name="score")
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act
result = MetadataService.create_metadata("dataset-1", metadata_args)
# Assert
assert result.tenant_id == "tenant-1"
assert result.dataset_id == "dataset-1"
assert result.type == "number"
assert result.name == "score"
assert result.created_by == "user-1"
mock_db.session.add.assert_called_once_with(result)
mock_db.session.commit.assert_called_once()
mock_current_account.assert_called_once()
def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None:
# Arrange
too_long_name = "x" * 256
# Act + Assert
with pytest.raises(ValueError, match="cannot exceed 255"):
MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name)
def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists(
mock_db: MagicMock, mock_current_account: MagicMock
) -> None:
# Arrange
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
# Act + Assert
with pytest.raises(ValueError, match="already exists"):
MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate")
# Assert
mock_current_account.assert_called_once()
def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin(
mock_db: MagicMock,
mock_current_account: MagicMock,
) -> None:
# Arrange
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Built-in fields"):
MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source)
# Assert
mock_current_account.assert_called_once()
def test_update_metadata_name_should_update_bound_documents_and_return_metadata(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC)
mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now)
metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None)
bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")]
query_duplicate = MagicMock()
query_duplicate.filter_by.return_value.first.return_value = None
query_metadata = MagicMock()
query_metadata.filter_by.return_value.first.return_value = metadata
query_bindings = MagicMock()
query_bindings.filter_by.return_value.all.return_value = bindings
mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings]
doc_1 = _build_document("1", {"old_name": "value", "other": "keep"})
doc_2 = _build_document("2", None)
mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids")
mock_get_documents.return_value = [doc_1, doc_2]
# Act
result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name")
# Assert
assert result is metadata
assert metadata.name == "new_name"
assert metadata.updated_by == "user-1"
assert metadata.updated_at == fixed_now
assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"}
assert doc_2.doc_metadata == {"new_name": None}
mock_get_documents.assert_called_once_with(["doc-1", "doc-2"])
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
mock_current_account.assert_called_once()
def test_update_metadata_name_should_return_none_when_metadata_does_not_exist(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
mock_logger = mocker.patch("services.metadata_service.logger")
query_duplicate = MagicMock()
query_duplicate.filter_by.return_value.first.return_value = None
query_metadata = MagicMock()
query_metadata.filter_by.return_value.first.return_value = None
mock_db.session.query.side_effect = [query_duplicate, query_metadata]
# Act
result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name")
# Assert
assert result is None
mock_logger.exception.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
mock_current_account.assert_called_once()
def test_delete_metadata_should_remove_metadata_and_related_document_fields(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
metadata = SimpleNamespace(id="metadata-1", name="obsolete")
bindings = [SimpleNamespace(document_id="doc-1")]
query_metadata = MagicMock()
query_metadata.filter_by.return_value.first.return_value = metadata
query_bindings = MagicMock()
query_bindings.filter_by.return_value.all.return_value = bindings
mock_db.session.query.side_effect = [query_metadata, query_bindings]
document = _build_document("1", {"obsolete": "legacy", "remaining": "value"})
mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document])
# Act
result = MetadataService.delete_metadata("dataset-1", "metadata-1")
# Assert
assert result is metadata
assert document.doc_metadata == {"remaining": "value"}
mock_db.session.delete.assert_called_once_with(metadata)
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_delete_metadata_should_return_none_when_metadata_is_missing(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_logger = mocker.patch("services.metadata_service.logger")
# Act
result = MetadataService.delete_metadata("dataset-1", "missing-id")
# Assert
assert result is None
mock_logger.exception.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_get_built_in_fields_should_return_all_expected_fields() -> None:
# Arrange
expected_names = {
BuiltInField.document_name,
BuiltInField.uploader,
BuiltInField.upload_date,
BuiltInField.last_update_date,
BuiltInField.source,
}
# Act
result = MetadataService.get_built_in_fields()
# Assert
assert {item["name"] for item in result} == expected_names
assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"]
def test_enable_built_in_field_should_return_immediately_when_already_enabled(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
# Act
MetadataService.enable_built_in_field(dataset)
# Assert
get_docs.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_enable_built_in_field_should_populate_documents_and_enable_flag(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
doc_1 = _build_document("1", {"custom": "value"})
doc_2 = _build_document("2", None)
mocker.patch(
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
return_value=[doc_1, doc_2],
)
# Act
MetadataService.enable_built_in_field(dataset)
# Assert
assert dataset.built_in_field_enabled is True
assert doc_1.doc_metadata is not None
assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1"
assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
assert doc_2.doc_metadata is not None
assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com"
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_disable_built_in_field_should_return_immediately_when_already_disabled(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
# Act
MetadataService.disable_built_in_field(dataset)
# Assert
get_docs.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
document = _build_document(
"1",
{
BuiltInField.document_name: "doc",
BuiltInField.uploader: "user",
BuiltInField.upload_date: 1.0,
BuiltInField.last_update_date: 2.0,
BuiltInField.source: MetadataDataSource.upload_file,
"custom": "keep",
},
)
mocker.patch(
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
return_value=[document],
)
# Act
MetadataService.disable_built_in_field(dataset)
# Assert
assert dataset.built_in_field_enabled is False
assert document.doc_metadata == {"custom": "keep"}
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
document = _build_document("1", {"legacy": "value"})
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
delete_chain = mock_db.session.query.return_value.filter_by.return_value
delete_chain.delete.return_value = 1
operation = DocumentMetadataOperation(
document_id="1",
metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")],
partial_update=False,
)
metadata_args = MetadataOperationData(operation_data=[operation])
# Act
MetadataService.update_documents_metadata(dataset, metadata_args)
# Assert
assert document.doc_metadata == {"priority": "high"}
delete_chain.delete.assert_called_once()
assert mock_db.session.commit.call_count == 1
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
mock_current_account.assert_called_once()
def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
document = _build_document("1", {"existing": "value"})
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
operation = DocumentMetadataOperation(
document_id="1",
metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")],
partial_update=True,
)
metadata_args = MetadataOperationData(operation_data=[operation])
# Act
MetadataService.update_documents_metadata(dataset, metadata_args)
# Assert
assert document.doc_metadata is not None
assert document.doc_metadata["existing"] == "value"
assert document.doc_metadata["new_key"] == "new_value"
assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
assert mock_db.session.commit.call_count == 1
assert mock_db.session.add.call_count == 1
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
mock_current_account.assert_called_once()
def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None)
operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True)
metadata_args = MetadataOperationData(operation_data=[operation])
# Act + Assert
with pytest.raises(ValueError, match="Document not found"):
MetadataService.update_documents_metadata(dataset, metadata_args)
# Assert
mock_db.session.rollback.assert_called_once()
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404")
@pytest.mark.parametrize(
("dataset_id", "document_id", "expected_key"),
[
("dataset-1", None, "dataset_metadata_lock_dataset-1"),
(None, "doc-1", "document_metadata_lock_doc-1"),
],
)
def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked(
dataset_id: str | None,
document_id: str | None,
expected_key: str,
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act
MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id)
# Assert
mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600)
def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = 1
# Act + Assert
with pytest.raises(ValueError, match="knowledge base metadata operation is running"):
MetadataService.knowledge_base_metadata_lock_check("dataset-1", None)
def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = 1
# Act + Assert
with pytest.raises(ValueError, match="document metadata operation is running"):
MetadataService.knowledge_base_metadata_lock_check(None, "doc-1")
def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None:
# Arrange
dataset = _dataset(
id="dataset-1",
built_in_field_enabled=True,
doc_metadata=[
{"id": "meta-1", "name": "priority", "type": "string"},
{"id": "built-in", "name": "ignored", "type": "string"},
{"id": "meta-2", "name": "score", "type": "number"},
],
)
count_chain = mock_db.session.query.return_value.filter_by.return_value
count_chain.count.side_effect = [3, 1]
# Act
result = MetadataService.get_dataset_metadatas(dataset)
# Assert
assert result["built_in_field_enabled"] is True
assert result["doc_metadata"] == [
{"id": "meta-1", "name": "priority", "type": "string", "count": 3},
{"id": "meta-2", "name": "score", "type": "number", "count": 1},
]
def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None:
# Arrange
dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None)
# Act
result = MetadataService.get_dataset_metadatas(dataset)
# Assert
assert result == {"doc_metadata": [], "built_in_field_enabled": False}
mock_db.session.query.assert_not_called()

View File

@ -0,0 +1,808 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from constants import HIDDEN_VALUE
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.provider_entities import (
CredentialFormSchema,
FieldModelSchema,
FormType,
ModelCredentialSchema,
ProviderCredentialSchema,
)
from models.provider import LoadBalancingModelConfig
from services.model_load_balancing_service import ModelLoadBalancingService
def _build_provider_credential_schema() -> ProviderCredentialSchema:
return ProviderCredentialSchema(
credential_form_schemas=[
CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT)
]
)
def _build_model_credential_schema() -> ModelCredentialSchema:
return ModelCredentialSchema(
model=FieldModelSchema(label=I18nObject(en_US="Model")),
credential_form_schemas=[
CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT)
],
)
def _build_provider_configuration(
*,
custom_provider: bool = False,
load_balancing_enabled: bool | None = None,
model_schema: ModelCredentialSchema | None = None,
provider_schema: ProviderCredentialSchema | None = None,
) -> MagicMock:
provider_configuration = MagicMock()
provider_configuration.provider = SimpleNamespace(
provider="openai",
model_credential_schema=model_schema,
provider_credential_schema=provider_schema,
)
provider_configuration.custom_configuration = SimpleNamespace(provider=custom_provider)
provider_configuration.extract_secret_variables.return_value = ["api_key"]
provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: credentials
provider_configuration.get_provider_model_setting.return_value = (
None if load_balancing_enabled is None else SimpleNamespace(load_balancing_enabled=load_balancing_enabled)
)
return provider_configuration
def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig:
return cast(LoadBalancingModelConfig, SimpleNamespace(**kwargs))
@pytest.fixture
def service(mocker: MockerFixture) -> ModelLoadBalancingService:
# Arrange
provider_manager = MagicMock()
mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager)
svc = ModelLoadBalancingService()
svc.provider_manager = provider_manager
return svc
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
# Arrange
mocked_db = mocker.patch("services.model_load_balancing_service.db")
mocked_db.session = MagicMock()
return mocked_db
@pytest.mark.parametrize(
("method_name", "expected_provider_method"),
[
("enable_model_load_balancing", "enable_model_load_balancing"),
("disable_model_load_balancing", "disable_model_load_balancing"),
],
)
def test_enable_disable_model_load_balancing_should_call_provider_configuration_method_when_provider_exists(
method_name: str,
expected_provider_method: str,
service: ModelLoadBalancingService,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
# Act
getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
# Assert
getattr(provider_configuration, expected_provider_method).assert_called_once_with(
model="gpt-4o-mini", model_type=ModelType.LLM
)
@pytest.mark.parametrize(
"method_name",
["enable_model_load_balancing", "disable_model_load_balancing"],
)
def test_enable_disable_model_load_balancing_should_raise_value_error_when_provider_missing(
method_name: str,
service: ModelLoadBalancingService,
) -> None:
# Arrange
service.provider_manager.get_configurations.return_value = {}
# Act + Assert
with pytest.raises(ValueError, match="Provider openai does not exist"):
getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
def test_get_load_balancing_configs_should_raise_value_error_when_provider_missing(
service: ModelLoadBalancingService,
) -> None:
# Arrange
service.provider_manager.get_configurations.return_value = {}
# Act + Assert
with pytest.raises(ValueError, match="Provider openai does not exist"):
service.get_load_balancing_configs("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value)
def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_for_custom_provider(
service: ModelLoadBalancingService,
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(
custom_provider=True,
load_balancing_enabled=True,
provider_schema=_build_provider_credential_schema(),
)
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
config = SimpleNamespace(
id="cfg-1",
name="primary",
encrypted_config=json.dumps({"api_key": "encrypted-key"}),
credential_id="cred-1",
enabled=True,
)
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config]
mocker.patch(
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
return_value=("rsa", "cipher"),
)
mocker.patch(
"services.model_load_balancing_service.encrypter.decrypt_token_with_decoding",
return_value="plain-key",
)
mocker.patch(
"services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl",
return_value=(False, 0),
)
# Act
is_enabled, configs = service.get_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
)
# Assert
assert is_enabled is True
assert len(configs) == 2
assert configs[0]["name"] == "__inherit__"
assert configs[1]["name"] == "primary"
assert configs[1]["credentials"] == {"api_key": "plain-key"}
assert mock_db.session.add.call_count == 1
assert mock_db.session.commit.call_count == 1
def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate_json_or_decrypt_errors(
service: ModelLoadBalancingService,
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(
custom_provider=True,
load_balancing_enabled=None,
provider_schema=_build_provider_credential_schema(),
)
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
normal_config = SimpleNamespace(
id="cfg-1",
name="normal",
encrypted_config=json.dumps({"api_key": "bad-encrypted"}),
credential_id="cred-1",
enabled=True,
)
inherit_config = SimpleNamespace(
id="cfg-2",
name="__inherit__",
encrypted_config="not-json",
credential_id=None,
enabled=False,
)
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [
normal_config,
inherit_config,
]
mocker.patch(
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
return_value=("rsa", "cipher"),
)
mocker.patch(
"services.model_load_balancing_service.encrypter.decrypt_token_with_decoding",
side_effect=ValueError("cannot decrypt"),
)
mocker.patch(
"services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl",
return_value=(True, 15),
)
# Act
is_enabled, configs = service.get_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
config_from="predefined-model",
)
# Assert
assert is_enabled is False
assert configs[0]["name"] == "__inherit__"
assert configs[0]["credentials"] == {}
assert configs[1]["credentials"] == {"api_key": "bad-encrypted"}
assert configs[1]["in_cooldown"] is True
assert configs[1]["ttl"] == 15
def test_get_load_balancing_config_should_raise_value_error_when_provider_missing(
service: ModelLoadBalancingService,
) -> None:
# Arrange
service.provider_manager.get_configurations.return_value = {}
# Act + Assert
with pytest.raises(ValueError, match="Provider openai does not exist"):
service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
def test_get_load_balancing_config_should_return_none_when_config_not_found(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
# Assert
assert result is None
def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_exists(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: {
"masked": credentials.get("api_key", "")
}
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True)
mock_db.session.query.return_value.where.return_value.first.return_value = config
# Act
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
# Assert
assert result == {
"id": "cfg-1",
"name": "primary",
"credentials": {"masked": ""},
"enabled": True,
}
def test_init_inherit_config_should_create_and_persist_inherit_configuration(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
model_type = ModelType.LLM
# Act
inherit_config = service._init_inherit_config("tenant-1", "openai", "gpt-4o-mini", model_type)
# Assert
assert inherit_config.tenant_id == "tenant-1"
assert inherit_config.provider_name == "openai"
assert inherit_config.model_name == "gpt-4o-mini"
assert inherit_config.model_type == "text-generation"
assert inherit_config.name == "__inherit__"
mock_db.session.add.assert_called_once_with(inherit_config)
mock_db.session.commit.assert_called_once()
def test_update_load_balancing_configs_should_raise_value_error_when_provider_missing(
service: ModelLoadBalancingService,
) -> None:
# Arrange
service.provider_manager.get_configurations.return_value = {}
# Act + Assert
with pytest.raises(ValueError, match="Provider openai does not exist"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[],
"custom-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_configs_is_not_list(
service: ModelLoadBalancingService,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing configs"):
service.update_load_balancing_configs( # type: ignore[arg-type]
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
cast(list[dict[str, object]], "invalid-configs"),
"custom-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_config_item_is_not_dict(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing config"):
service.update_load_balancing_configs( # type: ignore[list-item]
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
cast(list[dict[str, object]], ["bad-item"]),
"custom-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_credential_id_not_found(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"credential_id": "cred-1", "enabled": True}],
"predefined-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_name_or_enabled_is_invalid(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing config name"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"enabled": True}],
"custom-model",
)
with pytest.raises(ValueError, match="Invalid load balancing config enabled"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"name": "cfg-without-enabled"}],
"custom-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_existing_config_id_is_invalid(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
current_config = SimpleNamespace(id="cfg-1")
mock_db.session.scalars.return_value.all.return_value = [current_config]
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing config id: cfg-2"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"id": "cfg-2", "name": "invalid", "enabled": True}],
"custom-model",
)
def test_update_load_balancing_configs_should_raise_value_error_when_credentials_are_invalid_for_update_or_create(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
existing_config = SimpleNamespace(id="cfg-1", name="old", enabled=True, encrypted_config=None, updated_at=None)
mock_db.session.scalars.return_value.all.return_value = [existing_config]
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"id": "cfg-1", "name": "new", "enabled": True, "credentials": "bad"}],
"custom-model",
)
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"name": "new-config", "enabled": True, "credentials": "bad"}],
"custom-model",
)
def test_update_load_balancing_configs_should_update_existing_create_new_and_delete_removed_configs(
service: ModelLoadBalancingService,
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
existing_config_1 = SimpleNamespace(
id="cfg-1",
name="existing-one",
enabled=True,
encrypted_config=json.dumps({"api_key": "old"}),
updated_at=None,
)
existing_config_2 = SimpleNamespace(
id="cfg-2",
name="existing-two",
enabled=True,
encrypted_config=None,
updated_at=None,
)
mock_db.session.scalars.return_value.all.return_value = [existing_config_1, existing_config_2]
mocker.patch.object(service, "_custom_credentials_validate", return_value={"api_key": "encrypted"})
mock_clear_cache = mocker.patch.object(service, "_clear_credentials_cache")
# Act
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[
{"id": "cfg-1", "name": "updated-name", "enabled": False, "credentials": {"api_key": "plain"}},
{"name": "new-config", "enabled": True, "credentials": {"api_key": "plain"}},
],
"custom-model",
)
# Assert
assert existing_config_1.name == "updated-name"
assert existing_config_1.enabled is False
assert json.loads(existing_config_1.encrypted_config) == {"api_key": "encrypted"}
assert mock_db.session.add.call_count == 1
mock_db.session.delete.assert_called_once_with(existing_config_2)
assert mock_db.session.commit.call_count >= 3
mock_clear_cache.assert_any_call("tenant-1", "cfg-1")
mock_clear_cache.assert_any_call("tenant-1", "cfg-2")
def test_update_load_balancing_configs_should_raise_value_error_for_invalid_new_config_name_or_missing_credentials(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
# Act + Assert
with pytest.raises(ValueError, match="Invalid load balancing config name"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"name": "__inherit__", "enabled": True, "credentials": {"api_key": "x"}}],
"custom-model",
)
with pytest.raises(ValueError, match="Invalid load balancing config credentials"):
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"name": "new", "enabled": True}],
"custom-model",
)
def test_update_load_balancing_configs_should_create_from_existing_provider_credential_when_credential_id_provided(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = []
credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}')
mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record
# Act
service.update_load_balancing_configs(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
[{"credential_id": "cred-1", "enabled": True}],
"predefined-model",
)
# Assert
created_config = mock_db.session.add.call_args.args[0]
assert created_config.name == "Main Credential"
assert created_config.credential_id == "cred-1"
assert created_config.credential_source_type == "provider"
assert created_config.encrypted_config == '{"api_key":"enc"}'
mock_db.session.commit.assert_called()
def test_validate_load_balancing_credentials_should_raise_value_error_when_provider_missing(
service: ModelLoadBalancingService,
) -> None:
# Arrange
service.provider_manager.get_configurations.return_value = {}
# Act + Assert
with pytest.raises(ValueError, match="Provider openai does not exist"):
service.validate_load_balancing_credentials(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
{"api_key": "plain"},
)
def test_validate_load_balancing_credentials_should_raise_value_error_when_config_id_is_invalid(
service: ModelLoadBalancingService,
mock_db: MagicMock,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"):
service.validate_load_balancing_credentials(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
{"api_key": "plain"},
config_id="cfg-1",
)
def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_with_or_without_config(
service: ModelLoadBalancingService,
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
existing_config = SimpleNamespace(id="cfg-1")
mock_db.session.query.return_value.where.return_value.first.return_value = existing_config
mock_validate = mocker.patch.object(service, "_custom_credentials_validate")
# Act
service.validate_load_balancing_credentials(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
{"api_key": "plain"},
config_id="cfg-1",
)
service.validate_load_balancing_credentials(
"tenant-1",
"openai",
"gpt-4o-mini",
ModelType.LLM.value,
{"api_key": "plain"},
)
# Assert
assert mock_validate.call_count == 2
assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config
assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None
def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt(
service: ModelLoadBalancingService,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
load_balancing_model_config = _load_balancing_model_config(
encrypted_config=json.dumps({"api_key": "old-encrypted-token"})
)
mocker.patch("services.model_load_balancing_service.encrypter.decrypt_token", return_value="old-plain-value")
mock_encrypt = mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}",
)
# Act
result = service._custom_credentials_validate(
tenant_id="tenant-1",
provider_configuration=provider_configuration,
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"api_key": HIDDEN_VALUE, "region": "us"},
load_balancing_model_config=load_balancing_model_config,
validate=False,
)
# Assert
assert result == {"api_key": "enc:old-plain-value", "region": "us"}
mock_encrypt.assert_called_once_with("tenant-1", "old-plain-value")
def test_custom_credentials_validate_should_handle_invalid_original_json_and_validate_with_model_schema(
service: ModelLoadBalancingService,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(model_schema=_build_model_credential_schema())
load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json")
mock_factory = MagicMock()
mock_factory.model_credentials_validate.return_value = {"api_key": "validated"}
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
mock_encrypt = mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}",
)
# Act
result = service._custom_credentials_validate(
tenant_id="tenant-1",
provider_configuration=provider_configuration,
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"api_key": "plain"},
load_balancing_model_config=load_balancing_model_config,
validate=True,
)
# Assert
assert result == {"api_key": "enc:validated"}
mock_factory.model_credentials_validate.assert_called_once()
mock_factory.provider_credentials_validate.assert_not_called()
mock_encrypt.assert_called_once_with("tenant-1", "validated")
def test_custom_credentials_validate_should_validate_with_provider_schema_when_model_schema_absent(
service: ModelLoadBalancingService,
mocker: MockerFixture,
) -> None:
# Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
mock_factory = MagicMock()
mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"}
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}",
)
# Act
result = service._custom_credentials_validate(
tenant_id="tenant-1",
provider_configuration=provider_configuration,
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"api_key": "plain"},
validate=True,
)
# Assert
assert result == {"api_key": "enc:provider-validated"}
mock_factory.provider_credentials_validate.assert_called_once()
mock_factory.model_credentials_validate.assert_not_called()
def test_get_credential_schema_should_return_model_schema_or_provider_schema_or_raise(
service: ModelLoadBalancingService,
) -> None:
# Arrange
model_schema = _build_model_credential_schema()
provider_schema = _build_provider_credential_schema()
provider_configuration_with_model = _build_provider_configuration(model_schema=model_schema)
provider_configuration_with_provider = _build_provider_configuration(provider_schema=provider_schema)
provider_configuration_without_schema = _build_provider_configuration()
# Act
schema_from_model = service._get_credential_schema(provider_configuration_with_model)
schema_from_provider = service._get_credential_schema(provider_configuration_with_provider)
# Assert
assert schema_from_model is model_schema
assert schema_from_provider is provider_schema
with pytest.raises(ValueError, match="No credential schema found"):
service._get_credential_schema(provider_configuration_without_schema)
def test_clear_credentials_cache_should_delete_load_balancing_cache_entry(
service: ModelLoadBalancingService,
mocker: MockerFixture,
) -> None:
# Arrange
mock_cache_instance = MagicMock()
mock_cache_cls = mocker.patch(
"services.model_load_balancing_service.ProviderCredentialsCache",
return_value=mock_cache_instance,
)
# Act
service._clear_credentials_cache("tenant-1", "cfg-1")
# Assert
mock_cache_cls.assert_called_once()
assert mock_cache_cls.call_args.kwargs == {
"tenant_id": "tenant-1",
"identity_id": "cfg-1",
"cache_type": mocker.ANY,
}
assert mock_cache_cls.call_args.kwargs["cache_type"].name == "LOAD_BALANCING_MODEL"
mock_cache_instance.delete.assert_called_once()

View File

@ -0,0 +1,224 @@
from __future__ import annotations
import uuid
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from werkzeug.exceptions import BadRequest
from services.oauth_server import (
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
OAUTH_ACCESS_TOKEN_REDIS_KEY,
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
OAUTH_REFRESH_TOKEN_REDIS_KEY,
OAuthGrantType,
OAuthServerService,
)
@pytest.fixture
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
return mocker.patch("services.oauth_server.redis_client")
@pytest.fixture
def mock_session(mocker: MockerFixture) -> MagicMock:
"""Mock the OAuth server Session context manager."""
mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object()))
session = MagicMock()
session_cm = MagicMock()
session_cm.__enter__.return_value = session
mocker.patch("services.oauth_server.Session", return_value=session_cm)
return session
def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None:
# Arrange
mock_execute_result = MagicMock()
expected_app = MagicMock()
mock_execute_result.scalar_one_or_none.return_value = expected_app
mock_session.execute.return_value = mock_execute_result
# Act
result = OAuthServerService.get_oauth_provider_app("client-1")
# Assert
assert result is expected_app
mock_session.execute.assert_called_once()
mock_execute_result.scalar_one_or_none.assert_called_once()
def test_sign_oauth_authorization_code_should_store_code_and_return_value(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
# Act
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
# Assert
expected_code = str(deterministic_uuid)
assert code == expected_code
mock_redis_client.set.assert_called_once_with(
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code),
"user-1",
ex=600,
)
def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act + Assert
with pytest.raises(BadRequest, match="invalid code"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="bad-code",
client_id="client-1",
)
def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
token_uuids = [
uuid.UUID("00000000-0000-0000-0000-000000000201"),
uuid.UUID("00000000-0000-0000-0000-000000000202"),
]
mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids)
mock_redis_client.get.return_value = b"user-1"
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
# Act
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="code-1",
client_id="client-1",
)
# Assert
assert access_token == str(token_uuids[0])
assert refresh_token == str(token_uuids[1])
mock_redis_client.delete.assert_called_once_with(code_key)
mock_redis_client.set.assert_any_call(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
mock_redis_client.set.assert_any_call(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
b"user-1",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act + Assert
with pytest.raises(BadRequest, match="invalid refresh token"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="stale-token",
client_id="client-1",
)
def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
mock_redis_client.get.return_value = b"user-1"
# Act
access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="refresh-1",
client_id="client-1",
)
# Assert
assert access_token == str(deterministic_uuid)
assert returned_refresh_token == "refresh-1"
mock_redis_client.set.assert_called_once_with(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None:
# Arrange
grant_type = cast(OAuthGrantType, "invalid-grant-type")
# Act
result = OAuthServerService.sign_oauth_access_token(
grant_type=grant_type,
client_id="client-1",
)
# Assert
assert result is None
def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
# Act
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
# Assert
assert refresh_token == str(deterministic_uuid)
mock_redis_client.set.assert_called_once_with(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
"user-2",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_validate_oauth_access_token_should_return_none_when_token_not_found(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
# Assert
assert result is None
def test_validate_oauth_access_token_should_load_user_when_token_exists(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
mock_redis_client.get.return_value = b"user-88"
expected_user = MagicMock()
mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user)
# Act
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
# Assert
assert result is expected_user
mock_load_user.assert_called_once_with("user-88")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,259 @@
from __future__ import annotations
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.app.entities.app_invoke_entities import InvokeFrom
from models import Account
from models.model import App, EndUser
from services.web_conversation_service import WebConversationService
@pytest.fixture
def app_model() -> App:
return cast(App, SimpleNamespace(id="app-1"))
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
def _end_user(**kwargs: Any) -> EndUser:
return cast(EndUser, SimpleNamespace(**kwargs))
def test_pagination_by_last_id_should_raise_error_when_user_is_none(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
# Act + Assert
with pytest.raises(ValueError, match="User is required"):
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=None,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
)
def test_pagination_by_last_id_should_forward_without_pin_filter_when_pinned_is_none(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
fake_user = _account(id="user-1")
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
mock_pagination.return_value = MagicMock()
# Act
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=fake_user,
last_id="conv-9",
limit=10,
invoke_from=InvokeFrom.WEB_APP,
pinned=None,
)
# Assert
call_kwargs = mock_pagination.call_args.kwargs
assert call_kwargs["include_ids"] is None
assert call_kwargs["exclude_ids"] is None
assert call_kwargs["last_id"] == "conv-9"
assert call_kwargs["sort_by"] == "-updated_at"
def test_pagination_by_last_id_should_include_only_pinned_ids_when_pinned_true(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
fake_account_cls = type("FakeAccount", (), {})
fake_user = cast(Account, fake_account_cls())
fake_user.id = "account-1"
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
mocker.patch("services.web_conversation_service.EndUser", type("FakeEndUser", (), {}))
session.scalars.return_value.all.return_value = ["conv-1", "conv-2"]
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
mock_pagination.return_value = MagicMock()
# Act
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=fake_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
pinned=True,
)
# Assert
call_kwargs = mock_pagination.call_args.kwargs
assert call_kwargs["include_ids"] == ["conv-1", "conv-2"]
assert call_kwargs["exclude_ids"] is None
def test_pagination_by_last_id_should_exclude_pinned_ids_when_pinned_false(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
fake_end_user_cls = type("FakeEndUser", (), {})
fake_user = cast(EndUser, fake_end_user_cls())
fake_user.id = "end-user-1"
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
session.scalars.return_value.all.return_value = ["conv-3"]
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
mock_pagination.return_value = MagicMock()
# Act
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=fake_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
pinned=False,
)
# Assert
call_kwargs = mock_pagination.call_args.kwargs
assert call_kwargs["include_ids"] is None
assert call_kwargs["exclude_ids"] == ["conv-3"]
def test_pin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
# Arrange
mock_db = mocker.patch("services.web_conversation_service.db")
mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
# Act
WebConversationService.pin(app_model, "conv-1", None)
# Assert
mock_db.session.add.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_pin_should_return_early_when_conversation_is_already_pinned(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_account_cls = type("FakeAccount", (), {})
fake_user = cast(Account, fake_account_cls())
fake_user.id = "account-1"
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
mock_db.session.query.return_value.where.return_value.first.return_value = object()
mock_get_conversation = mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
# Act
WebConversationService.pin(app_model, "conv-1", fake_user)
# Assert
mock_get_conversation.assert_not_called()
mock_db.session.add.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_pin_should_create_pinned_conversation_when_not_already_pinned(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_account_cls = type("FakeAccount", (), {})
fake_user = cast(Account, fake_account_cls())
fake_user.id = "account-2"
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_conversation = SimpleNamespace(id="conv-2")
mock_get_conversation = mocker.patch(
"services.web_conversation_service.ConversationService.get_conversation",
return_value=mock_conversation,
)
# Act
WebConversationService.pin(app_model, "conv-2", fake_user)
# Assert
mock_get_conversation.assert_called_once_with(app_model=app_model, conversation_id="conv-2", user=fake_user)
added_obj = mock_db.session.add.call_args.args[0]
assert added_obj.app_id == "app-1"
assert added_obj.conversation_id == "conv-2"
assert added_obj.created_by_role == "account"
assert added_obj.created_by == "account-2"
mock_db.session.commit.assert_called_once()
def test_unpin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
# Arrange
mock_db = mocker.patch("services.web_conversation_service.db")
# Act
WebConversationService.unpin(app_model, "conv-1", None)
# Assert
mock_db.session.delete.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_unpin_should_return_early_when_conversation_is_not_pinned(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_end_user_cls = type("FakeEndUser", (), {})
fake_user = cast(EndUser, fake_end_user_cls())
fake_user.id = "end-user-3"
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act
WebConversationService.unpin(app_model, "conv-7", fake_user)
# Assert
mock_db.session.delete.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_unpin_should_delete_pinned_conversation_when_exists(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_end_user_cls = type("FakeEndUser", (), {})
fake_user = cast(EndUser, fake_end_user_cls())
fake_user.id = "end-user-4"
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
pinned_obj = SimpleNamespace(id="pin-1")
mock_db.session.query.return_value.where.return_value.first.return_value = pinned_obj
# Act
WebConversationService.unpin(app_model, "conv-8", fake_user)
# Assert
mock_db.session.delete.assert_called_once_with(pinned_obj)
mock_db.session.commit.assert_called_once()

View File

@ -0,0 +1,379 @@
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from werkzeug.exceptions import NotFound, Unauthorized
from models import Account, AccountStatus
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback"
TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token"
TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data"
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
# Arrange
mocked_db = mocker.patch("services.webapp_auth_service.db")
mocked_db.session = MagicMock()
return mocked_db
def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
# Act + Assert
with pytest.raises(AccountNotFoundError):
WebAppAuthService.authenticate("user@example.com", "pwd")
def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act + Assert
with pytest.raises(AccountLoginError, match="Account is banned"):
WebAppAuthService.authenticate("user@example.com", "pwd")
@pytest.mark.parametrize("password_value", [None, "hash"])
def test_authenticate_should_raise_password_error_when_password_is_invalid(
password_value: str | None,
mocker: MockerFixture,
) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
mocker.patch("services.webapp_auth_service.compare_password", return_value=False)
# Act + Assert
with pytest.raises(AccountPasswordError, match="Invalid email or password"):
WebAppAuthService.authenticate("user@example.com", "pwd")
def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt")
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
mocker.patch("services.webapp_auth_service.compare_password", return_value=True)
# Act
result = WebAppAuthService.authenticate("user@example.com", "pwd")
# Assert
assert result is account
def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None:
# Arrange
account = _account(id="a1", email="u@example.com")
mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token")
# Act
result = WebAppAuthService.login(account)
# Assert
assert result == "jwt-token"
mock_get_token.assert_called_once_with(account=account)
def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None)
# Act
result = WebAppAuthService.get_user_through_email("missing@example.com")
# Assert
assert result is None
def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.BANNED)
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act + Assert
with pytest.raises(Unauthorized, match="Account is banned"):
WebAppAuthService.get_user_through_email("user@example.com")
def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None:
# Arrange
account = SimpleNamespace(status=AccountStatus.ACTIVE)
mocker.patch(
ACCOUNT_LOOKUP_PATH,
return_value=account,
)
# Act
result = WebAppAuthService.get_user_through_email("user@example.com")
# Assert
assert result is account
def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Email must be provided"):
WebAppAuthService.send_email_code_login_email(account=None, email=None)
def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account(
mocker: MockerFixture,
) -> None:
# Arrange
account = _account(email="user@example.com")
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6])
mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1")
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
# Act
result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US")
# Assert
assert result == "token-1"
mock_generate_token.assert_called_once()
assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"}
mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456")
def test_send_email_code_login_email_should_send_mail_for_email_without_account(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0])
mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2")
mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay")
# Act
result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans")
# Assert
assert result == "token-2"
mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000")
def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
# Arrange
mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"})
# Act
result = WebAppAuthService.get_email_code_login_data("token-abc")
# Assert
assert result == {"code": "123"}
mock_get_data.assert_called_once_with("token-abc", "email_code_login")
def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None:
# Arrange
mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token")
# Act
WebAppAuthService.revoke_email_code_login_token("token-xyz")
# Assert
mock_revoke.assert_called_once_with("token-xyz", "email_code_login")
def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(NotFound, match="Site not found"):
WebAppAuthService.create_end_user("app-code", "user@example.com")
def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None:
# Arrange
site = SimpleNamespace(app_id="app-1")
app_query = MagicMock()
app_query.where.return_value.first.return_value = None
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None]
# Act + Assert
with pytest.raises(NotFound, match="App not found"):
WebAppAuthService.create_end_user("app-code", "user@example.com")
def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None:
# Arrange
site = SimpleNamespace(app_id="app-1")
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model]
# Act
result = WebAppAuthService.create_end_user("app-code", "user@example.com")
# Assert
assert result.tenant_id == "tenant-1"
assert result.app_id == "app-1"
assert result.session_id == "user@example.com"
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None:
# Arrange
account = _account(id="a1", email="user@example.com")
mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60)
mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1")
# Act
token = WebAppAuthService._get_account_jwt_token(account)
# Assert
assert token == "jwt-1"
payload = mock_issue.call_args.args[0]
assert payload["user_id"] == "a1"
assert payload["session_id"] == "user@example.com"
assert payload["token_source"] == "webapp_login_token"
assert payload["auth_type"] == "internal"
assert payload["exp"] > int(datetime.now(UTC).timestamp())
@pytest.mark.parametrize(
("access_mode", "expected"),
[
("private", True),
("private_all", True),
("public", False),
],
)
def test_is_app_require_permission_check_should_use_access_mode_when_provided(
access_mode: str,
expected: bool,
) -> None:
# Arrange
# Act
result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode)
# Assert
assert result is expected
def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Either app_code or app_id must be provided"):
WebAppAuthService.is_app_require_permission_check()
def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None)
# Act + Assert
with pytest.raises(ValueError, match="App ID could not be determined"):
WebAppAuthService.is_app_require_permission_check(app_code="app-code")
def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="private"),
)
# Act
result = WebAppAuthService.is_app_require_permission_check(app_code="app-code")
# Assert
assert result is True
def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="public"),
)
# Act
result = WebAppAuthService.is_app_require_permission_check(app_id="app-1")
# Assert
assert result is False
@pytest.mark.parametrize(
("access_mode", "expected"),
[
("public", WebAppAuthType.PUBLIC),
("private", WebAppAuthType.INTERNAL),
("private_all", WebAppAuthType.INTERNAL),
("sso_verified", WebAppAuthType.EXTERNAL),
],
)
def test_get_app_auth_type_should_map_access_modes_correctly(
access_mode: str,
expected: WebAppAuthType,
) -> None:
# Arrange
# Act
result = WebAppAuthService.get_app_auth_type(access_mode=access_mode)
# Assert
assert result == expected
def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None:
# Arrange
mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1")
mocker.patch(
"services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=SimpleNamespace(access_mode="private_all"),
)
# Act
result = WebAppAuthService.get_app_auth_type(app_code="app-code")
# Assert
assert result == WebAppAuthType.INTERNAL
def test_get_app_auth_type_should_raise_when_no_input_provided() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"):
WebAppAuthService.get_app_auth_type()
def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None:
# Arrange
# Act + Assert
with pytest.raises(ValueError, match="Could not determine app authentication type"):
WebAppAuthService.get_app_auth_type(access_mode="unknown")

View File

@ -0,0 +1,300 @@
from __future__ import annotations
import json
import uuid
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from dify_graph.enums import WorkflowExecutionStatus
from models import App, WorkflowAppLog
from models.enums import AppTriggerType, CreatorUserRole
from services.workflow_app_service import LogView, WorkflowAppService
@pytest.fixture
def service() -> WorkflowAppService:
# Arrange
return WorkflowAppService()
@pytest.fixture
def app_model() -> App:
# Arrange
return cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1"))
def _workflow_app_log(**kwargs: Any) -> WorkflowAppLog:
return cast(WorkflowAppLog, SimpleNamespace(**kwargs))
def test_log_view_details_should_return_wrapped_details_and_proxy_attributes() -> None:
# Arrange
log = _workflow_app_log(id="log-1", status="succeeded")
view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}})
# Act
details = view.details
proxied_status = view.status
# Assert
assert details == {"trigger_metadata": {"type": "plugin"}}
assert proxied_status == "succeeded"
def test_get_paginate_workflow_app_logs_should_return_paginated_summary_when_detail_false(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
log_1 = SimpleNamespace(id="log-1")
log_2 = SimpleNamespace(id="log-2")
session.scalar.return_value = 3
session.scalars.return_value.all.return_value = [log_1, log_2]
# Act
result = service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
page=1,
limit=2,
detail=False,
)
# Assert
assert result["page"] == 1
assert result["limit"] == 2
assert result["total"] == 3
assert result["has_more"] is True
assert len(result["data"]) == 2
assert isinstance(result["data"][0], LogView)
assert result["data"][0].details is None
def test_get_paginate_workflow_app_logs_should_return_detailed_rows_when_detail_true(
service: WorkflowAppService,
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
session.scalar.side_effect = [1]
log_1 = SimpleNamespace(id="log-1")
session.execute.return_value.all.return_value = [(log_1, '{"type":"trigger_plugin"}')]
mock_handle = mocker.patch.object(
service,
"handle_trigger_metadata",
return_value={"type": "trigger_plugin", "icon": "url"},
)
# Act
result = service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
keyword="run-1",
status=WorkflowExecutionStatus.SUCCEEDED,
created_at_before=None,
created_at_after=None,
page=1,
limit=20,
detail=True,
)
# Assert
assert result["total"] == 1
assert len(result["data"]) == 1
assert result["data"][0].details == {"trigger_metadata": {"type": "trigger_plugin", "icon": "url"}}
mock_handle.assert_called_once()
def test_get_paginate_workflow_app_logs_should_raise_when_account_filter_email_not_found(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Account not found: account@example.com"):
service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
created_by_account="account@example.com",
)
def test_get_paginate_workflow_app_logs_should_filter_by_account_when_account_exists(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
session.scalar.side_effect = [SimpleNamespace(id="account-1"), 0]
session.scalars.return_value.all.return_value = []
# Act
result = service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
created_by_account="account@example.com",
)
# Assert
assert result["total"] == 0
assert result["data"] == []
def test_get_paginate_workflow_archive_logs_should_return_paginated_archive_items(
service: WorkflowAppService,
app_model: App,
) -> None:
# Arrange
session = MagicMock()
log_account = SimpleNamespace(
id="log-1",
created_by="acc-1",
created_by_role=CreatorUserRole.ACCOUNT,
workflow_run_summary={"run": "1"},
trigger_metadata='{"type":"trigger-webhook"}',
log_created_at="2026-01-01",
)
log_end_user = SimpleNamespace(
id="log-2",
created_by="end-1",
created_by_role=CreatorUserRole.END_USER,
workflow_run_summary={"run": "2"},
trigger_metadata='{"type":"trigger-webhook"}',
log_created_at="2026-01-02",
)
log_unknown = SimpleNamespace(
id="log-3",
created_by="other",
created_by_role="system",
workflow_run_summary={"run": "3"},
trigger_metadata='{"type":"trigger-webhook"}',
log_created_at="2026-01-03",
)
session.scalar.return_value = 3
session.scalars.side_effect = [
SimpleNamespace(all=lambda: [log_account, log_end_user, log_unknown]),
SimpleNamespace(all=lambda: [SimpleNamespace(id="acc-1", email="a@example.com")]),
SimpleNamespace(all=lambda: [SimpleNamespace(id="end-1", session_id="session-1")]),
]
# Act
result = service.get_paginate_workflow_archive_logs(
session=session,
app_model=app_model,
page=1,
limit=20,
)
# Assert
assert result["total"] == 3
assert len(result["data"]) == 3
assert result["data"][0]["created_by_account"].id == "acc-1"
assert result["data"][1]["created_by_end_user"].id == "end-1"
assert result["data"][2]["created_by_account"] is None
assert result["data"][2]["created_by_end_user"] is None
def test_handle_trigger_metadata_should_return_empty_dict_when_metadata_missing(
service: WorkflowAppService,
) -> None:
# Arrange
# Act
result = service.handle_trigger_metadata("tenant-1", None)
# Assert
assert result == {}
def test_handle_trigger_metadata_should_enrich_plugin_icons_for_trigger_plugin(
service: WorkflowAppService,
mocker: MockerFixture,
) -> None:
# Arrange
meta = {
"type": AppTriggerType.TRIGGER_PLUGIN.value,
"icon_filename": "light.png",
"icon_dark_filename": "dark.png",
}
mock_icon = mocker.patch(
"services.workflow_app_service.PluginService.get_plugin_icon_url",
side_effect=["https://cdn/light.png", "https://cdn/dark.png"],
)
# Act
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
# Assert
assert result["icon"] == "https://cdn/light.png"
assert result["icon_dark"] == "https://cdn/dark.png"
assert mock_icon.call_count == 2
def test_handle_trigger_metadata_should_return_non_plugin_metadata_without_icon_lookup(
service: WorkflowAppService,
mocker: MockerFixture,
) -> None:
# Arrange
meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value}
mock_icon = mocker.patch("services.workflow_app_service.PluginService.get_plugin_icon_url")
# Act
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
# Assert
assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value
mock_icon.assert_not_called()
@pytest.mark.parametrize(
("value", "expected"),
[
(None, None),
("", None),
('{"k":"v"}', {"k": "v"}),
("not-json", None),
({"raw": True}, {"raw": True}),
],
)
def test_safe_json_loads_should_handle_various_inputs(
value: object,
expected: object,
service: WorkflowAppService,
) -> None:
# Arrange
# Act
result = service._safe_json_loads(value)
# Assert
assert result == expected
def test_safe_parse_uuid_should_return_none_for_short_or_invalid_values(service: WorkflowAppService) -> None:
# Arrange
# Act
short_result = service._safe_parse_uuid("short")
invalid_result = service._safe_parse_uuid("x" * 40)
# Assert
assert short_result is None
assert invalid_result is None
def test_safe_parse_uuid_should_return_uuid_for_valid_uuid_string(service: WorkflowAppService) -> None:
# Arrange
raw_uuid = str(uuid.uuid4())
# Act
result = service._safe_parse_uuid(raw_uuid)
# Assert
assert result is not None
assert str(result) == raw_uuid

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,576 @@
from __future__ import annotations
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from models.account import Tenant
# ---------------------------------------------------------------------------
# Constants used throughout the tests
# ---------------------------------------------------------------------------
TENANT_ID = "tenant-abc"
ACCOUNT_ID = "account-xyz"
FILES_BASE_URL = "https://files.example.com"
DB_PATH = "services.workspace_service.db"
FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features"
TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles"
DIFY_CONFIG_PATH = "services.workspace_service.dify_config"
CURRENT_USER_PATH = "services.workspace_service.current_user"
CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool"
# ---------------------------------------------------------------------------
# Helpers / factories
# ---------------------------------------------------------------------------
def _make_tenant(
tenant_id: str = TENANT_ID,
name: str = "My Workspace",
plan: str = "sandbox",
status: str = "active",
custom_config: dict | None = None,
) -> Tenant:
"""Create a minimal Tenant-like namespace."""
return cast(
Tenant,
SimpleNamespace(
id=tenant_id,
name=name,
plan=plan,
status=status,
created_at="2024-01-01T00:00:00Z",
custom_config_dict=custom_config or {},
),
)
def _make_feature(
can_replace_logo: bool = False,
next_credit_reset_date: str | None = None,
billing_plan: str = "sandbox",
) -> MagicMock:
"""Create a feature namespace matching what FeatureService.get_features returns."""
feature = MagicMock()
feature.can_replace_logo = can_replace_logo
feature.next_credit_reset_date = next_credit_reset_date
feature.billing.subscription.plan = billing_plan
return feature
def _make_pool(quota_limit: int, quota_used: int) -> MagicMock:
pool = MagicMock()
pool.quota_limit = quota_limit
pool.quota_used = quota_used
return pool
def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace:
return SimpleNamespace(role=role)
def _tenant_info(result: object) -> dict[str, Any] | None:
return cast(dict[str, Any] | None, result)
# ---------------------------------------------------------------------------
# Shared fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_current_user() -> SimpleNamespace:
"""Return a lightweight current_user stand-in."""
return SimpleNamespace(id=ACCOUNT_ID)
@pytest.fixture
def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
"""
Patch the common external boundaries used by WorkspaceService.get_tenant_info.
Returns a dict of named mocks so individual tests can customise them.
"""
mocker.patch(CURRENT_USER_PATH, mock_current_user)
mock_db_session = mocker.patch(f"{DB_PATH}.session")
mock_query_chain = MagicMock()
mock_db_session.query.return_value = mock_query_chain
mock_query_chain.where.return_value = mock_query_chain
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature())
mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False)
mock_config = mocker.patch(DIFY_CONFIG_PATH)
mock_config.EDITION = "SELF_HOSTED"
mock_config.FILES_URL = FILES_BASE_URL
return {
"db_session": mock_db_session,
"query_chain": mock_query_chain,
"get_features": mock_feature,
"has_roles": mock_has_roles,
"config": mock_config,
}
# ---------------------------------------------------------------------------
# 1. None Tenant Handling
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None:
"""get_tenant_info should short-circuit and return None for a falsy tenant."""
from services.workspace_service import WorkspaceService
# Arrange
tenant = None
# Act
result = WorkspaceService.get_tenant_info(cast(Tenant, tenant))
# Assert
assert result is None
def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None:
"""get_tenant_info treats any falsy value as absent (e.g. empty string, 0)."""
from services.workspace_service import WorkspaceService
# Arrange / Act / Assert
assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# 2. Basic Tenant Info — happy path
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_return_base_fields(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""get_tenant_info should always return the six base scalar fields."""
from services.workspace_service import WorkspaceService
# Arrange
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["id"] == TENANT_ID
assert result["name"] == "My Workspace"
assert result["plan"] == "sandbox"
assert result["status"] == "active"
assert result["created_at"] == "2024-01-01T00:00:00Z"
assert result["trial_end_reason"] is None
def test_get_tenant_info_should_populate_role_from_tenant_account_join(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""The 'role' field should be taken from TenantAccountJoin, not the default."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin")
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["role"] == "admin"
def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""
The service asserts that TenantAccountJoin exists.
Missing join should raise AssertionError.
"""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["query_chain"].first.return_value = None
tenant = _make_tenant()
# Act + Assert
with pytest.raises(AssertionError, match="TenantAccountJoin not found"):
WorkspaceService.get_tenant_info(tenant)
# ---------------------------------------------------------------------------
# 3. Logo Customisation
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""custom_config block should appear for OWNER/ADMIN when can_replace_logo is True."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant(
custom_config={
"replace_webapp_logo": True,
"remove_webapp_brand": True,
}
)
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "custom_config" in result
assert result["custom_config"]["remove_webapp_brand"] is True
expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo"
assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url
def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""replace_webapp_logo should be None when custom_config_dict does not have the key."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["custom_config"]["replace_webapp_logo"] is None
def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""custom_config should be absent when can_replace_logo is False."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "custom_config" not in result
def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""custom_config block is gated on OWNER or ADMIN role."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = False # regular member
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "custom_config" not in result
def test_get_tenant_info_should_use_files_url_for_logo_url(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""The logo URL should use dify_config.FILES_URL as the base."""
from services.workspace_service import WorkspaceService
# Arrange
custom_base = "https://cdn.mycompany.io"
basic_mocks["config"].FILES_URL = custom_base
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant(custom_config={"replace_webapp_logo": True})
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base)
# ---------------------------------------------------------------------------
# 4. Cloud-Edition Credit Features
# ---------------------------------------------------------------------------
CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX
@pytest.fixture
def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
"""Patches for CLOUD edition tests, billing plan = professional by default."""
mocker.patch(CURRENT_USER_PATH, mock_current_user)
mock_db_session = mocker.patch(f"{DB_PATH}.session")
mock_query_chain = MagicMock()
mock_db_session.query.return_value = mock_query_chain
mock_query_chain.where.return_value = mock_query_chain
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
mock_feature = mocker.patch(
FEATURE_SERVICE_PATH,
return_value=_make_feature(
can_replace_logo=False,
next_credit_reset_date="2025-02-01",
billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX,
),
)
mocker.patch(TENANT_SERVICE_PATH, return_value=False)
mock_config = mocker.patch(DIFY_CONFIG_PATH)
mock_config.EDITION = "CLOUD"
mock_config.FILES_URL = FILES_BASE_URL
return {
"db_session": mock_db_session,
"query_chain": mock_query_chain,
"get_features": mock_feature,
"config": mock_config,
}
def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""next_credit_reset_date should be present in CLOUD edition."""
from services.workspace_service import WorkspaceService
# Arrange
mocker.patch(
CREDIT_POOL_SERVICE_PATH,
side_effect=[None, None], # both paid and trial pools absent
)
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["next_credit_reset_date"] == "2025-02-01"
def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""trial_credits/trial_credits_used come from the paid pool when conditions are met."""
from services.workspace_service import WorkspaceService
# Arrange
paid_pool = _make_pool(quota_limit=1000, quota_used=200)
mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool)
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 1000
assert result["trial_credits_used"] == 200
def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""quota_limit == -1 means unlimited; service should still use the paid pool."""
from services.workspace_service import WorkspaceService
# Arrange
paid_pool = _make_pool(quota_limit=-1, quota_used=999)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == -1
assert result["trial_credits_used"] == 999
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""When paid pool is exhausted (used >= limit), switch to trial pool."""
from services.workspace_service import WorkspaceService
# Arrange
paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full
trial_pool = _make_pool(quota_limit=100, quota_used=10)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 100
assert result["trial_credits_used"] == 10
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""When paid_pool is None, fall back to trial pool."""
from services.workspace_service import WorkspaceService
# Arrange
trial_pool = _make_pool(quota_limit=50, quota_used=5)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 50
assert result["trial_credits_used"] == 5
def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""
When the subscription plan IS SANDBOX, the paid pool branch is skipped
entirely and we fall back to the trial pool.
"""
from enums.cloud_plan import CloudPlan
from services.workspace_service import WorkspaceService
# Arrange — override billing plan to SANDBOX
cloud_mocks["get_features"].return_value = _make_feature(
next_credit_reset_date="2025-02-01",
billing_plan=CloudPlan.SANDBOX,
)
paid_pool = _make_pool(quota_limit=1000, quota_used=0)
trial_pool = _make_pool(quota_limit=200, quota_used=20)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 200
assert result["trial_credits_used"] == 20
def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""When both paid and trial pools are absent, trial_credits should not be set."""
from services.workspace_service import WorkspaceService
# Arrange
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "trial_credits" not in result
assert "trial_credits_used" not in result
# ---------------------------------------------------------------------------
# 5. Self-hosted / Non-Cloud Edition
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode."""
from services.workspace_service import WorkspaceService
# Arrange (basic_mocks already sets EDITION = "SELF_HOSTED")
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "next_credit_reset_date" not in result
assert "trial_credits" not in result
assert "trial_credits_used" not in result
# ---------------------------------------------------------------------------
# 6. DB query integrity
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""
The DB query for TenantAccountJoin must be scoped to the correct
tenant_id and current_user.id.
"""
from services.workspace_service import WorkspaceService
# Arrange
tenant = _make_tenant(tenant_id="my-special-tenant")
mock_current_user = mocker.patch(CURRENT_USER_PATH)
mock_current_user.id = "special-user-id"
# Act
WorkspaceService.get_tenant_info(tenant)
# Assert — db.session.query was invoked (at least once)
basic_mocks["db_session"].query.assert_called()

View File

@ -0,0 +1,643 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.tools.entities.tool_entities import ApiProviderSchemaType
from services.tools.api_tools_manage_service import ApiToolManageService
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
# Arrange
mocked_db = mocker.patch("services.tools.api_tools_manage_service.db")
mocked_db.session = MagicMock()
return mocked_db
def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace:
return SimpleNamespace(operation_id=operation_id)
def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value),
)
# Act
result = ApiToolManageService.parser_api_schema("valid-schema")
# Assert
assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value
assert len(result["credentials_schema"]) == 3
assert "warning" in result
def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=RuntimeError("bad schema"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"):
ApiToolManageService.parser_api_schema("invalid")
def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None:
# Arrange
expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER)
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=expected,
)
extra_info: dict[str, str] = {}
# Act
result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info)
# Assert
assert result == expected
def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=ValueError("parse failed"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema: parse failed"):
ApiToolManageService.convert_schema_to_tool_bundles("schema")
def test_create_api_tool_provider_should_raise_error_when_provider_already_exists(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = object()
# Act + Assert
with pytest.raises(ValueError, match="provider provider-a already exists"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name=" provider-a ",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
many_tools = [_tool_bundle(str(i)) for i in range(101)]
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=(many_tools, ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="the number of apis should be less than 100"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_create_provider_when_input_is_valid(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
mock_controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=mock_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.encrypt.return_value = {"auth_type": "none"}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
# Act
result = ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=["news"],
)
# Assert
assert result == {"result": "success"}
mock_controller.load_bundled_tools.assert_called_once()
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.get",
return_value=SimpleNamespace(status_code=200, text="schema-content"),
)
mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True})
# Act
result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
# Assert
assert result == {"schema": "schema-content"}
@pytest.mark.parametrize("status_code", [400, 404, 500])
def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid(
status_code: int,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.get",
return_value=SimpleNamespace(status_code=status_code, text="schema-content"),
)
mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger")
# Act + Assert
with pytest.raises(ValueError, match="invalid schema, please check the url you provided"):
ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
mock_logger.exception.assert_called_once()
def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found(
mock_db: MagicMock,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="you have not added provider provider-a"):
ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")])
mock_db.session.query.return_value.where.return_value.first.return_value = provider
controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
return_value=controller,
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"])
mock_convert = mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}],
)
# Act
result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
# Assert
assert result == [{"name": "tool-a"}, {"name": "tool-b"}]
assert mock_convert.call_count == 2
def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found(
mock_db: MagicMock,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="api provider provider-a does not exists"):
ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
original_provider="provider-a",
icon={},
credentials={"auth_type": "none"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy=None,
custom_disclaimer="custom",
labels=[],
)
def test_update_api_tool_provider_should_raise_error_when_auth_type_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(credentials={}, name="old")
mock_db.session.query.return_value.where.return_value.first.return_value = provider
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
original_provider="provider-a",
icon={},
credentials={},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy=None,
custom_disclaimer="custom",
labels=[],
)
def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(
credentials={"auth_type": "none", "api_key_value": "encrypted-old"},
name="old",
icon="",
schema="",
description="",
schema_type_str="",
tools_str="",
privacy_policy="",
custom_disclaimer="",
credentials_str="",
)
mock_db.session.query.return_value.where.return_value.first.return_value = provider
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=controller,
)
cache = MagicMock()
encrypter = MagicMock()
encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"}
encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"}
encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(encrypter, cache),
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
# Act
result = ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-new",
original_provider="provider-old",
icon={"emoji": "E"},
credentials={"auth_type": "none", "api_key_value": "***"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=["news"],
)
# Assert
assert result == {"result": "success"}
assert provider.name == "provider-new"
assert provider.privacy_policy == "privacy"
assert provider.credentials_str != ""
cache.delete.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="you have not added provider provider-a"):
ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None:
# Arrange
provider = object()
mock_db.session.query.return_value.where.return_value.first.return_value = provider
# Act
result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
# Assert
assert result == {"result": "success"}
mock_db.session.delete.assert_called_once_with(provider)
mock_db.session.commit.assert_called_once()
def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None:
# Arrange
expected = {"provider": "value"}
mock_get = mocker.patch(
"services.tools.api_tools_manage_service.ToolManager.user_get_api_provider",
return_value=expected,
)
# Act
result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a")
# Assert
assert result == expected
mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1")
def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None:
# Arrange
schema_type = "bad-schema-type"
# Act + Assert
with pytest.raises(ValueError, match="invalid schema type"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=schema_type, # type: ignore[arg-type]
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=RuntimeError("invalid"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
# Act + Assert
with pytest.raises(ValueError, match="invalid tool name tool-b"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-b",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_auth_type_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
provider_controller = MagicMock()
tool_obj = MagicMock()
tool_obj.fork_tool_runtime.return_value = tool_obj
tool_obj.validate_credentials.side_effect = ValueError("validation failed")
provider_controller.get_tool.return_value = tool_obj
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=provider_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
mock_encrypter.mask_plugin_credentials.return_value = {}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
# Act
result = ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
# Assert
assert result == {"error": "validation failed"}
def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
provider_controller = MagicMock()
tool_obj = MagicMock()
tool_obj.fork_tool_runtime.return_value = tool_obj
tool_obj.validate_credentials.return_value = {"ok": True}
provider_controller.get_tool.return_value = tool_obj
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=provider_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
mock_encrypter.mask_plugin_credentials.return_value = {}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
# Act
result = ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={"x": "1"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
# Assert
assert result == {"result": {"ok": True}}
def test_list_api_tools_should_return_all_user_providers_with_converted_tools(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_one = SimpleNamespace(name="p1")
provider_two = SimpleNamespace(name="p2")
mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two]
controller_one = MagicMock()
controller_one.get_tools.return_value = ["tool-a"]
controller_two = MagicMock()
controller_two.get_tools.return_value = ["tool-b", "tool-c"]
user_provider_one = SimpleNamespace(labels=[], tools=[])
user_provider_two = SimpleNamespace(labels=[], tools=[])
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
side_effect=[controller_one, controller_two],
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"])
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider",
side_effect=[user_provider_one, user_provider_two],
)
mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider")
mock_convert = mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}],
)
# Act
result = ApiToolManageService.list_api_tools("tenant-1")
# Assert
assert len(result) == 2
assert user_provider_one.tools == [{"name": "tool-a"}]
assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}]
assert mock_convert.call_count == 3

File diff suppressed because it is too large Load Diff

View File

@ -324,79 +324,66 @@ packages:
resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==}
cpu: [arm]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-arm-musleabihf@4.59.0':
resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==}
cpu: [arm]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-arm64-gnu@4.59.0':
resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==}
cpu: [arm64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-arm64-musl@4.59.0':
resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==}
cpu: [arm64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-loong64-gnu@4.59.0':
resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==}
cpu: [loong64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-loong64-musl@4.59.0':
resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==}
cpu: [loong64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-ppc64-gnu@4.59.0':
resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==}
cpu: [ppc64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-ppc64-musl@4.59.0':
resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==}
cpu: [ppc64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-riscv64-gnu@4.59.0':
resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==}
cpu: [riscv64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-riscv64-musl@4.59.0':
resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==}
cpu: [riscv64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-s390x-gnu@4.59.0':
resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==}
cpu: [s390x]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-x64-gnu@4.59.0':
resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==}
cpu: [x64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-x64-musl@4.59.0':
resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==}
cpu: [x64]
os: [linux]
libc: [musl]
'@rollup/rollup-openbsd-x64@4.59.0':
resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==}

View File

@ -6,19 +6,23 @@ NEXT_PUBLIC_EDITION=SELF_HOSTED
NEXT_PUBLIC_BASE_PATH=
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
# different from api or web app domain.
# example: http://cloud.dify.ai/console/api
# example: https://cloud.dify.ai/console/api
NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api
# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from
# console or api domain.
# example: http://udify.app/api
# example: https://udify.app/api
NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
# Dev-only Hono proxy targets. The frontend keeps requesting http://localhost:5001 directly.
# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1.
NEXT_PUBLIC_COOKIE_DOMAIN=
# Dev-only Hono proxy targets.
# The frontend keeps requesting http://localhost:5001 directly,
# the proxy server will forward the request to the target server,
# so that you don't need to run a separate backend server and use online API in development.
HONO_PROXY_HOST=127.0.0.1
HONO_PROXY_PORT=5001
HONO_CONSOLE_API_PROXY_TARGET=
HONO_PUBLIC_API_PROXY_TARGET=
# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1.
NEXT_PUBLIC_COOKIE_DOMAIN=
# The API PREFIX for MARKETPLACE
NEXT_PUBLIC_MARKETPLACE_API_PREFIX=https://marketplace.dify.ai/api/v1

View File

@ -1,6 +1,6 @@
# Dify Frontend
This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app).
This is a [Next.js] project, but you can dev with [vinext].
## Getting Started
@ -8,8 +8,11 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next
Before starting the web frontend service, please make sure the following environment is ready.
- [Node.js](https://nodejs.org)
- [pnpm](https://pnpm.io)
- [Node.js]
- [pnpm]
You can also use [Vite+] with the corresponding `vp` commands.
For example, use `vp install` instead of `pnpm install` and `vp test` instead of `pnpm run test`.
> [!TIP]
> It is recommended to install and enable Corepack to manage package manager versions automatically:
@ -19,7 +22,7 @@ Before starting the web frontend service, please make sure the following environ
> corepack enable
> ```
>
> Learn more: [Corepack](https://github.com/nodejs/corepack#readme)
> Learn more: [Corepack]
First, install the dependencies:
@ -27,31 +30,14 @@ First, install the dependencies:
pnpm install
```
Then, configure the environment variables. Create a file named `.env.local` in the current directory and copy the contents from `.env.example`. Modify the values of these environment variables according to your requirements:
Then, configure the environment variables.
Create a file named `.env.local` in the current directory and copy the contents from `.env.example`.
Modify the values of these environment variables according to your requirements:
```bash
cp .env.example .env.local
```
```txt
# For production release, change this to PRODUCTION
NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT
# The deployment edition, SELF_HOSTED
NEXT_PUBLIC_EDITION=SELF_HOSTED
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
# different from api or web app domain.
# example: http://cloud.dify.ai/console/api
NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api
NEXT_PUBLIC_COOKIE_DOMAIN=
# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from
# console or api domain.
# example: http://udify.app/api
NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
# SENTRY
NEXT_PUBLIC_SENTRY_DSN=
```
> [!IMPORTANT]
>
> 1. When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. The frontend and backend must be under the same top-level domain in order to share authentication cookies.
@ -61,11 +47,16 @@ Finally, run the development server:
```bash
pnpm run dev
# or if you are using vinext which provides a better development experience
pnpm run dev:vinext
# (optional) start the dev proxy server so that you can use online API in development
pnpm run dev:proxy
```
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
Open <http://localhost:3000> with your browser to see the result.
You can start editing the file under folder `app`. The page auto-updates as you edit the file.
You can start editing the file under folder `app`.
The page auto-updates as you edit the file.
## Deploy
@ -91,7 +82,7 @@ pnpm run start --port=3001 --host=0.0.0.0
## Storybook
This project uses [Storybook](https://storybook.js.org/) for UI component development.
This project uses [Storybook] for UI component development.
To start the storybook server, run:
@ -99,19 +90,24 @@ To start the storybook server, run:
pnpm storybook
```
Open [http://localhost:6006](http://localhost:6006) with your browser to see the result.
Open <http://localhost:6006> with your browser to see the result.
## Lint Code
If your IDE is VSCode, rename `web/.vscode/settings.example.json` to `web/.vscode/settings.json` for lint code setting.
Then follow the [Lint Documentation](./docs/lint.md) to lint the code.
Then follow the [Lint Documentation] to lint the code.
## Test
We use [Vitest](https://vitest.dev/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing.
We use [Vitest] and [React Testing Library] for Unit Testing.
**📖 Complete Testing Guide**: See [web/testing/testing.md](./testing/testing.md) for detailed testing specifications, best practices, and examples.
**📖 Complete Testing Guide**: See [web/docs/test.md] for detailed testing specifications, best practices, and examples.
> [!IMPORTANT]
> As we are using Vite+, the `vitest` command is not available.
> Please make sure to run tests with `vp` commands.
> For example, use `npx vp test` instead of `npx vitest`.
Run test:
@ -119,12 +115,17 @@ Run test:
pnpm test
```
> [!NOTE]
> Our test is not fully stable yet, and we are actively working on improving it.
> If you encounter test failures only in CI but not locally, please feel free to ignore them and report the issue to us.
> You can try to re-run the test in CI, and it may pass successfully.
### Example Code
If you are not familiar with writing tests, refer to:
- [classnames.spec.ts](./utils/classnames.spec.ts) - Utility function test example
- [index.spec.tsx](./app/components/base/button/index.spec.tsx) - Component test example
- [classnames.spec.ts] - Utility function test example
- [index.spec.tsx] - Component test example
### Analyze Component Complexity
@ -134,7 +135,7 @@ Before writing tests, use the script to analyze component complexity:
pnpm analyze-component app/components/your-component/index.tsx
```
This will help you determine the testing strategy. See [web/testing/testing.md](./testing/testing.md) for details.
This will help you determine the testing strategy. See [web/testing/testing.md] for details.
## Documentation
@ -142,4 +143,19 @@ Visit <https://docs.dify.ai> to view the full documentation.
## Community
The Dify community can be found on [Discord community](https://discord.gg/5AEfbxcd9k), where you can ask questions, voice ideas, and share your projects.
The Dify community can be found on [Discord community], where you can ask questions, voice ideas, and share your projects.
[Corepack]: https://github.com/nodejs/corepack#readme
[Discord community]: https://discord.gg/5AEfbxcd9k
[Lint Documentation]: ./docs/lint.md
[Next.js]: https://nextjs.org
[Node.js]: https://nodejs.org
[React Testing Library]: https://testing-library.com/docs/react-testing-library/intro
[Storybook]: https://storybook.js.org
[Vite+]: https://viteplus.dev
[Vitest]: https://vitest.dev
[classnames.spec.ts]: ./utils/classnames.spec.ts
[index.spec.tsx]: ./app/components/base/button/index.spec.tsx
[pnpm]: https://pnpm.io
[vinext]: https://github.com/cloudflare/vinext
[web/docs/test.md]: ./docs/test.md

View File

@ -95,7 +95,7 @@ describe('Cloud Plan Payment Flow', () => {
beforeEach(() => {
vi.clearAllMocks()
cleanup()
toast.close()
toast.dismiss()
setupAppContext()
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' })
mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' })

View File

@ -66,7 +66,7 @@ describe('Self-Hosted Plan Flow', () => {
beforeEach(() => {
vi.clearAllMocks()
cleanup()
toast.close()
toast.dismiss()
setupAppContext()
// Mock window.location with minimal getter/setter (Location props are non-enumerable)

View File

@ -11,8 +11,8 @@ import SideBar from '@/app/components/explore/sidebar'
import { MediaType } from '@/hooks/use-breakpoints'
import { AppModeEnum } from '@/types/app'
const { mockToastAdd } = vi.hoisted(() => ({
mockToastAdd: vi.fn(),
const { mockToastSuccess } = vi.hoisted(() => ({
mockToastSuccess: vi.fn(),
}))
let mockMediaType: string = MediaType.pc
@ -53,14 +53,16 @@ vi.mock('@/service/use-explore', () => ({
}),
}))
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
add: mockToastAdd,
close: vi.fn(),
update: vi.fn(),
promise: vi.fn(),
},
}))
vi.mock('@/app/components/base/ui/toast', async (importOriginal) => {
const actual = await importOriginal<typeof import('@/app/components/base/ui/toast')>()
return {
...actual,
toast: {
...actual.toast,
success: mockToastSuccess,
},
}
})
const createInstalledApp = (overrides: Partial<InstalledApp> = {}): InstalledApp => ({
id: overrides.id ?? 'app-1',
@ -105,9 +107,7 @@ describe('Sidebar Lifecycle Flow', () => {
await waitFor(() => {
expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: true })
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
type: 'success',
}))
expect(mockToastSuccess).toHaveBeenCalled()
})
// Step 2: Simulate refetch returning pinned state, then unpin
@ -124,9 +124,7 @@ describe('Sidebar Lifecycle Flow', () => {
await waitFor(() => {
expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: false })
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
type: 'success',
}))
expect(mockToastSuccess).toHaveBeenCalled()
})
})
@ -150,10 +148,7 @@ describe('Sidebar Lifecycle Flow', () => {
// Step 4: Uninstall API called and success toast shown
await waitFor(() => {
expect(mockUninstall).toHaveBeenCalledWith('app-1')
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({
type: 'success',
title: 'common.api.remove',
}))
expect(mockToastSuccess).toHaveBeenCalledWith('common.api.remove')
})
})

View File

@ -24,17 +24,11 @@ export default function CheckCode() {
const verify = async () => {
try {
if (!code.trim()) {
toast.add({
type: 'error',
title: t('checkCode.emptyCode', { ns: 'login' }),
})
toast.error(t('checkCode.emptyCode', { ns: 'login' }))
return
}
if (!/\d{6}/.test(code)) {
toast.add({
type: 'error',
title: t('checkCode.invalidCode', { ns: 'login' }),
})
toast.error(t('checkCode.invalidCode', { ns: 'login' }))
return
}
setIsLoading(true)

View File

@ -27,15 +27,12 @@ export default function CheckCode() {
const handleGetEMailVerificationCode = async () => {
try {
if (!email) {
toast.add({ type: 'error', title: t('error.emailEmpty', { ns: 'login' }) })
toast.error(t('error.emailEmpty', { ns: 'login' }))
return
}
if (!emailRegex.test(email)) {
toast.add({
type: 'error',
title: t('error.emailInValid', { ns: 'login' }),
})
toast.error(t('error.emailInValid', { ns: 'login' }))
return
}
setIsLoading(true)
@ -48,16 +45,10 @@ export default function CheckCode() {
router.push(`/webapp-reset-password/check-code?${params.toString()}`)
}
else if (res.code === 'account_not_found') {
toast.add({
type: 'error',
title: t('error.registrationNotAllowed', { ns: 'login' }),
})
toast.error(t('error.registrationNotAllowed', { ns: 'login' }))
}
else {
toast.add({
type: 'error',
title: res.data,
})
toast.error(res.data)
}
}
catch (error) {

View File

@ -24,10 +24,7 @@ const ChangePasswordForm = () => {
const [showConfirmPassword, setShowConfirmPassword] = useState(false)
const showErrorMessage = useCallback((message: string) => {
toast.add({
type: 'error',
title: message,
})
toast.error(message)
}, [])
const getSignInUrl = () => {

View File

@ -43,24 +43,15 @@ export default function CheckCode() {
try {
const appCode = getAppCodeFromRedirectUrl()
if (!code.trim()) {
toast.add({
type: 'error',
title: t('checkCode.emptyCode', { ns: 'login' }),
})
toast.error(t('checkCode.emptyCode', { ns: 'login' }))
return
}
if (!/\d{6}/.test(code)) {
toast.add({
type: 'error',
title: t('checkCode.invalidCode', { ns: 'login' }),
})
toast.error(t('checkCode.invalidCode', { ns: 'login' }))
return
}
if (!redirectUrl || !appCode) {
toast.add({
type: 'error',
title: t('error.redirectUrlMissing', { ns: 'login' }),
})
toast.error(t('error.redirectUrlMissing', { ns: 'login' }))
return
}
setIsLoading(true)

View File

@ -17,10 +17,7 @@ const ExternalMemberSSOAuth = () => {
const redirectUrl = searchParams.get('redirect_url')
const showErrorToast = (message: string) => {
toast.add({
type: 'error',
title: message,
})
toast.error(message)
}
const getAppCodeFromRedirectUrl = useCallback(() => {

View File

@ -22,15 +22,12 @@ export default function MailAndCodeAuth() {
const handleGetEMailVerificationCode = async () => {
try {
if (!email) {
toast.add({ type: 'error', title: t('error.emailEmpty', { ns: 'login' }) })
toast.error(t('error.emailEmpty', { ns: 'login' }))
return
}
if (!emailRegex.test(email)) {
toast.add({
type: 'error',
title: t('error.emailInValid', { ns: 'login' }),
})
toast.error(t('error.emailInValid', { ns: 'login' }))
return
}
setIsLoading(true)

View File

@ -46,26 +46,20 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
const appCode = getAppCodeFromRedirectUrl()
const handleEmailPasswordLogin = async () => {
if (!email) {
toast.add({ type: 'error', title: t('error.emailEmpty', { ns: 'login' }) })
toast.error(t('error.emailEmpty', { ns: 'login' }))
return
}
if (!emailRegex.test(email)) {
toast.add({
type: 'error',
title: t('error.emailInValid', { ns: 'login' }),
})
toast.error(t('error.emailInValid', { ns: 'login' }))
return
}
if (!password?.trim()) {
toast.add({ type: 'error', title: t('error.passwordEmpty', { ns: 'login' }) })
toast.error(t('error.passwordEmpty', { ns: 'login' }))
return
}
if (!redirectUrl || !appCode) {
toast.add({
type: 'error',
title: t('error.redirectUrlMissing', { ns: 'login' }),
})
toast.error(t('error.redirectUrlMissing', { ns: 'login' }))
return
}
try {
@ -94,15 +88,12 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
router.replace(decodeURIComponent(redirectUrl))
}
else {
toast.add({
type: 'error',
title: res.data,
})
toast.error(res.data)
}
}
catch (e: any) {
if (e.code === 'authentication_failed')
toast.add({ type: 'error', title: e.message })
toast.error(e.message)
}
finally {
setIsLoading(false)

View File

@ -37,10 +37,7 @@ const SSOAuth: FC<SSOAuthProps> = ({
const handleSSOLogin = () => {
const appCode = getAppCodeFromRedirectUrl()
if (!redirectUrl || !appCode) {
toast.add({
type: 'error',
title: t('error.invalidRedirectUrlOrAppCode', { ns: 'login' }),
})
toast.error(t('error.invalidRedirectUrlOrAppCode', { ns: 'login' }))
return
}
setIsLoading(true)
@ -66,10 +63,7 @@ const SSOAuth: FC<SSOAuthProps> = ({
})
}
else {
toast.add({
type: 'error',
title: t('error.invalidSSOProtocol', { ns: 'login' }),
})
toast.error(t('error.invalidSSOProtocol', { ns: 'login' }))
setIsLoading(false)
}
}

View File

@ -91,10 +91,7 @@ export default function OAuthAuthorize() {
globalThis.location.href = url.toString()
}
catch (err: any) {
toast.add({
type: 'error',
title: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`,
})
toast.error(`${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`)
}
}
@ -102,11 +99,10 @@ export default function OAuthAuthorize() {
const invalidParams = !client_id || !redirect_uri
if ((invalidParams || isError) && !hasNotifiedRef.current) {
hasNotifiedRef.current = true
toast.add({
type: 'error',
title: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }),
timeout: 0,
})
toast.error(
invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }),
{ timeout: 0 },
)
}
}, [client_id, redirect_uri, isError])

View File

@ -137,10 +137,7 @@ const Apps = ({
})
setIsShowCreateModal(false)
toast.add({
type: 'success',
title: t('newApp.appCreated', { ns: 'app' }),
})
toast.success(t('newApp.appCreated', { ns: 'app' }))
if (onSuccess)
onSuccess()
if (app.app_id)
@ -149,7 +146,7 @@ const Apps = ({
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id!, mode }, push)
}
catch {
toast.add({ type: 'error', title: t('newApp.appCreateFailed', { ns: 'app' }) })
toast.error(t('newApp.appCreateFailed', { ns: 'app' }))
}
}

View File

@ -141,6 +141,145 @@ describe('useChat', () => {
expect(result.current.chatList[0].suggestedQuestions).toEqual(['Ask Bob'])
})
describe('opening statement referential stability', () => {
it('should keep the same item reference across multiple streaming chatTree mutations', () => {
let callbacks: HookCallbacks
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
callbacks = options as HookCallbacks
})
const config = {
opening_statement: 'Welcome!',
suggested_questions: ['Q1', 'Q2'],
}
const { result } = renderHook(() => useChat(config as ChatConfig))
const openerInitial = result.current.chatList[0]
expect(openerInitial.isOpeningStatement).toBe(true)
expect(openerInitial.content).toBe('Welcome!')
act(() => {
result.current.handleSend('url', { query: 'hello' }, {})
})
act(() => {
callbacks.onWorkflowStarted({ workflow_run_id: 'wr-1', task_id: 't-1' })
})
expect(result.current.chatList[0]).toBe(openerInitial)
act(() => {
callbacks.onData('chunk-1 ', true, { messageId: 'm-1', conversationId: 'c-1', taskId: 't-1' })
})
expect(result.current.chatList.length).toBeGreaterThan(1)
expect(result.current.chatList[0]).toBe(openerInitial)
act(() => {
callbacks.onData('chunk-2 ', false, { messageId: 'm-1' })
})
expect(result.current.chatList[0]).toBe(openerInitial)
act(() => {
callbacks.onData('chunk-3', false, { messageId: 'm-1' })
callbacks.onMessageEnd({ metadata: { retriever_resources: [] } })
callbacks.onWorkflowFinished({ data: { status: 'succeeded' } })
callbacks.onCompleted()
})
expect(result.current.chatList[0]).toBe(openerInitial)
expect(result.current.chatList.at(-1)!.content).toBe('chunk-1 chunk-2 chunk-3')
})
it('should keep stable reference when getIntroduction identity changes but output is identical', () => {
const config = {
opening_statement: 'Hello {{name}}',
suggested_questions: ['Ask about {{name}}'],
}
const { result, rerender } = renderHook(
({ fs }) => useChat(config as ChatConfig, fs as UseChatFormSettings),
{ initialProps: { fs: { inputs: { name: 'Alice' }, inputsForm: [] } } },
)
const openerBefore = result.current.chatList[0]
expect(openerBefore.content).toBe('Hello Alice')
expect(openerBefore.suggestedQuestions).toEqual(['Ask about Alice'])
rerender({ fs: { inputs: { name: 'Alice' }, inputsForm: [] } })
expect(result.current.chatList[0]).toBe(openerBefore)
})
it('should produce a new item when the processed content actually changes', () => {
const config = {
opening_statement: 'Hello {{name}}',
suggested_questions: ['Ask {{name}}'],
}
const { result, rerender } = renderHook(
({ fs }) => useChat(config as ChatConfig, fs as UseChatFormSettings),
{ initialProps: { fs: { inputs: { name: 'Alice' }, inputsForm: [] } } },
)
const before = result.current.chatList[0]
rerender({ fs: { inputs: { name: 'Bob' }, inputsForm: [] } })
const after = result.current.chatList[0]
expect(after).not.toBe(before)
expect(after.content).toBe('Hello Bob')
expect(after.suggestedQuestions).toEqual(['Ask Bob'])
})
it('should keep content and suggestedQuestions stable for opener already in prevChatTree even when sibling metadata changes', () => {
let callbacks: HookCallbacks
vi.mocked(ssePost).mockImplementation(async (_url, _params, options) => {
callbacks = options as HookCallbacks
})
const config = {
opening_statement: 'Hello updated',
suggested_questions: ['S1'],
}
const prevChatTree = [{
id: 'opening-statement',
content: 'old',
isAnswer: true,
isOpeningStatement: true,
suggestedQuestions: [],
}]
const { result } = renderHook(() =>
useChat(config as ChatConfig, undefined, prevChatTree as ChatItemInTree[]),
)
const openerBefore = result.current.chatList[0]
expect(openerBefore.content).toBe('Hello updated')
expect(openerBefore.suggestedQuestions).toEqual(['S1'])
const contentBefore = openerBefore.content
const suggestionsBefore = openerBefore.suggestedQuestions
act(() => {
result.current.handleSend('url', { query: 'msg' }, {})
})
act(() => {
callbacks.onData('resp', true, { messageId: 'm-1', conversationId: 'c-1', taskId: 't-1' })
})
expect(result.current.chatList.length).toBeGreaterThan(1)
const openerAfter = result.current.chatList[0]
expect(openerAfter.content).toBe(contentBefore)
expect(openerAfter.suggestedQuestions).toBe(suggestionsBefore)
})
it('should use a stable id of "opening-statement"', () => {
const { result } = renderHook(() =>
useChat({ opening_statement: 'Hi' } as ChatConfig),
)
expect(result.current.chatList[0].id).toBe('opening-statement')
})
})
describe('handleSend', () => {
it('should block send if already responding', async () => {
const { result } = renderHook(() => useChat())

Some files were not shown because too many files have changed in this diff Show More