Merge branch 'main' into feat/hitl-frontend

This commit is contained in:
twwu 2026-01-14 13:24:56 +08:00
commit dfb25df5ec
125 changed files with 4121 additions and 1224 deletions

View File

@ -35,7 +35,7 @@ from libs.rsa import generate_key_pair
from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID
@ -64,8 +64,10 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
normalized_email = email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
@ -86,7 +88,7 @@ def reset_password(email, new_password, password_confirm):
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(email)
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@ -102,20 +104,22 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
normalized_new_email = new_email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
email_validate(new_email)
email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = new_email
account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green"))
@ -660,7 +664,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
return
# Create account
email = email.strip()
email = email.strip().lower()
if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red"))

View File

@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseSettings):
"""
Configuration settings for Volcengine Tinder Object Storage (TOS)
Configuration settings for Volcengine Torch Object Storage (TOS)
"""
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(

View File

@ -63,10 +63,9 @@ class ActivateCheckApi(Resource):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args.workspace_id
reg_email = args.email
token = args.token
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
@ -100,11 +99,12 @@ class ActivateApi(Resource):
def post(self):
args = ActivatePayload.model_validate(console_ns.payload)
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
normalized_request_email = args.email.lower() if args.email else None
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
account = invitation["account"]
account.name = args.name

View File

@ -1,7 +1,6 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
@email_register_enabled
def post(self):
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
if args.language in languages:
language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
def post(self):
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
user_email = args.email
user_email = args.email.lower()
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args.email)
AccountService.add_email_register_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
user_email, code=args.code, additional_data={"phase": "register"}
)
AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_email_register_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/email-register")
@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
normalized_email = email.lower()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(email, args.password_confirm)
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(email)
AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}
def _create_new_account(self, email, password) -> Account | None:
def _create_new_account(self, email: str, password: str) -> Account | None:
# Create new account if allowed
account = None
try:

View File

@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
@email_password_login_enabled
def post(self):
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email(
account=account,
email=args.email,
email=normalized_email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
user_email = args.email
user_email = args.email.lower()
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args.email)
AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args.code, additional_data={"phase": "reset"}
token_email, code=args.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/forgot-password/resets")
@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)

View File

@ -90,32 +90,38 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
request_email = args.email
normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
invitation_data: dict[str, Any] | None = None
if args.invite_token:
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation_data is None:
invite_token = None
try:
if invitation_data:
data = invitation_data.get("data", {})
invitee_email = data.get("email") if data else None
if invitee_email != args.email:
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email:
raise InvalidEmailError()
account = AccountService.authenticate(args.email, args.password, args.invite_token)
else:
account = AccountService.authenticate(args.email, args.password)
account = _authenticate_account_with_case_fallback(
request_email, normalized_email, args.password, invite_token
)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError()
except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(normalized_email)
raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@ -130,7 +136,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args.email)
AccountService.reset_login_error_rate_limit(normalized_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args.email)
account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
email=args.email,
email=normalized_email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args.email)
account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args.email, language=language)
token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
else:
raise AccountNotFound()
else:
@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource):
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
user_email = args.email
original_email = args.email
user_email = original_email.lower()
language = args.language
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args.email:
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email:
raise InvalidEmailError()
if token_data["code"] != args.code:
@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args.token)
try:
account = AccountService.get_user_through_email(user_email)
account = _get_account_with_case_fallback(original_email)
except AccountRegisterError:
raise AccountInFreezeError()
if account:
@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args.email)
AccountService.reset_login_error_rate_limit(user_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@ -309,3 +320,22 @@ class RefreshTokenApi(Resource):
return response
except Exception as e:
return {"result": "fail", "message": str(e)}, 401
def _get_account_with_case_fallback(email: str):
account = AccountService.get_user_through_email(email)
if account or email == email.lower():
return account
return AccountService.get_user_through_email(email.lower())
def _authenticate_account_with_case_fallback(
original_email: str, normalized_email: str, password: str, invite_token: str | None
):
try:
return AccountService.authenticate(original_email, password, invite_token)
except services.errors.account.AccountPasswordError:
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)

View File

@ -3,7 +3,6 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@ -118,7 +117,10 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
if invitation_email != user_info.email:
invitation_email_normalized = (
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
)
if invitation_email_normalized != user_info.email.lower():
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
@ -175,7 +177,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account
@ -197,9 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
tenant_was_created.send(new_tenant)
if not account:
normalized_email = user_info.email.lower()
oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
@ -210,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
)
# Set interface language

View File

@ -7,7 +7,7 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@ -104,6 +104,15 @@ class DocumentRenamePayload(BaseModel):
name: str
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
status: str | None = Field(None, title="Status", description="Document status.")
fetch_val: str = Field("false", alias="fetch")
register_schema_models(
console_ns,
KnowledgeConfig,
@ -225,14 +234,16 @@ class DatasetDocumentListApi(Resource):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
sort = request.args.get("sort", default="-created_at", type=str)
status = request.args.get("status", default=None, type=str)
raw_args = request.args.to_dict()
param = DocumentDatasetListParam.model_validate(raw_args)
page = param.page
limit = param.limit
search = param.search
sort = param.sort_by
status = param.status
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
fetch_val = request.args.get("fetch", default="false")
fetch_val = param.fetch_val
if isinstance(fetch_val, bool):
fetch = fetch_val
else:

View File

@ -84,10 +84,11 @@ class SetupApi(Resource):
raise NotInitValidateError()
args = SetupRequestPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
# setup
RegisterService.setup(
email=args.email,
email=normalized_email,
name=args.name,
password=args.password,
ip_address=extract_remote_ip(request),

View File

@ -41,7 +41,7 @@ from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode
from models import AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource):
else:
language = "en-US"
account = None
user_email = args.email
user_email = None
email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email != current_user.email:
if user_email.lower() != current_user.email.lower():
raise InvalidEmailError()
user_email = current_user.email
else:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()
email_for_sending = account.email
user_email = account.email
token = AccountService.send_change_email_email(
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
account=account,
email=email_for_sending,
old_email=user_email,
language=language,
phase=args.phase,
)
return {"result": "success", "data": token}
@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource):
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
user_email = args.email
user_email = args.email.lower()
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(args.email)
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource):
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
AccountService.reset_change_email_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_change_email_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/account/change-email/reset")
@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource):
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
normalized_new_email = args.new_email.lower()
if AccountService.is_account_in_freeze(args.new_email):
if AccountService.is_account_in_freeze(normalized_new_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args.new_email):
if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token)
@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource):
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email != old_email:
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(
email=args.new_email,
email=normalized_new_email,
)
return updated_account
@ -645,8 +657,9 @@ class CheckEmailUnique(Resource):
def post(self):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
if AccountService.is_account_in_freeze(args.email):
normalized_email = args.email.lower()
if AccountService.is_account_in_freeze(normalized_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args.email):
if not AccountService.check_email_unique(normalized_email):
raise EmailAlreadyInUseError()
return {"result": "success"}

View File

@ -116,26 +116,31 @@ class MemberInviteEmailApi(Resource):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(invitee_email)
encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
return {
"result": "success",

View File

@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
@ -22,7 +21,7 @@ from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from models.account import Account
from services.account_service import AccountService
@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource):
def post(self):
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
request_email = payload.email
normalized_email = request_email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
if account is None:
raise AuthenticationFailedError()
else:
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
return {"result": "success", "data": token}
@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
user_email = payload.email
user_email = payload.email.lower()
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
raise InvalidEmailError()
if payload.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(payload.email)
AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=payload.code, additional_data={"phase": "reset"}
token_email, code=payload.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(payload.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@web_ns.route("/forgot-password/resets")
@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)

View File

@ -197,25 +197,29 @@ class EmailCodeLoginApi(Resource):
)
args = parser.parse_args()
user_email = args["email"]
user_email = args["email"].lower()
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if normalized_token_email != user_email:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"])
account = WebAppAuthService.get_user_through_email(user_email)
account = WebAppAuthService.get_user_through_email(token_email)
if not account:
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(args["email"])
AccountService.reset_login_error_rate_limit(user_email)
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response

View File

@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
),
)
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
@ -200,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message)

View File

@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables,
)
@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager,
)
def _initialize_conversation_variables(self) -> list[VariableUnion]:
def _initialize_conversation_variables(self) -> list[Variable]:
"""
Initialize conversation variables for the current conversation.
@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[VariableUnion], conversation_variables)
return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""

View File

@ -189,7 +189,7 @@ class BaseAppGenerator:
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if not isinstance(value, dict):
if value and not isinstance(value, dict):
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
case _:
raise AssertionError("this statement should be unreachable.")

View File

@ -1,6 +1,6 @@
import logging
from core.variables import Variable
from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,

View File

@ -251,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
if not super().is_empty() and not self.tool_calls:
return False
return True
return super().is_empty() and not self.tool_calls
class SystemPromptMessage(PromptMessage):

View File

@ -1,6 +1,7 @@
import logging
from collections.abc import Sequence
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@ -54,7 +55,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@ -273,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
)
self.trace_client.add_span(workflow_span)

View File

@ -166,7 +166,7 @@ class SpanBuilder:
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
kind=trace_api.SpanKind.INTERNAL,
kind=span_data.span_kind,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,

View File

@ -4,7 +4,7 @@ from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import BaseModel, Field
@ -34,3 +34,4 @@ class SpanData(BaseModel):
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")

View File

@ -7,8 +7,8 @@ from typing import Any, cast
from flask import has_request_context
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs.login import current_user
from models import Account, Tenant
@ -230,30 +229,32 @@ class WorkflowTool(Tool):
"""
Resolve user from database (worker/Celery context).
"""
with session_factory.create_session() as session:
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = session.scalar(user_stmt)
if user:
user.current_tenant = tenant
session.expunge(user)
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = session.scalar(end_user_stmt)
if end_user:
session.expunge(end_user)
return end_user
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = db.session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(user_stmt)
if user:
user.current_tenant = tenant
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = db.session.scalar(end_user_stmt)
if end_user:
return end_user
return None
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
with Session(db.engine, expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
if not version:
stmt = (
select(Workflow)
@ -265,22 +266,24 @@ class WorkflowTool(Tool):
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError("workflow not found or not published")
if not workflow:
raise ValueError("workflow not found or not published")
return workflow
session.expunge(workflow)
return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
app = session.scalar(stmt)
if not app:
raise ValueError("app not found")
if not app:
raise ValueError("app not found")
return app
session.expunge(app)
return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""

View File

@ -30,6 +30,7 @@ from .variables import (
SecretVariable,
StringVariable,
Variable,
VariableBase,
)
__all__ = [
@ -62,4 +63,5 @@ __all__ = [
"StringSegment",
"StringVariable",
"Variable",
"VariableBase",
]

View File

@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
# - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]

View File

@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType
class Variable(Segment):
class VariableBase(Segment):
"""
A variable is a segment that has a name.
@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list)
class StringVariable(StringSegment, Variable):
class StringVariable(StringSegment, VariableBase):
pass
class FloatVariable(FloatSegment, Variable):
class FloatVariable(FloatSegment, VariableBase):
pass
class IntegerVariable(IntegerSegment, Variable):
class IntegerVariable(IntegerSegment, VariableBase):
pass
class ObjectVariable(ObjectSegment, Variable):
class ObjectVariable(ObjectSegment, VariableBase):
pass
class ArrayVariable(ArraySegment, Variable):
class ArrayVariable(ArraySegment, VariableBase):
pass
@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):
class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE
value: None = None
class FileVariable(FileSegment, Variable):
class FileVariable(FileSegment, VariableBase):
pass
class BooleanVariable(BooleanSegment, Variable):
class BooleanVariable(BooleanSegment, VariableBase):
pass
@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required.
# The `Variable` type is used to enable serialization and deserialization with Pydantic.
# Use `VariableBase` for type hinting when serialization is not required.
#
# Note:
# - All variants in `VariableUnion` must inherit from the `Variable` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
VariableUnion: TypeAlias = Annotated[
# - All variants in `Variable` must inherit from the `VariableBase` class.
# - The union must include all non-abstract subclasses of `VariableBase`.
Variable: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]

View File

@ -1,7 +1,7 @@
import abc
from typing import Protocol
from core.variables import Variable
from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):
@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable"):
def update(self, conversation_id: str, variable: "VariableBase"):
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
:param variable: The `VariableBase` instance containing the updated value.
"""
pass

View File

@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
class CommandType(StrEnum):
@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):

View File

@ -11,7 +11,7 @@ from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
NodeExecutionType,
@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
datetime,
list[GraphNodeEventBase],
object | None,
dict[str, VariableUnion],
dict[str, Variable],
LLMUsage,
]
],
@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return variable_mapping
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -73,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:

View File

@ -2,7 +2,7 @@ import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# ==================== Validation Part
# Check if variable exists
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def _handle_item(
self,
*,
variable: Variable,
variable: VariableBase,
operation: Operation,
value: Any,
):

View File

@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
@ -32,7 +32,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@ -46,13 +46,13 @@ class VariablePool(BaseModel):
description="System variables",
default_factory=SystemVariable.empty,
)
environment_variables: Sequence[VariableUnion] = Field(
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list[VariableUnion],
default_factory=list[Variable],
)
conversation_variables: Sequence[VariableUnion] = Field(
conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
default_factory=list[VariableUnion],
default_factory=list[Variable],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
@ -105,7 +105,7 @@ class VariablePool(BaseModel):
f"got {len(selector)} elements"
)
if isinstance(value, Variable):
if isinstance(value, VariableBase):
variable = value
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
@ -114,9 +114,9 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
self.variable_dictionary[node_id][name] = cast(Variable, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:

View File

@ -2,7 +2,7 @@ import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.variables import Variable
from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool
@ -26,7 +26,7 @@ class VariableLoader(Protocol):
"""
@abc.abstractmethod
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.
@ -36,7 +36,7 @@ class VariableLoader(Protocol):
:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
:return: a list of Variable objects that match the provided selectors.
:return: a list of VariableBase objects that match the provided selectors.
"""
pass
@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
Serves as a placeholder when no variable loading is needed.
"""
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
return []

View File

@ -19,6 +19,7 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@ -136,13 +137,11 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
node_config = workflow.get_node_config_by_id(node_id)
node_config = dict(workflow.get_node_config_by_id(node_id))
node_config_data = node_config.get("data", {})
# Get node class
# Get node type
node_type = NodeType(node_config_data.get("type"))
node_version = node_config_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@ -158,12 +157,12 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node = node_cls(
id=str(uuid.uuid4()),
config=node_config,
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(node_config)
node_cls = type(node)
try:
# variable selector to variable mapping

View File

@ -10,6 +10,7 @@ import os
from dotenv import load_dotenv
from configs import dify_config
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@ -19,12 +20,17 @@ def is_enabled() -> bool:
"""
Check if logstore extension is enabled.
Logstore is considered enabled when:
1. All required Aliyun SLS environment variables are set
2. At least one repository configuration points to a logstore implementation
Returns:
True if all required Aliyun SLS environment variables are set, False otherwise
True if logstore should be initialized, False otherwise
"""
# Load environment variables from .env file
load_dotenv()
# Check if Aliyun SLS connection parameters are configured
required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET",
@ -33,24 +39,32 @@ def is_enabled() -> bool:
"ALIYUN_SLS_PROJECT_NAME",
]
all_set = all(os.environ.get(var) for var in required_vars)
sls_vars_set = all(os.environ.get(var) for var in required_vars)
if not all_set:
logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
if not sls_vars_set:
return False
return all_set
# Check if any repository configuration points to logstore implementation
repository_configs = [
dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_RUN_REPOSITORY,
]
uses_logstore = any("logstore" in config.lower() for config in repository_configs)
if not uses_logstore:
return False
logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
return True
def init_app(app: DifyApp):
"""
Initialize logstore on application startup.
This function:
1. Creates Aliyun SLS project if it doesn't exist
2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
3. Creates indexes with field configurations based on PostgreSQL table structures
This operation is idempotent and only executes once during application startup.
If initialization fails, the application continues running without logstore features.
Args:
app: The Dify application instance
@ -58,17 +72,23 @@ def init_app(app: DifyApp):
try:
from extensions.logstore.aliyun_logstore import AliyunLogStore
logger.info("Initializing logstore...")
logger.info("Initializing Aliyun SLS Logstore...")
# Create logstore client and initialize project/logstores/indexes
# Create logstore client and initialize resources
logstore_client = AliyunLogStore()
logstore_client.init_project_logstore()
# Attach to app for potential later use
app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully")
except Exception:
logger.exception("Failed to initialize logstore")
# Don't raise - allow application to continue even if logstore init fails
# This ensures that the application can still run if logstore is misconfigured
logger.exception(
"Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
"Application will continue but logstore features will NOT work.",
os.environ.get("ALIYUN_SLS_ENDPOINT"),
os.environ.get("ALIYUN_SLS_REGION"),
os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
)
# Don't raise - allow application to continue even if logstore setup fails

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import os
import socket
import threading
import time
from collections.abc import Sequence
@ -179,9 +180,18 @@ class AliyunLogStore:
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
self.log_enabled: bool = (
os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
)
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
# Get timeout configuration
check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
# Pre-check endpoint connectivity to prevent indefinite hangs
self._check_endpoint_connectivity(self.endpoint, check_timeout)
# Initialize SDK client
self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
@ -199,6 +209,49 @@ class AliyunLogStore:
self.__class__._initialized = True
@staticmethod
def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
"""
Check if the SLS endpoint is reachable before creating LogClient.
Prevents indefinite hangs when the endpoint is unreachable.
Args:
endpoint: SLS endpoint URL
timeout: Connection timeout in seconds
Raises:
ConnectionError: If endpoint is not reachable
"""
# Parse endpoint URL to extract hostname and port
from urllib.parse import urlparse
parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
hostname = parsed_url.hostname
port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
if not hostname:
raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
sock = None
try:
# Create socket and set timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect((hostname, port))
except Exception as e:
# Catch all exceptions and provide clear error message
error_type = type(e).__name__
raise ConnectionError(
f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
) from e
finally:
# Ensure socket is properly closed
if sock:
try:
sock.close()
except Exception: # noqa: S110
pass # Ignore errors during cleanup
@property
def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled."""
@ -220,19 +273,16 @@ class AliyunLogStore:
try:
self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol:
logger.info("Successfully connected to project %s using PG protocol", self.project_name)
logger.info("Using PG protocol for project %s", self.project_name)
# Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled()
return True
else:
logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
logger.info("Using SDK mode for project %s", self.project_name)
return False
except Exception as e:
logger.warning(
"Failed to establish PG connection for project %s: %s. Will use SDK mode.",
self.project_name,
str(e),
)
logger.info("Using SDK mode for project %s", self.project_name)
logger.debug("PG connection details: %s", str(e))
self._use_pg_protocol = False
return False
@ -246,10 +296,6 @@ class AliyunLogStore:
if self._use_pg_protocol:
return
logger.info(
"Attempting delayed PG connection for newly created project %s ...",
self.project_name,
)
self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None
@ -284,11 +330,7 @@ class AliyunLogStore:
if project_is_new:
# For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False
logger.info(
"Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
self.project_name,
self.__class__._pg_connection_delay,
)
logger.info("Using SDK mode for project %s (newly created)", self.project_name)
if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer(
@ -299,7 +341,6 @@ class AliyunLogStore:
self.__class__._pg_connection_timer.start()
else:
# For existing projects, attempt PG connection immediately
logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
@ -318,9 +359,9 @@ class AliyunLogStore:
existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index:
logger.info(
"Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
"PG protocol requires scan_index to be enabled.",
"Logstore %s requires scan_index enabled, using SDK mode for project %s",
logstore_name,
self.project_name,
)
self._use_pg_protocol = False
# Close PG connection if it was initialized
@ -748,7 +789,6 @@ class AliyunLogStore:
reverse=reverse,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
@ -770,7 +810,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
@ -845,7 +884,6 @@ class AliyunLogStore:
query=full_query,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
@ -853,8 +891,7 @@ class AliyunLogStore:
self.project_name,
from_time,
to_time,
query,
sql,
full_query,
)
try:
@ -865,7 +902,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",

View File

@ -7,8 +7,7 @@ from contextlib import contextmanager
from typing import Any
import psycopg2
import psycopg2.pool
from psycopg2 import InterfaceError, OperationalError
from sqlalchemy import create_engine
from configs import dify_config
@ -16,11 +15,7 @@ logger = logging.getLogger(__name__)
class AliyunLogStorePG:
"""
PostgreSQL protocol support for Aliyun SLS LogStore.
Handles PG connection pooling and operations for regions that support PG protocol.
"""
"""PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
"""
@ -36,24 +31,11 @@ class AliyunLogStorePG:
self._access_key_secret = access_key_secret
self._endpoint = endpoint
self.project_name = project_name
self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
self._engine: Any = None # SQLAlchemy Engine
self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
"""
Check if a TCP port is reachable using socket connection.
This provides a fast check before attempting full database connection,
preventing long waits when connecting to unsupported regions.
Args:
host: Hostname or IP address
port: Port number
timeout: Connection timeout in seconds (default: 2.0)
Returns:
True if port is reachable, False otherwise
"""
"""Fast TCP port check to avoid long waits on unsupported regions."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
@ -65,166 +47,101 @@ class AliyunLogStorePG:
return False
def init_connection(self) -> bool:
"""
Initialize PostgreSQL connection pool for SLS PG protocol support.
Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
_use_pg_protocol to True and creates a connection pool. If connection fails
(region doesn't support PG protocol or other errors), returns False.
Returns:
True if PG protocol is supported and initialized, False otherwise
"""
"""Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
try:
# Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
# Get pool configuration
pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
# Pool configuration
pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
logger.debug(
"Check PG protocol connection to SLS: host=%s, project=%s",
pg_host,
self.project_name,
)
logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
# Fast port connectivity check before attempting full connection
# This prevents long waits when connecting to unsupported regions
# Fast port check to avoid long waits
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
logger.info(
"USE SDK mode for read/write operations, host=%s",
pg_host,
)
logger.debug("Using SDK mode for host=%s", pg_host)
return False
# Create connection pool
self._pg_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=pg_max_connections,
host=pg_host,
port=5432,
database=self.project_name,
user=self._access_key_id,
password=self._access_key_secret,
sslmode="require",
connect_timeout=5,
application_name=f"Dify-{dify_config.project.version}",
# Build connection URL
from urllib.parse import quote_plus
username = quote_plus(self._access_key_id)
password = quote_plus(self._access_key_secret)
database_url = (
f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
)
# Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
# Connection pool creation success already indicates connectivity
# Create SQLAlchemy engine with connection pool
self._engine = create_engine(
database_url,
pool_size=pool_size,
max_overflow=max_overflow,
pool_recycle=pool_recycle,
pool_pre_ping=pool_pre_ping,
pool_timeout=30,
connect_args={
"connect_timeout": 5,
"application_name": f"Dify-{dify_config.project.version}-fixautocommit",
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
},
)
self._use_pg_protocol = True
logger.info(
"PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
"PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
self.project_name,
pool_size,
pool_recycle,
)
return True
except Exception as e:
# PG connection failed - fallback to SDK mode
self._use_pg_protocol = False
if self._pg_pool:
if self._engine:
try:
self._pg_pool.closeall()
self._engine.dispose()
except Exception:
logger.debug("Failed to close PG connection pool during cleanup, ignoring")
self._pg_pool = None
logger.debug("Failed to dispose engine during cleanup, ignoring")
self._engine = None
logger.info(
"PG protocol connection failed (region may not support PG protocol): %s. "
"Falling back to SDK mode for read/write operations.",
str(e),
)
return False
def _is_connection_valid(self, conn: Any) -> bool:
"""
Check if a connection is still valid.
Args:
conn: psycopg2 connection object
Returns:
True if connection is valid, False otherwise
"""
try:
# Check if connection is closed
if conn.closed:
return False
# Quick ping test - execute a lightweight query
# For SLS PG protocol, we can't use SELECT 1 without FROM,
# so we just check the connection status
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.fetchone()
return True
except Exception:
logger.debug("Using SDK mode for region: %s", str(e))
return False
@contextmanager
def _get_connection(self):
"""
Context manager to get a PostgreSQL connection from the pool.
"""Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
if not self._engine:
raise RuntimeError("SQLAlchemy engine is not initialized")
Automatically validates and refreshes stale connections.
Note: Aliyun SLS PG protocol does not support transactions, so we always
use autocommit mode.
Yields:
psycopg2 connection object
Raises:
RuntimeError: If PG pool is not initialized
"""
if not self._pg_pool:
raise RuntimeError("PG connection pool is not initialized")
conn = self._pg_pool.getconn()
connection = self._engine.raw_connection()
try:
# Validate connection and get a fresh one if needed
if not self._is_connection_valid(conn):
logger.debug("Connection is stale, marking as bad and getting a new one")
# Mark connection as bad and get a new one
self._pg_pool.putconn(conn, close=True)
conn = self._pg_pool.getconn()
# Aliyun SLS PG protocol does not support transactions, always use autocommit
conn.autocommit = True
yield conn
connection.autocommit = True # SLS PG protocol does not support transactions
yield connection
except Exception:
raise
finally:
# Return connection to pool (or close if it's bad)
if self._is_connection_valid(conn):
self._pg_pool.putconn(conn)
else:
self._pg_pool.putconn(conn, close=True)
connection.close()
def close(self) -> None:
"""Close the PostgreSQL connection pool."""
if self._pg_pool:
"""Dispose SQLAlchemy engine and close all connections."""
if self._engine:
try:
self._pg_pool.closeall()
logger.info("PG connection pool closed")
self._engine.dispose()
logger.info("SQLAlchemy engine disposed")
except Exception:
logger.exception("Failed to close PG connection pool")
logger.exception("Failed to dispose engine")
def _is_retriable_error(self, error: Exception) -> bool:
"""
Check if an error is retriable (connection-related issues).
Args:
error: Exception to check
Returns:
True if the error is retriable, False otherwise
"""
# Retry on connection-related errors
if isinstance(error, (OperationalError, InterfaceError)):
"""Check if error is retriable (connection-related issues)."""
# Check for psycopg2 connection errors directly
if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
return True
# Check error message for specific connection issues
error_msg = str(error).lower()
retriable_patterns = [
"connection",
@ -234,34 +151,18 @@ class AliyunLogStorePG:
"reset by peer",
"no route to host",
"network",
"operational error",
"interface error",
]
return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
"""
Write log to SLS using PostgreSQL protocol with automatic retry.
Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
writes with log_version field for versioning, same as SDK implementation.
Args:
logstore: Name of the logstore table
contents: List of (field_name, value) tuples
log_enabled: Whether to enable logging
Raises:
psycopg2.Error: If database operation fails after all retries
"""
"""Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
if not contents:
return
# Extract field names and values from contents
fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents]
# Build INSERT statement with literal values
# Note: Aliyun SLS PG protocol doesn't support parameterized queries,
# so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled:
@ -272,67 +173,40 @@ class AliyunLogStorePG:
len(contents),
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql)
# Success - exit retry loop
return
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., data validation error), fail immediately
logger.exception(
"Failed to put logs to logstore %s via PG protocol (non-retriable error)",
logstore,
)
logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
"Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
retry_delay *= 2
else:
# Last attempt failed
logger.exception(
"Failed to put logs to logstore %s via PG protocol after %d attempts",
logstore,
max_retries,
)
logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
"""
Execute SQL query using PostgreSQL protocol with automatic retry.
Args:
sql: SQL query string
logstore: Name of the logstore (for logging purposes)
log_enabled: Whether to enable logging
Returns:
List of result rows as dictionaries
Raises:
psycopg2.Error: If database operation fails after all retries
"""
"""Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
@ -341,20 +215,16 @@ class AliyunLogStorePG:
sql,
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql)
# Get column names from cursor description
columns = [desc[0] for desc in cursor.description]
# Fetch all results and convert to list of dicts
result = []
for row in cursor.fetchall():
row_dict = {}
@ -372,36 +242,31 @@ class AliyunLogStorePG:
return result
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
"Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
logstore,
sql,
)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
"Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
retry_delay *= 2
else:
# Last attempt failed
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
"Failed to execute SQL on logstore %s after %d attempts: sql=%s",
logstore,
max_retries,
sql,
)
raise
# This line should never be reached due to raise above, but makes type checker happy
return []

View File

@ -0,0 +1,29 @@
"""
LogStore repository utilities.
"""
from typing import Any
def safe_float(value: Any, default: float = 0.0) -> float:
"""
Safely convert a value to float, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value: Any, default: int = 0) -> int:
"""
Safely convert a value to int, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default

View File

@ -14,6 +14,8 @@ from typing import Any
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@ -52,9 +54,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.index = int(data.get("index", 0))
model.elapsed_time = float(data.get("elapsed_time", 0))
model.index = safe_int(data.get("index", 0))
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
# Optional fields
model.workflow_run_id = data.get("workflow_run_id")
@ -130,6 +131,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
node_id,
)
try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_id = escape_identifier(workflow_id)
escaped_node_id = escape_identifier(node_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@ -138,10 +145,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_id = '{workflow_id}'
AND node_id = '{node_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_id = '{escaped_workflow_id}'
AND node_id = '{escaped_node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
@ -153,7 +160,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
else:
# Use SDK with LogStore query syntax
query = (
f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
)
from_time = 0
to_time = int(time.time()) # now
@ -227,6 +235,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
workflow_run_id,
)
try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_run_id = escape_identifier(workflow_run_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@ -235,9 +248,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_run_id = '{workflow_run_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_run_id = '{escaped_workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1000
@ -248,7 +261,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
query = (
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_run_id: {escaped_workflow_run_id}"
)
from_time = 0
to_time = int(time.time()) # now
@ -313,16 +329,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
"""
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try:
# Escape parameters to prevent SQL injection
escaped_execution_id = escape_identifier(execution_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
if tenant_id:
escaped_tenant_id = escape_identifier(tenant_id)
tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
else:
tenant_filter = ""
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1
"""
@ -332,10 +356,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
# Note: Values must be quoted in LogStore query syntax to prevent injection
if tenant_id:
query = f"id: {execution_id} and tenant_id: {tenant_id}"
query = (
f"id:{escape_logstore_query_value(execution_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
)
else:
query = f"id: {execution_id}"
query = f"id:{escape_logstore_query_value(execution_id)}"
from_time = 0
to_time = int(time.time()) # now

View File

@ -10,6 +10,7 @@ Key Features:
- Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security
- SQL injection prevention via parameter escaping
"""
import logging
@ -22,6 +23,8 @@ from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.total_tokens = int(data.get("total_tokens", 0))
model.total_steps = int(data.get("total_steps", 0))
model.exceptions_count = int(data.get("exceptions_count", 0))
model.total_tokens = safe_int(data.get("total_tokens", 0))
model.total_steps = safe_int(data.get("total_steps", 0))
model.exceptions_count = safe_int(data.get("exceptions_count", 0))
# Optional fields
model.graph = data.get("graph")
@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else:
model.elapsed_time = float(data.get("elapsed_time", 0))
# Use safe conversion to handle 'null' strings and None values
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
return model
@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
status,
)
# Convert triggered_from to list if needed
if isinstance(triggered_from, WorkflowRunTriggeredFrom):
if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
triggered_from_list = [triggered_from]
else:
triggered_from_list = list(triggered_from)
# Build triggered_from filter
triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Build status filter
status_filter = f"AND status='{status}'" if status else ""
# Build triggered_from filter with escaped values
# Support both enum and string values for triggered_from
triggered_from_filter = " OR ".join(
[
f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
for tf in triggered_from_list
]
)
# Build status filter with escaped value
status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
# Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record
@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND ({triggered_from_filter})
{status_filter}
{last_id_filter}
@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try:
# Escape parameters to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
WHERE id = '{escaped_run_id}'
AND tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
# Note: Values must be quoted in LogStore query syntax to prevent injection
query = (
f"id:{escape_logstore_query_value(run_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
f"and app_id:{escape_logstore_query_value(app_id)}"
)
from_time = 0
to_time = int(time.time()) # now
@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try:
# Escape parameter to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND __time__ > 0
WHERE id = '{escaped_run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id}"
# Note: Values must be quoted in LogStore query syntax
query = f"id:{escape_logstore_query_value(run_id)}"
from_time = 0
to_time = int(time.time()) # now
@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
triggered_from,
status,
)
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter
time_filter = ""
if time_range:
@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# If status is provided, simple count
if status:
escaped_status = escape_sql_string(status)
if status == "running":
# Running status requires window function
sql = f"""
@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='{status}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='{escaped_status}'
AND finished_at IS NOT NULL
{time_filter}
"""
@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running
try:
# Escape parameters (already escaped above, reuse variables)
# Count finished runs grouped by status
finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY status
@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
created_by,
COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date, created_by

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
@ -22,18 +23,6 @@ from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
def to_serializable(obj):
"""
Convert non-JSON-serializable objects into JSON-compatible formats.
- Uses `to_dict()` if it's a callable method.
- Falls back to string representation.
"""
if hasattr(obj, "to_dict") and callable(obj.to_dict):
return obj.to_dict()
return str(obj)
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
@ -79,7 +68,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
@ -113,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
# Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
json_converter = WorkflowRuntimeTypeConverter()
logstore_model = [
("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes
@ -127,19 +119,19 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
("version", domain_model.workflow_version),
(
"graph",
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
if domain_model.graph and self._enable_put_graph_field
else "{}",
),
(
"inputs",
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
if domain_model.inputs
else "{}",
),
(
"outputs",
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
if domain_model.outputs
else "{}",
),

View File

@ -24,6 +24,8 @@ from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier
from libs.helper import extract_tenant_id
from models import (
Account,
@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
node_execution_id=data.get("node_execution_id"),
workflow_id=data.get("workflow_id", ""),
workflow_execution_id=data.get("workflow_run_id"),
index=int(data.get("index", 0)),
index=safe_int(data.get("index", 0)),
predecessor_node_id=data.get("predecessor_node_id"),
node_id=data.get("node_id", ""),
node_type=NodeType(data.get("node_type", "start")),
@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
outputs=outputs,
status=status,
error=data.get("error"),
elapsed_time=float(data.get("elapsed_time", 0.0)),
elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
metadata=domain_metadata,
created_at=created_at,
finished_at=finished_at,
@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
logger.debug(
@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
Save or update the inputs, process_data, or outputs associated with a specific
node_execution record.
For LogStore implementation, this is similar to save() since we always write
complete records. We append a new record with updated data fields.
For LogStore implementation, this is a no-op for the LogStore write because save()
already writes all fields including inputs, process_data, and outputs. The caller
typically calls save() first to persist status/metadata, then calls save_execution_data()
to persist data fields. Since LogStore writes complete records atomically, we don't
need a separate write here to avoid duplicate records.
However, if dual-write is enabled, we still need to call the SQL repository's
save_execution_data() method to properly update the SQL database.
Args:
execution: The NodeExecution instance with data to save
"""
logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
# In LogStore, we simply write a new complete record with the data
# The log_version timestamp will ensure this is treated as the latest version
self.save(execution)
logger.debug(
"save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
execution.id,
execution.node_execution_id,
)
# No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
# Calling save() again would create a duplicate record in the append-only LogStore
# Dual-write to SQL database if enabled (for safe migration)
if self._enable_dual_write:
try:
self.sql_repository.save_execution_data(execution)
logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
except Exception:
logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
# Don't raise - LogStore write succeeded, SQL is just a backup
def get_by_workflow_run(
self,
@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
This ensures we only get the final version of each node execution.
Uses LogStore SQL query with window function to get the latest version of each node execution.
This ensures we only get the most recent version of each node execution record.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
A list of NodeExecution instances
Note:
This method filters by finished_at IS NOT NULL to avoid duplicates from
version updates. For complete history including intermediate states,
a different query strategy would be needed.
This method uses ROW_NUMBER() window function partitioned by node_execution_id
to get the latest version (highest log_version) of each node execution.
"""
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
# Build SQL query with deduplication using finished_at IS NOT NULL
# This optimization avoids window functions for common case where we only
# want the final state of each node execution
# Build SQL query with deduplication using window function
# ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
# ensures we get the latest version of each node execution
# Build ORDER BY clause
# Escape parameters to prevent SQL injection
escaped_workflow_run_id = escape_identifier(workflow_run_id)
escaped_tenant_id = escape_identifier(self._tenant_id)
# Build ORDER BY clause for outer query
order_clause = ""
if order_config and order_config.order_by:
order_fields = []
@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
if order_fields:
order_clause = "ORDER BY " + ", ".join(order_fields)
sql = f"""
SELECT *
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{workflow_run_id}'
AND tenant_id='{self._tenant_id}'
AND finished_at IS NOT NULL
"""
# Build app_id filter for subquery
app_id_filter = ""
if self._app_id:
sql += f" AND app_id='{self._app_id}'"
escaped_app_id = escape_identifier(self._app_id)
app_id_filter = f" AND app_id='{escaped_app_id}'"
# Use window function to get latest version of each node execution
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{escaped_workflow_run_id}'
AND tenant_id='{escaped_tenant_id}'
{app_id_filter}
) t
WHERE rn = 1
"""
if order_clause:
sql += f" {order_clause}"

View File

@ -0,0 +1,134 @@
"""
SQL Escape Utility for LogStore Queries
This module provides escaping utilities to prevent injection attacks in LogStore queries.
LogStore supports two query modes:
1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
Key Security Concerns:
- Prevent tenant A from accessing tenant B's data via injection
- SLS queries are read-only, so we focus on data access control
- Different escaping strategies for SQL vs LogStore query syntax
"""
def escape_sql_string(value: str) -> str:
"""
Escape a string value for safe use in SQL queries.
This function escapes single quotes by doubling them, which is the standard
SQL escaping method. This prevents SQL injection by ensuring that user input
cannot break out of string literals.
Args:
value: The string value to escape
Returns:
Escaped string safe for use in SQL queries
Examples:
>>> escape_sql_string("normal_value")
"normal_value"
>>> escape_sql_string("value' OR '1'='1")
"value'' OR ''1''=''1"
>>> escape_sql_string("tenant's_id")
"tenant''s_id"
Security:
- Prevents breaking out of string literals
- Stops injection attacks like: ' OR '1'='1
- Protects against cross-tenant data access
"""
if not value:
return value
# Escape single quotes by doubling them (standard SQL escaping)
# This prevents breaking out of string literals in SQL queries
return value.replace("'", "''")
def escape_identifier(value: str) -> str:
"""
Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
This function is for PG protocol mode (SQL syntax).
For SDK mode, use escape_logstore_query_value() instead.
Args:
value: The identifier value to escape
Returns:
Escaped identifier safe for use in SQL queries
Examples:
>>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
"550e8400-e29b-41d4-a716-446655440000"
>>> escape_identifier("tenant_id' OR '1'='1")
"tenant_id'' OR ''1''=''1"
Security:
- Prevents SQL injection via identifiers
- Stops cross-tenant access attempts
- Works for UUIDs, alphanumeric IDs, and similar identifiers
"""
# For identifiers, use the same escaping as strings
# This is simple and effective for preventing injection
return escape_sql_string(value)
def escape_logstore_query_value(value: str) -> str:
"""
Escape value for LogStore query syntax (SDK mode).
LogStore query syntax rules:
1. Keywords (and/or/not) are case-insensitive
2. Single quotes are ordinary characters (no special meaning)
3. Double quotes wrap values: key:"value"
4. Backslash is the escape character:
- \" for double quote inside value
- \\ for backslash itself
5. Parentheses can change query structure
To prevent injection:
- Wrap value in double quotes to treat special chars as literals
- Escape backslashes and double quotes using backslash
Args:
value: The value to escape for LogStore query syntax
Returns:
Quoted and escaped value safe for LogStore query syntax (includes the quotes)
Examples:
>>> escape_logstore_query_value("normal_value")
'"normal_value"'
>>> escape_logstore_query_value("value or field:evil")
'"value or field:evil"' # 'or' and ':' are now literals
>>> escape_logstore_query_value('value"test')
'"value\\"test"' # Internal double quote escaped
>>> escape_logstore_query_value('value\\test')
'"value\\\\test"' # Backslash escaped
Security:
- Prevents injection via and/or/not keywords
- Prevents injection via colons (:)
- Prevents injection via parentheses
- Protects against cross-tenant data access
Note:
Escape order is critical: backslash first, then double quotes.
Otherwise, we'd double-escape the escape character itself.
"""
if not value:
return '""'
# IMPORTANT: Escape backslashes FIRST, then double quotes
# This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
escaped = value.replace("\\", "\\\\") # \ -> \\
escaped = escaped.replace('"', '\\"') # " -> \"
# Wrap in double quotes to treat as literal string
# This prevents and/or/not/:/() from being interpreted as operators
return f'"{escaped}"'

View File

@ -38,7 +38,7 @@ from core.variables.variables import (
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
VariableBase,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
}
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("variable"):
raise VariableError("missing variable")
return mapping["variable"]
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
"""
This factory function is used to create the environment variable or the conversation variable,
not support the File type.
@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
result: Variable
result: VariableBase
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
return cast(Variable, result)
return cast(VariableBase, result)
def build_segment(value: Any, /) -> Segment:
@ -285,8 +285,8 @@ def segment_to_variable(
id: str | None = None,
name: str | None = None,
description: str = "",
) -> Variable:
if isinstance(segment, Variable):
) -> VariableBase:
if isinstance(segment, VariableBase):
return segment
name = name or selector[-1]
id = id or str(uuid4())
@ -297,7 +297,7 @@ def segment_to_variable(
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
Variable,
VariableBase,
variable_class(
id=id,
name=name,

View File

@ -1,7 +1,7 @@
from flask_restx import fields
from core.helper import encrypter
from core.variables import SecretVariable, SegmentType, Variable
from core.variables import SecretVariable, SegmentType, VariableBase
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
"value_type": value.value_type.value,
"description": value.description,
}
if isinstance(value, Variable):
if isinstance(value, VariableBase):
return {
"id": value.id,
"name": value.name,

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@ -46,7 +44,7 @@ if TYPE_CHECKING:
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
from core.variables import SecretVariable, Segment, SegmentType, Variable
from core.variables import SecretVariable, Segment, SegmentType, VariableBase
from factories import variable_factory
from libs import helper
@ -69,7 +67,7 @@ class WorkflowType(StrEnum):
RAG_PIPELINE = "rag-pipeline"
@classmethod
def value_of(cls, value: str) -> WorkflowType:
def value_of(cls, value: str) -> "WorkflowType":
"""
Get value of given mode.
@ -82,7 +80,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}")
@classmethod
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
"""
Get workflow type from app mode.
@ -178,12 +176,12 @@ class Workflow(Base): # bug
graph: str,
features: str,
created_by: str,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
) -> Workflow:
) -> "Workflow":
workflow = Workflow()
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
@ -447,7 +445,7 @@ class Workflow(Base): # bug
# decrypt secret variables value
def decrypt_func(
var: Variable,
var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
@ -463,7 +461,7 @@ class Workflow(Base): # bug
return decrypted_results
@environment_variables.setter
def environment_variables(self, value: Sequence[Variable]):
def environment_variables(self, value: Sequence[VariableBase]):
if not value:
self._environment_variables = "{}"
return
@ -487,7 +485,7 @@ class Workflow(Base): # bug
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
def encrypt_func(var: Variable) -> Variable:
def encrypt_func(var: VariableBase) -> VariableBase:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@ -517,7 +515,7 @@ class Workflow(Base): # bug
return result
@property
def conversation_variables(self) -> Sequence[Variable]:
def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
@ -527,7 +525,7 @@ class Workflow(Base): # bug
return results
@conversation_variables.setter
def conversation_variables(self, value: Sequence[Variable]):
def conversation_variables(self, value: Sequence[VariableBase]):
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
@ -622,7 +620,7 @@ class WorkflowRun(Base):
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
pause: Mapped[WorkflowPause | None] = orm.relationship(
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
"WorkflowPause",
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
uselist=False,
@ -692,7 +690,7 @@ class WorkflowRun(Base):
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@ -844,7 +842,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
"WorkflowNodeExecutionOffload",
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
uselist=True,
@ -854,13 +852,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod
def preload_offload_data(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
@staticmethod
def preload_offload_data_and_files(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
@ -935,7 +933,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
)
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property
@ -1049,7 +1047,7 @@ class WorkflowNodeExecutionOffload(Base):
back_populates="offload_data",
)
file: Mapped[UploadFile | None] = orm.relationship(
file: Mapped[Optional["UploadFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@ -1067,7 +1065,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
INSTALLED_APP = "installed-app"
@classmethod
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
"""
Get value of given mode.
@ -1184,7 +1182,7 @@ class ConversationVariable(TypeBase):
)
@classmethod
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
obj = cls(
id=variable.id,
app_id=app_id,
@ -1193,7 +1191,7 @@ class ConversationVariable(TypeBase):
)
return obj
def to_variable(self) -> Variable:
def to_variable(self) -> VariableBase:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
@ -1337,7 +1335,7 @@ class WorkflowDraftVariable(Base):
)
# Relationship to WorkflowDraftVariableFile
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@ -1507,7 +1505,7 @@ class WorkflowDraftVariable(Base):
node_execution_id: str | None,
description: str = "",
file_id: str | None = None,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.id = str(uuid4())
variable.created_at = naive_utc_now()
@ -1530,7 +1528,7 @@ class WorkflowDraftVariable(Base):
name: str,
value: Segment,
description: str = "",
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
@ -1551,7 +1549,7 @@ class WorkflowDraftVariable(Base):
value: Segment,
node_execution_id: str,
editable: bool = False,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
@ -1574,7 +1572,7 @@ class WorkflowDraftVariable(Base):
visible: bool = True,
editable: bool = True,
file_id: str | None = None,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=node_id,
@ -1670,7 +1668,7 @@ class WorkflowDraftVariableFile(Base):
)
# Relationship to UploadFile
upload_file: Mapped[UploadFile] = orm.relationship(
upload_file: Mapped["UploadFile"] = orm.relationship(
foreign_keys=[upload_file_id],
lazy="raise",
uselist=False,
@ -1737,7 +1735,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# Relationship to WorkflowRun
workflow_run: Mapped[WorkflowRun] = orm.relationship(
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
foreign_keys=[workflow_run_id],
# require explicit preloading.
lazy="raise",
@ -1793,7 +1791,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
)
@classmethod
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired):
return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.11.2"
version = "1.11.3"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, cast
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@ -748,6 +748,21 @@ class AccountService:
cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token
@staticmethod
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
"""
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
This keeps backward compatibility for older records that stored uppercase emails while the
rest of the system gradually normalizes new inputs.
"""
query_session = session or db.session
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@ -1363,16 +1378,22 @@ class RegisterService:
if not inviter:
raise ValueError("Inviter is required")
normalized_email = email.lower()
"""Invite new member"""
with Session(db.engine) as session:
account = session.query(Account).filter_by(email=email).first()
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
name = email.split("@")[0]
name = normalized_email.split("@")[0]
account = cls.register(
email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
email=normalized_email,
name=name,
language=language,
status=AccountStatus.PENDING,
is_setup=True,
)
# Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role)
@ -1394,7 +1415,7 @@ class RegisterService:
# send email
send_invite_member_mail_task.delay(
language=language,
to=email,
to=account.email,
token=token,
inviter_name=inviter.name if inviter else "Dify",
workspace_name=tenant.name,
@ -1493,6 +1514,16 @@ class RegisterService:
invitation: dict = json.loads(data)
return invitation
@classmethod
def get_invitation_with_case_fallback(
cls, workspace_id: str | None, email: str | None, token: str
) -> dict[str, Any] | None:
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
if invitation or not email or email == email.lower():
return invitation
normalized_email = email.lower()
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
def _generate_refresh_token(length: int = 64):
token = secrets.token_hex(length)

View File

@ -1,7 +1,7 @@
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.variables.variables import Variable
from core.variables.variables import VariableBase
from models import ConversationVariable
@ -13,7 +13,7 @@ class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker
def update(self, conversation_id: str, variable: Variable) -> None:
def update(self, conversation_id: str, variable: VariableBase) -> None:
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)

View File

@ -1,9 +1,14 @@
import logging
import os
from collections.abc import Mapping
from typing import Any
import httpx
from core.helper.trace_id_helper import generate_traceparent_header
logger = logging.getLogger(__name__)
class BaseRequest:
proxies: Mapping[str, str] | None = {
@ -38,6 +43,15 @@ class BaseRequest:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
mounts = cls._build_mounts()
try:
# ensure traceparent even when OTEL is disabled
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
except Exception:
logger.debug("Failed to generate traceparent header", exc_info=True)
with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers)
return response.json()

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from pydantic import BaseModel
from sqlalchemy import select
from yarl import URL
from configs import dify_config
@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import (
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider import ProviderCredential
from models.provider_ids import GenericProviderID
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
@ -506,6 +509,33 @@ class PluginService:
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstaller()
# Get plugin info before uninstalling to delete associated credentials
try:
plugins = manager.list_plugins(tenant_id)
plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
if plugin:
plugin_id = plugin.plugin_id
logger.info("Deleting credentials for plugin: %s", plugin_id)
# Delete provider credentials that match this plugin
credentials = db.session.scalars(
select(ProviderCredential).where(
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name.like(f"{plugin_id}/%"),
)
).all()
for cred in credentials:
db.session.delete(cred)
db.session.commit()
logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id)
except Exception as e:
logger.warning("Failed to delete credentials: %s", e)
# Continue with uninstall even if credential deletion fails
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod

View File

@ -36,7 +36,7 @@ from core.rag.entities.event import (
)
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.variables.variables import VariableBase
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@ -270,8 +270,8 @@ class RagPipelineService:
graph: dict,
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list,
) -> Workflow:
"""

View File

@ -12,6 +12,7 @@ from libs.passport import PassportService
from libs.password import compare_password
from models import Account, AccountStatus
from models.model import App, EndUser, Site
from services.account_service import AccountService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
@ -32,7 +33,7 @@ class WebAppAuthService:
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
raise AccountNotFoundError()
@ -52,7 +53,7 @@ class WebAppAuthService:
@classmethod
def get_user_through_email(cls, email: str):
account = db.session.query(Account).where(Account.email == email).first()
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
return None

View File

@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.variables import Segment, StringSegment, Variable
from core.variables import Segment, StringSegment, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import (
ArrayFileSegment,
@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
# Application ID for which variables are being loaded.
_app_id: str
_tenant_id: str
_fallback_variables: Sequence[Variable]
_fallback_variables: Sequence[VariableBase]
def __init__(
self,
engine: Engine,
app_id: str,
tenant_id: str,
fallback_variables: Sequence[Variable] | None = None,
fallback_variables: Sequence[VariableBase] | None = None,
):
self._engine = engine
self._app_id = app_id
@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
return (selector[0], selector[1])
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
if not selectors:
return []
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
variable_by_selector: dict[tuple[str, str], Variable] = {}
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
variable_by_selector: dict[tuple[str, str], VariableBase] = {}
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
return list(variable_by_selector.values())
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]:
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
# and must remain synchronized with it.
# Ideally, these should be co-located for better maintainability.

View File

@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.variables import VariableBase
from core.variables.variables import Variable
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
@ -198,8 +198,8 @@ class WorkflowService:
features: dict,
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
) -> Workflow:
"""
Sync draft workflow
@ -1044,7 +1044,7 @@ def _setup_variable_pool(
workflow: Workflow,
node_type: NodeType,
conversation_id: str,
conversation_variables: list[Variable],
conversation_variables: list[VariableBase],
):
# Only inject system variables for START node type.
if node_type == NodeType.START or node_type.is_trigger_node:
@ -1070,9 +1070,9 @@ def _setup_variable_pool(
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), #
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=cast(list[Variable], conversation_variables), #
)
return variable_pool

View File

@ -40,7 +40,7 @@ class TestActivateCheckApi:
"tenant": tenant,
}
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation):
"""
Test checking valid invitation token.
@ -66,7 +66,7 @@ class TestActivateCheckApi:
assert response["data"]["workspace_id"] == "workspace-123"
assert response["data"]["email"] == "invitee@example.com"
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_invalid_invitation_token(self, mock_get_invitation, app):
"""
Test checking invalid invitation token.
@ -88,7 +88,7 @@ class TestActivateCheckApi:
# Assert
assert response["is_valid"] is False
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation):
"""
Test checking token without workspace ID.
@ -109,7 +109,7 @@ class TestActivateCheckApi:
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation):
"""
Test checking token without email parameter.
@ -130,6 +130,20 @@ class TestActivateCheckApi:
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation):
"""Ensure token validation uses lowercase emails."""
mock_get_invitation.return_value = mock_invitation
with app.test_request_context(
"/activate/check?workspace_id=workspace-123&email=Invitee@Example.com&token=valid_token"
):
api = ActivateCheckApi()
response = api.get()
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
class TestActivateApi:
"""Test cases for account activation endpoint."""
@ -212,7 +226,7 @@ class TestActivateApi:
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
mock_db.session.commit.assert_called_once()
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_activation_with_invalid_token(self, mock_get_invitation, app):
"""
Test account activation with invalid token.
@ -241,7 +255,7 @@ class TestActivateApi:
with pytest.raises(AlreadyActivateError):
api.post()
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
def test_activation_sets_interface_theme(
@ -290,7 +304,7 @@ class TestActivateApi:
("es-ES", "Europe/Madrid"),
],
)
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
def test_activation_with_different_locales(
@ -336,7 +350,7 @@ class TestActivateApi:
assert mock_account.interface_language == language
assert mock_account.timezone == timezone
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
def test_activation_returns_success_response(
@ -376,7 +390,7 @@ class TestActivateApi:
# Assert
assert response == {"result": "success"}
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
def test_activation_without_workspace_id(
@ -415,3 +429,37 @@ class TestActivateApi:
# Assert
assert response["result"] == "success"
mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
def test_activation_normalizes_email_before_lookup(
self,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_account,
):
"""Ensure uppercase emails are normalized before lookup and revocation."""
mock_get_invitation.return_value = mock_invitation
with app.test_request_context(
"/activate",
method="POST",
json={
"workspace_id": "workspace-123",
"email": "Invitee@Example.com",
"token": "valid_token",
"name": "John Doe",
"interface_language": "en-US",
"timezone": "UTC",
},
):
api = ActivateApi()
response = api.post()
assert response["result"] == "success"
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")

View File

@ -34,7 +34,7 @@ class TestAuthenticationSecurity:
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_invalid_email_with_registration_allowed(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
):
@ -67,7 +67,7 @@ class TestAuthenticationSecurity:
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_wrong_password_returns_error(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
):
@ -100,7 +100,7 @@ class TestAuthenticationSecurity:
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_invalid_email_with_registration_disabled(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
):

View File

@ -0,0 +1,177 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.email_register import (
EmailRegisterCheckApi,
EmailRegisterResetApi,
EmailRegisterSendEmailApi,
)
from services.account_service import AccountService
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
class TestEmailRegisterSendEmailApi:
@patch("controllers.console.auth.email_register.Session")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
@patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
def test_send_email_normalizes_and_falls_back(
self,
mock_extract_ip,
mock_is_email_send_ip_limit,
mock_is_freeze,
mock_send_mail,
mock_get_account,
mock_session_cls,
app,
):
mock_send_mail.return_value = "token-123"
mock_is_freeze.return_value = False
mock_account = MagicMock()
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_get_account.return_value = mock_account
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
with app.test_request_context(
"/email-register/send-email",
method="POST",
json={"email": "Invitee@Example.com", "language": "en-US"},
):
response = EmailRegisterSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_is_freeze.assert_called_once_with("invitee@example.com")
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
mock_extract_ip.assert_called_once()
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
class TestEmailRegisterCheckApi:
@patch("controllers.console.auth.email_register.AccountService.reset_email_register_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.generate_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.add_email_register_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
@patch("controllers.console.auth.email_register.AccountService.is_email_register_error_rate_limit")
def test_validity_normalizes_email_before_checks(
self,
mock_rate_limit_check,
mock_get_data,
mock_add_rate,
mock_revoke,
mock_generate_token,
mock_reset_rate,
app,
):
mock_rate_limit_check.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"}
mock_generate_token.return_value = (None, "new-token")
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
with app.test_request_context(
"/email-register/validity",
method="POST",
json={"email": "User@Example.com", "code": "4321", "token": "token-123"},
):
response = EmailRegisterCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
mock_rate_limit_check.assert_called_once_with("user@example.com")
mock_generate_token.assert_called_once_with(
"user@example.com", code="4321", additional_data={"phase": "register"}
)
mock_reset_rate.assert_called_once_with("user@example.com")
mock_add_rate.assert_not_called()
mock_revoke.assert_called_once_with("token-123")
class TestEmailRegisterResetApi:
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.login")
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
@patch("controllers.console.auth.email_register.Session")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
def test_reset_creates_account_with_normalized_email(
self,
mock_extract_ip,
mock_get_data,
mock_revoke_token,
mock_get_account,
mock_session_cls,
mock_create_account,
mock_login,
mock_reset_login_rate,
app,
):
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
mock_create_account.return_value = MagicMock()
token_pair = MagicMock()
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
mock_login.return_value = token_pair
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_get_account.return_value = None
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
with app.test_request_context(
"/email-register",
method="POST",
json={"token": "token-123", "new_password": "ValidPass123!", "password_confirm": "ValidPass123!"},
):
response = EmailRegisterResetApi().post()
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!")
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_extract_ip.assert_called_once()
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
mock_session = MagicMock()
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_query = MagicMock()
second_query.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_query, second_query]
account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
assert account is expected_account
assert mock_session.execute.call_count == 2

View File

@ -0,0 +1,176 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.forgot_password import (
ForgotPasswordCheckApi,
ForgotPasswordResetApi,
ForgotPasswordSendEmailApi,
)
from services.account_service import AccountService
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
class TestForgotPasswordSendEmailApi:
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1")
def test_send_normalizes_email(
self,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_account,
mock_session_cls,
app,
):
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_send_email.return_value = "token-123"
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
controller_features = SimpleNamespace(is_allow_register=True)
with (
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
patch(
"controllers.console.auth.forgot_password.FeatureService.get_system_features",
return_value=controller_features,
),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
with app.test_request_context(
"/forgot-password",
method="POST",
json={"email": "User@Example.com", "language": "zh-Hans"},
):
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_send_email.assert_called_once_with(
account=mock_account,
email="user@example.com",
language="zh-Hans",
is_allow_register=True,
)
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_extract_ip.assert_called_once()
class TestForgotPasswordCheckApi:
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_check_normalizes_email(
self,
mock_rate_limit_check,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
):
mock_rate_limit_check.return_value = False
mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"}
mock_generate_token.return_value = (None, "new-token")
wraps_features = SimpleNamespace(enable_email_password_login=True)
with (
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "ADMIN@Example.com", "code": "4321", "token": "token-123"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "admin@example.com", "token": "new-token"}
mock_rate_limit_check.assert_called_once_with("admin@example.com")
mock_generate_token.assert_called_once_with(
"Admin@Example.com",
code="4321",
additional_data={"phase": "reset"},
)
mock_reset_rate.assert_called_once_with("admin@example.com")
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_fetches_account_with_original_email(
self,
mock_get_reset_data,
mock_revoke_token,
mock_get_account,
mock_session_cls,
mock_update_account,
app,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
wraps_features = SimpleNamespace(enable_email_password_login=True)
with (
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
with app.test_request_context(
"/forgot-password/resets",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_reset_data.assert_called_once_with("token-123")
mock_revoke_token.assert_called_once_with("token-123")
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_update_account.assert_called_once()
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
mock_session = MagicMock()
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_query = MagicMock()
second_query.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_query, second_query]
account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
assert account is expected_account
assert mock_session.execute.call_count == 2

View File

@ -76,7 +76,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.AccountService.login")
@ -120,7 +120,7 @@ class TestLoginApi:
response = login_api.post()
# Assert
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!")
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None)
mock_login.assert_called_once()
mock_reset_rate_limit.assert_called_once_with("test@example.com")
assert response.json["result"] == "success"
@ -128,7 +128,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.AccountService.login")
@ -182,7 +182,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
"""
Test login rejection when rate limit is exceeded.
@ -230,7 +230,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
def test_login_fails_with_invalid_credentials(
@ -269,7 +269,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
def test_login_fails_for_banned_account(
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app
@ -298,7 +298,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@ -343,7 +343,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
"""
Test login failure when invitation email doesn't match login email.
@ -371,6 +371,52 @@ class TestLoginApi:
with pytest.raises(InvalidEmailError):
login_api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.AccountService.login")
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
def test_login_retries_with_lowercase_email(
self,
mock_reset_rate_limit,
mock_login_service,
mock_get_tenants,
mock_add_rate_limit,
mock_authenticate,
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
mock_account,
mock_token_pair,
):
"""Test that login retries with lowercase email when uppercase lookup fails."""
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
mock_get_tenants.return_value = [MagicMock()]
mock_login_service.return_value = mock_token_pair
with app.test_request_context(
"/login",
method="POST",
json={"email": "Upper@Example.com", "password": encode_password("ValidPass123!")},
):
response = LoginApi().post()
assert response.json["result"] == "success"
assert mock_authenticate.call_args_list == [
(("Upper@Example.com", "ValidPass123!", None), {}),
(("upper@example.com", "ValidPass123!", None), {}),
]
mock_add_rate_limit.assert_not_called()
mock_reset_rate_limit.assert_called_once_with("upper@example.com")
class TestLogoutApi:
"""Test cases for the LogoutApi endpoint."""

View File

@ -12,6 +12,7 @@ from controllers.console.auth.oauth import (
)
from libs.oauth import OAuthUserInfo
from models.account import AccountStatus
from services.account_service import AccountService
from services.errors.account import AccountRegisterError
@ -215,6 +216,34 @@ class TestOAuthCallback:
assert status_code == 400
assert response["error"] == expected_error
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.redirect")
def test_invitation_comparison_is_case_insensitive(
self,
mock_redirect,
mock_register_service,
mock_get_providers,
mock_config,
resource,
app,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo(
id="123", name="Test User", email="User@Example.com"
)
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
mock_register_service.is_valid_invite_token.return_value = True
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
resource.get("github")
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123")
@pytest.mark.parametrize(
("account_status", "expected_redirect"),
[
@ -395,12 +424,12 @@ class TestAccountGeneration:
account.name = "Test User"
return account
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.Account")
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.oauth.Session")
@patch("controllers.console.auth.oauth.select")
@patch("controllers.console.auth.oauth.Account")
@patch("controllers.console.auth.oauth.db")
def test_should_get_account_by_openid_or_email(
self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
):
# Mock db.engine for Session creation
mock_db.engine = MagicMock()
@ -410,15 +439,31 @@ class TestAccountGeneration:
result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
mock_get_account.assert_not_called()
# Test fallback to email
# Test fallback to email lookup
mock_account_model.get_by_openid.return_value = None
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_get_account.return_value = mock_account
result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
mock_session = MagicMock()
first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_result = MagicMock()
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
assert result == expected_account
assert mock_session.execute.call_count == 2
@pytest.mark.parametrize(
("allow_register", "existing_account", "should_create"),
@ -466,6 +511,35 @@ class TestAccountGeneration:
mock_register_service.register.assert_called_once_with(
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
)
else:
mock_register_service.register.assert_not_called()
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_register_with_lowercase_email(
self,
mock_db,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
app,
):
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
mock_feature_service.get_system_features.return_value.is_allow_register = True
mock_register_service.register.return_value = MagicMock()
with app.test_request_context(headers={"Accept-Language": "en-US"}):
_generate_account("github", user_info)
mock_register_service.register.assert_called_once_with(
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
@patch("controllers.console.auth.oauth.TenantService")

View File

@ -28,6 +28,22 @@ from controllers.console.auth.forgot_password import (
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
@pytest.fixture(autouse=True)
def _mock_forgot_password_session():
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_session_cls.return_value.__exit__.return_value = None
yield mock_session
@pytest.fixture(autouse=True)
def _mock_forgot_password_db():
with patch("controllers.console.auth.forgot_password.db") as mock_db:
mock_db.engine = MagicMock()
yield mock_db
class TestForgotPasswordSendEmailApi:
"""Test cases for sending password reset emails."""
@ -47,20 +63,16 @@ class TestForgotPasswordSendEmailApi:
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.select")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
def test_send_reset_email_success(
self,
mock_get_features,
mock_send_email,
mock_select,
mock_session,
mock_get_account,
mock_is_ip_limit,
mock_forgot_db,
mock_wraps_db,
app,
mock_account,
@ -75,11 +87,8 @@ class TestForgotPasswordSendEmailApi:
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_is_ip_limit.return_value = False
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_get_account.return_value = mock_account
mock_send_email.return_value = "reset_token_123"
mock_get_features.return_value.is_allow_register = True
@ -125,20 +134,16 @@ class TestForgotPasswordSendEmailApi:
],
)
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.select")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
def test_send_reset_email_language_handling(
self,
mock_get_features,
mock_send_email,
mock_select,
mock_session,
mock_get_account,
mock_is_ip_limit,
mock_forgot_db,
mock_wraps_db,
app,
mock_account,
@ -154,11 +159,8 @@ class TestForgotPasswordSendEmailApi:
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_is_ip_limit.return_value = False
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_get_account.return_value = mock_account
mock_send_email.return_value = "token"
mock_get_features.return_value.is_allow_register = True
@ -229,8 +231,46 @@ class TestForgotPasswordCheckApi:
assert response["email"] == "test@example.com"
assert response["token"] == "new_token"
mock_revoke_token.assert_called_once_with("old_token")
mock_generate_token.assert_called_once_with(
"test@example.com", code="123456", additional_data={"phase": "reset"}
)
mock_reset_rate_limit.assert_called_once_with("test@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
def test_verify_code_preserves_token_email_case(
self,
mock_reset_rate_limit,
mock_generate_token,
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
mock_generate_token.return_value = (None, "fresh-token")
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "user@example.com", "code": "999888", "token": "upper_token"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "fresh-token"}
mock_generate_token.assert_called_once_with(
"User@Example.com", code="999888", additional_data={"phase": "reset"}
)
mock_revoke_token.assert_called_once_with("upper_token")
mock_reset_rate_limit.assert_called_once_with("user@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
@ -355,20 +395,16 @@ class TestForgotPasswordResetApi:
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.select")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
def test_reset_password_success(
self,
mock_get_tenants,
mock_select,
mock_session,
mock_get_account,
mock_revoke_token,
mock_get_data,
mock_forgot_db,
mock_wraps_db,
app,
mock_account,
@ -383,11 +419,8 @@ class TestForgotPasswordResetApi:
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_get_account.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
# Act
@ -475,13 +508,11 @@ class TestForgotPasswordResetApi:
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.select")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
def test_reset_password_account_not_found(
self, mock_select, mock_session, mock_revoke_token, mock_get_data, mock_forgot_db, mock_wraps_db, app
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
):
"""
Test password reset for non-existent account.
@ -491,11 +522,8 @@ class TestForgotPasswordResetApi:
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_get_account.return_value = None
# Act & Assert
with app.test_request_context(

View File

@ -0,0 +1,39 @@
from types import SimpleNamespace
from unittest.mock import patch
from controllers.console.setup import SetupApi
class TestSetupApi:
def test_post_lowercases_email_before_register(self):
"""Ensure setup registration normalizes email casing."""
payload = {
"email": "Admin@Example.com",
"name": "Admin User",
"password": "ValidPass123!",
"language": "en-US",
}
setup_api = SetupApi(api=None)
mock_console_ns = SimpleNamespace(payload=payload)
with (
patch("controllers.console.setup.console_ns", mock_console_ns),
patch("controllers.console.setup.get_setup_status", return_value=False),
patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0),
patch("controllers.console.setup.get_init_validate_status", return_value=True),
patch("controllers.console.setup.extract_remote_ip", return_value="127.0.0.1"),
patch("controllers.console.setup.request", object()),
patch("controllers.console.setup.RegisterService.setup") as mock_register,
):
response, status = setup_api.post()
assert response == {"result": "success"}
assert status == 201
mock_register.assert_called_once_with(
email="admin@example.com",
name=payload["name"],
password=payload["password"],
ip_address="127.0.0.1",
language=payload["language"],
)

View File

@ -0,0 +1,247 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, g
from controllers.console.workspace.account import (
AccountDeleteUpdateFeedbackApi,
ChangeEmailCheckApi,
ChangeEmailResetApi,
ChangeEmailSendEmailApi,
CheckEmailUnique,
)
from models import Account
from services.account_service import AccountService
@pytest.fixture
def app():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["RESTX_MASK_HEADER"] = "X-Fields"
app.login_manager = SimpleNamespace(_load_user=lambda: None)
return app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
account = Account(name=account_id, email=email)
account.email = email
account.id = account_id
account.status = "active"
account._current_tenant = tenant_obj
return account
def _set_logged_in_user(account: Account):
g._login_user = account
g._current_tenant = account.current_tenant
class TestChangeEmailSend:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_normalize_new_email_phase(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = {"email": "current@example.com"}
mock_send_email.return_value = "token-abc"
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailSendEmailApi().post()
assert response == {"result": "success", "data": "token-abc"}
mock_send_email.assert_called_once_with(
account=None,
email="new@example.com",
old_email="current@example.com",
language="en-US",
phase="new_email",
)
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
class TestChangeEmailValidity:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_validate_with_normalized_email(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
mock_generate_token.return_value = (None, "new-token")
with app.test_request_context(
"/account/change-email/validity",
method="POST",
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
mock_is_rate_limit.assert_called_once_with("user@example.com")
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
mock_generate_token.assert_called_once_with(
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
)
mock_reset_rate.assert_called_once_with("user@example.com")
mock_csrf.assert_called_once()
class TestChangeEmailReset:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_normalize_new_email_before_update(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {"old_email": "OLD@example.com"}
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
mock_update_account.return_value = mock_account_after_update
with app.test_request_context(
"/account/change-email/reset",
method="POST",
json={"new_email": "New@Example.com", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
ChangeEmailResetApi().post()
mock_is_freeze.assert_called_once_with("new@example.com")
mock_check_unique.assert_called_once_with("new@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_update_account.assert_called_once_with(current_user, email="new@example.com")
mock_send_notify.assert_called_once_with(email="new@example.com")
mock_csrf.assert_called_once()
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
_mock_wraps_db(mock_db)
with app.test_request_context(
"/account/delete/feedback",
method="POST",
json={"email": "User@Example.com", "feedback": "test"},
):
response = AccountDeleteUpdateFeedbackApi().post()
assert response == {"result": "success"}
mock_update.assert_called_once_with("User@Example.com", "test")
class TestCheckEmailUnique:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
_mock_wraps_db(mock_db)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
with app.test_request_context(
"/account/change-email/check-email-unique",
method="POST",
json={"email": "Case@Test.com"},
):
response = CheckEmailUnique().post()
assert response == {"result": "success"}
mock_is_freeze.assert_called_once_with("case@test.com")
mock_check_unique.assert_called_once_with("case@test.com")
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
session = MagicMock()
first = MagicMock()
first.scalar_one_or_none.return_value = None
second = MagicMock()
expected_account = MagicMock()
second.scalar_one_or_none.return_value = expected_account
session.execute.side_effect = [first, second]
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session)
assert result is expected_account
assert session.execute.call_count == 2

View File

@ -0,0 +1,82 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, g
from controllers.console.workspace.members import MemberInviteEmailApi
from models.account import Account, TenantAccountRole
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.login_manager = SimpleNamespace(_load_user=lambda: None)
return flask_app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_feature_flags():
placeholder_quota = SimpleNamespace(limit=0, size=0)
workspace_members = SimpleNamespace(is_available=lambda count: True)
return SimpleNamespace(
billing=SimpleNamespace(enabled=False),
workspace_members=workspace_members,
members=placeholder_quota,
apps=placeholder_quota,
vector_space=placeholder_quota,
documents_upload_quota=placeholder_quota,
annotation_quota_limit=placeholder_quota,
)
class TestMemberInviteEmailApi:
@patch("controllers.console.workspace.members.FeatureService.get_features")
@patch("controllers.console.workspace.members.RegisterService.invite_new_member")
@patch("controllers.console.workspace.members.current_account_with_tenant")
@patch("controllers.console.wraps.db")
@patch("libs.login.check_csrf_token", return_value=None)
def test_invite_normalizes_emails(
self,
mock_csrf,
mock_db,
mock_current_account,
mock_invite_member,
mock_get_features,
app,
):
_mock_wraps_db(mock_db)
mock_get_features.return_value = _build_feature_flags()
mock_invite_member.return_value = "token-abc"
tenant = SimpleNamespace(id="tenant-1", name="Test Tenant")
inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active")
mock_current_account.return_value = (inviter, tenant.id)
with patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"):
with app.test_request_context(
"/workspaces/current/members/invite-email",
method="POST",
json={"emails": ["User@Example.com"], "role": TenantAccountRole.EDITOR.value, "language": "en-US"},
):
account = Account(name="tester", email="tester@example.com")
account._current_tenant = tenant
g._login_user = account
g._current_tenant = tenant
response, status_code = MemberInviteEmailApi().post()
assert status_code == 201
assert response["invitation_results"][0]["email"] == "user@example.com"
assert mock_invite_member.call_count == 1
call_args = mock_invite_member.call_args
assert call_args.kwargs["tenant"] == tenant
assert call_args.kwargs["email"] == "User@Example.com"
assert call_args.kwargs["language"] == "en-US"
assert call_args.kwargs["role"] == TenantAccountRole.EDITOR
assert call_args.kwargs["inviter"] == inviter
mock_csrf.assert_called_once()

View File

@ -1,195 +0,0 @@
"""Unit tests for controllers.web.forgot_password endpoints."""
from __future__ import annotations
import base64
import builtins
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask.views import MethodView
# Ensure flask_restx.api finds MethodView during import.
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
def _load_controller_module():
"""Import controllers.web.forgot_password using a stub package."""
import importlib
import importlib.util
import sys
from types import ModuleType
parent_module_name = "controllers.web"
module_name = f"{parent_module_name}.forgot_password"
if parent_module_name not in sys.modules:
from flask_restx import Namespace
stub = ModuleType(parent_module_name)
stub.__file__ = "controllers/web/__init__.py"
stub.__path__ = ["controllers/web"]
stub.__package__ = "controllers"
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
stub.web_ns = Namespace("web", description="Web API", path="/")
sys.modules[parent_module_name] = stub
return importlib.import_module(module_name)
forgot_password_module = _load_controller_module()
ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi
ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi
ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi
@pytest.fixture
def app() -> Flask:
"""Configure a minimal Flask app for request contexts."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture(autouse=True)
def _enable_web_endpoint_guards():
"""Stub enterprise and feature toggles used by route decorators."""
features = SimpleNamespace(enable_email_password_login=True)
with (
patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features),
):
yield
@pytest.fixture(autouse=True)
def _mock_controller_db():
"""Replace controller-level db reference with a simple stub."""
fake_db = SimpleNamespace(engine=MagicMock(name="engine"))
fake_wraps_db = SimpleNamespace(
session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True))))
)
with (
patch("controllers.web.forgot_password.db", fake_db),
patch("controllers.console.wraps.db", fake_wraps_db),
):
yield fake_db
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10")
def test_send_reset_email_success(
mock_extract_ip: MagicMock,
mock_is_ip_limit: MagicMock,
mock_session: MagicMock,
mock_send_email: MagicMock,
app: Flask,
):
"""POST /forgot-password returns token when email exists and limits allow."""
mock_account = MagicMock()
session_ctx = MagicMock()
mock_session.return_value.__enter__.return_value = session_ctx
session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account
with app.test_request_context(
"/forgot-password",
method="POST",
json={"email": "user@example.com"},
):
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "reset-token"}
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("203.0.113.10")
mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US")
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token"))
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False)
def test_check_token_success(
mock_is_rate_limited: MagicMock,
mock_get_data: MagicMock,
mock_revoke: MagicMock,
mock_generate: MagicMock,
mock_reset_limit: MagicMock,
app: Flask,
):
"""POST /forgot-password/validity validates the code and refreshes token."""
mock_get_data.return_value = {"email": "user@example.com", "code": "123456"}
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "user@example.com", "code": "123456", "token": "old-token"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
mock_is_rate_limited.assert_called_once_with("user@example.com")
mock_get_data.assert_called_once_with("old-token")
mock_revoke.assert_called_once_with("old-token")
mock_generate.assert_called_once_with(
"user@example.com",
code="123456",
additional_data={"phase": "reset"},
)
mock_reset_limit.assert_called_once_with("user@example.com")
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_success(
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_session: MagicMock,
mock_token_bytes: MagicMock,
mock_hash_password: MagicMock,
app: Flask,
):
"""POST /forgot-password/resets updates the stored password when token is valid."""
mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"}
account = MagicMock()
session_ctx = MagicMock()
mock_session.return_value.__enter__.return_value = session_ctx
session_ctx.execute.return_value.scalar_one_or_none.return_value = account
with app.test_request_context(
"/forgot-password/resets",
method="POST",
json={
"token": "reset-token",
"new_password": "StrongPass123!",
"password_confirm": "StrongPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_data.assert_called_once_with("reset-token")
mock_revoke_token.assert_called_once_with("reset-token")
mock_token_bytes.assert_called_once_with(16)
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
expected_password = base64.b64encode(b"hashed-value").decode()
assert account.password == expected_password
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
assert account.password_salt == expected_salt
session_ctx.commit.assert_called_once()

View File

@ -0,0 +1,226 @@
import base64
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.forgot_password import (
ForgotPasswordCheckApi,
ForgotPasswordResetApi,
ForgotPasswordSendEmailApi,
)
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture(autouse=True)
def _patch_wraps():
wraps_features = SimpleNamespace(enable_email_password_login=True)
dify_settings = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD")
with (
patch("controllers.console.wraps.db") as mock_db,
patch("controllers.console.wraps.dify_config", dify_settings),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield
class TestForgotPasswordSendEmailApi:
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
@patch("controllers.web.forgot_password.Session")
def test_should_normalize_email_before_sending(
self,
mock_session_cls,
mock_extract_ip,
mock_rate_limit,
mock_get_account,
mock_send_mail,
app,
):
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_send_mail.return_value = "token-123"
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password",
method="POST",
json={"email": "User@Example.com", "language": "zh-Hans"},
):
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
mock_extract_ip.assert_called_once()
mock_rate_limit.assert_called_once_with("127.0.0.1")
class TestForgotPasswordCheckApi:
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.add_forgot_password_error_rate_limit")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_should_normalize_email_for_validity_checks(
self,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"}
mock_generate_token.return_value = (None, "new-token")
with app.test_request_context(
"/web/forgot-password/validity",
method="POST",
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
mock_is_rate_limit.assert_called_once_with("user@example.com")
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
mock_generate_token.assert_called_once_with(
"User@Example.com",
code="1234",
additional_data={"phase": "reset"},
)
mock_reset_rate.assert_called_once_with("user@example.com")
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_should_preserve_token_email_case(
self,
mock_is_rate_limit,
mock_get_data,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"}
mock_generate_token.return_value = (None, "fresh-token")
with app.test_request_context(
"/web/forgot-password/validity",
method="POST",
json={"email": "mixedcase@example.com", "code": "5678", "token": "token-upper"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "mixedcase@example.com", "token": "fresh-token"}
mock_generate_token.assert_called_once_with(
"MixedCase@Example.com",
code="5678",
additional_data={"phase": "reset"},
)
mock_revoke_token.assert_called_once_with("token-upper")
mock_reset_rate.assert_called_once_with("mixedcase@example.com")
class TestForgotPasswordResetApi:
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
def test_should_fetch_account_with_fallback(
self,
mock_get_reset_data,
mock_revoke_token,
mock_session_cls,
mock_get_account,
mock_update_account,
app,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_update_account.assert_called_once()
mock_revoke_token.assert_called_once_with("token-123")
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
def test_should_update_password_and_commit(
self,
mock_get_account,
mock_get_reset_data,
mock_revoke_token,
mock_session_cls,
mock_token_bytes,
mock_hash_password,
app,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
account = MagicMock()
mock_get_account.return_value = account
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "reset-token",
"new_password": "StrongPass123!",
"password_confirm": "StrongPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_reset_data.assert_called_once_with("reset-token")
mock_revoke_token.assert_called_once_with("reset-token")
mock_token_bytes.assert_called_once_with(16)
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
expected_password = base64.b64encode(b"hashed-value").decode()
assert account.password == expected_password
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
assert account.password_salt == expected_salt
mock_session.commit.assert_called_once()

View File

@ -0,0 +1,91 @@
import base64
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
def encode_code(code: str) -> str:
return base64.b64encode(code.encode("utf-8")).decode()
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture(autouse=True)
def _patch_wraps():
wraps_features = SimpleNamespace(enable_email_password_login=True)
console_dify = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD")
web_dify = SimpleNamespace(ENTERPRISE_ENABLED=True)
with (
patch("controllers.console.wraps.db") as mock_db,
patch("controllers.console.wraps.dify_config", console_dify),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
patch("controllers.web.login.dify_config", web_dify),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield
class TestEmailCodeLoginSendEmailApi:
@patch("controllers.web.login.WebAppAuthService.send_email_code_login_email")
@patch("controllers.web.login.WebAppAuthService.get_user_through_email")
def test_should_fetch_account_with_original_email(
self,
mock_get_user,
mock_send_email,
app,
):
mock_account = MagicMock()
mock_get_user.return_value = mock_account
mock_send_email.return_value = "token-123"
with app.test_request_context(
"/web/email-code-login",
method="POST",
json={"email": "User@Example.com", "language": "en-US"},
):
response = EmailCodeLoginSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_get_user.assert_called_once_with("User@Example.com")
mock_send_email.assert_called_once_with(account=mock_account, language="en-US")
class TestEmailCodeLoginApi:
@patch("controllers.web.login.AccountService.reset_login_error_rate_limit")
@patch("controllers.web.login.WebAppAuthService.login", return_value="new-access-token")
@patch("controllers.web.login.WebAppAuthService.get_user_through_email")
@patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token")
@patch("controllers.web.login.WebAppAuthService.get_email_code_login_data")
def test_should_normalize_email_before_validating(
self,
mock_get_token_data,
mock_revoke_token,
mock_get_user,
mock_login,
mock_reset_login_rate,
app,
):
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_user.return_value = MagicMock()
with app.test_request_context(
"/web/email-code-login/validity",
method="POST",
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
):
response = EmailCodeLoginApi().post()
assert response.get_json() == {"result": "success", "data": {"access_token": "new-access-token"}}
mock_get_user.assert_called_once_with("User@Example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_login.assert_called_once()
mock_reset_login_rate.assert_called_once_with("user@example.com")

View File

@ -228,11 +228,28 @@ def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.M
def scalar(self, _stmt):
return self.results.pop(0)
# SQLAlchemy Session APIs used by code under test
def expunge(self, *_args, **_kwargs):
pass
def close(self):
pass
# support `with session_factory.create_session() as session:`
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
tenant = SimpleNamespace(id="tenant_id")
end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id")
db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user]))
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
# Monkeypatch session factory to return our stub session
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: StubSession([tenant, None, end_user]),
)
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
@ -266,8 +283,23 @@ def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pyt
def scalar(self, _stmt):
return self.results.pop(0)
db_stub = SimpleNamespace(session=StubSession([None]))
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
def expunge(self, *_args, **_kwargs):
pass
def close(self):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
# Monkeypatch session factory to return our stub session with no tenant
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: StubSession([None]),
)
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),

View File

@ -35,7 +35,6 @@ from core.variables.variables import (
SecretVariable,
StringVariable,
Variable,
VariableUnion,
)
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
@ -96,7 +95,7 @@ class _Segments(BaseModel):
class _Variables(BaseModel):
variables: list[VariableUnion]
variables: list[Variable]
def create_test_file(
@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad:
# Create one instance of each variable type
test_file = create_test_file()
all_variables: list[VariableUnion] = [
all_variables: list[Variable] = [
NoneVariable(name="none_var"),
StringVariable(value="test string", name="string_var"),
IntegerVariable(value=42, name="int_var"),

View File

@ -11,7 +11,7 @@ from core.variables import (
SegmentType,
StringVariable,
)
from core.variables.variables import Variable
from core.variables.variables import VariableBase
def test_frozen_variables():
@ -76,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object():
var: Variable = StringVariable(name="text", value="text")
var: VariableBase = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42

View File

@ -24,7 +24,7 @@ from core.variables.variables import (
IntegerVariable,
ObjectVariable,
StringVariable,
VariableUnion,
Variable,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.runtime import VariablePool
@ -160,7 +160,7 @@ class TestVariablePoolSerialization:
)
# Create environment variables with all types including ArrayFileVariable
env_vars: list[VariableUnion] = [
env_vars: list[Variable] = [
StringVariable(
id="env_string_id",
name="env_string",
@ -182,7 +182,7 @@ class TestVariablePoolSerialization:
]
# Create conversation variables with complex data
conv_vars: list[VariableUnion] = [
conv_vars: list[Variable] = [
StringVariable(
id="conv_string_id",
name="conv_string",

View File

@ -2,13 +2,17 @@ from types import SimpleNamespace
import pytest
from configs import dify_config
from core.file.enums import FileType
from core.file.models import File, FileTransferMethod
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.variables import StringVariable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
)
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
@ -96,6 +100,58 @@ class TestWorkflowEntry:
assert output_var is not None
assert output_var.value == "system_user"
def test_single_step_run_injects_code_limits(self):
"""Ensure single-step CodeNode execution configures limits."""
# Arrange
node_id = "code_node"
node_data = {
"type": "code",
"title": "Code",
"desc": None,
"variables": [],
"code_language": CodeLanguage.PYTHON3,
"code": "def main():\n return {}",
"outputs": {},
}
node_config = {"id": node_id, "data": node_data}
class StubWorkflow:
def __init__(self):
self.tenant_id = "tenant"
self.app_id = "app"
self.id = "workflow"
self.graph_dict = {"nodes": [node_config], "edges": []}
def get_node_config_by_id(self, target_id: str):
assert target_id == node_id
return node_config
workflow = StubWorkflow()
variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={})
expected_limits = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
max_precision=dify_config.CODE_MAX_PRECISION,
max_depth=dify_config.CODE_MAX_DEPTH,
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
# Act
node, _ = WorkflowEntry.single_step_run(
workflow=workflow,
node_id=node_id,
user_id="user",
user_inputs={},
variable_pool=variable_pool,
)
# Assert
assert isinstance(node, CodeNode)
assert node._limits == expected_limits
def test_mapping_user_inputs_to_variable_pool_with_env_variables(self):
"""Test mapping environment variables from user inputs to variable pool."""
# Initialize variable pool with environment variables

View File

@ -0,0 +1 @@
"""LogStore extension unit tests."""

View File

@ -0,0 +1,469 @@
"""
Unit tests for SQL escape utility functions.
These tests ensure that SQL injection attacks are properly prevented
in LogStore queries, particularly for cross-tenant access scenarios.
"""
import pytest
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
class TestEscapeSQLString:
"""Test escape_sql_string function."""
def test_escape_empty_string(self):
"""Test escaping empty string."""
assert escape_sql_string("") == ""
def test_escape_normal_string(self):
"""Test escaping string without special characters."""
assert escape_sql_string("tenant_abc123") == "tenant_abc123"
assert escape_sql_string("app-uuid-1234") == "app-uuid-1234"
def test_escape_single_quote(self):
"""Test escaping single quote."""
# Single quote should be doubled
assert escape_sql_string("tenant'id") == "tenant''id"
assert escape_sql_string("O'Reilly") == "O''Reilly"
def test_escape_multiple_quotes(self):
"""Test escaping multiple single quotes."""
assert escape_sql_string("a'b'c") == "a''b''c"
assert escape_sql_string("'''") == "''''''"
# === SQL Injection Attack Scenarios ===
def test_prevent_boolean_injection(self):
"""Test prevention of boolean injection attacks."""
# Classic OR 1=1 attack
malicious_input = "tenant' OR '1'='1"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' OR ''1''=''1"
# When used in SQL, this becomes a safe string literal
sql = f"WHERE tenant_id='{escaped}'"
assert sql == "WHERE tenant_id='tenant'' OR ''1''=''1'"
# The entire input is now a string literal that won't match any tenant
def test_prevent_or_injection(self):
"""Test prevention of OR-based injection."""
malicious_input = "tenant_a' OR tenant_id='tenant_b"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant_a'' OR tenant_id=''tenant_b"
sql = f"WHERE tenant_id='{escaped}'"
# The OR is now part of the string literal, not SQL logic
assert "OR tenant_id=" in sql
# The SQL has: opening ', doubled internal quotes '', and closing '
assert sql == "WHERE tenant_id='tenant_a'' OR tenant_id=''tenant_b'"
def test_prevent_union_injection(self):
"""Test prevention of UNION-based injection."""
malicious_input = "xxx' UNION SELECT password FROM users WHERE '1'='1"
escaped = escape_sql_string(malicious_input)
assert escaped == "xxx'' UNION SELECT password FROM users WHERE ''1''=''1"
# UNION becomes part of the string literal
assert "UNION" in escaped
assert escaped.count("''") == 4 # All internal quotes are doubled
def test_prevent_comment_injection(self):
"""Test prevention of comment-based injection."""
# SQL comment to bypass remaining conditions
malicious_input = "tenant' --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' --"
sql = f"WHERE tenant_id='{escaped}' AND deleted=false"
# The -- is now inside the string, not a SQL comment
assert "--" in sql
assert "AND deleted=false" in sql # This part is NOT commented out
def test_prevent_semicolon_injection(self):
"""Test prevention of semicolon-based multi-statement injection."""
malicious_input = "tenant'; DROP TABLE users; --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant''; DROP TABLE users; --"
# Semicolons and DROP are now part of the string
assert "DROP TABLE" in escaped
def test_prevent_time_based_blind_injection(self):
"""Test prevention of time-based blind SQL injection."""
malicious_input = "tenant' AND SLEEP(5) --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' AND SLEEP(5) --"
# SLEEP becomes part of the string
assert "SLEEP" in escaped
def test_prevent_wildcard_injection(self):
"""Test prevention of wildcard-based injection."""
malicious_input = "tenant' OR tenant_id LIKE '%"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' OR tenant_id LIKE ''%"
# The LIKE and wildcard are now part of the string
assert "LIKE" in escaped
def test_prevent_null_byte_injection(self):
"""Test handling of null bytes."""
# Null bytes can sometimes bypass filters
malicious_input = "tenant\x00' OR '1'='1"
escaped = escape_sql_string(malicious_input)
# Null byte is preserved, but quote is escaped
assert "''1''=''1" in escaped
# === Real-world SAAS Scenarios ===
def test_cross_tenant_access_attempt(self):
"""Test prevention of cross-tenant data access."""
# Attacker tries to access another tenant's data
attacker_input = "tenant_b' OR tenant_id='tenant_a"
escaped = escape_sql_string(attacker_input)
sql = f"SELECT * FROM workflow_runs WHERE tenant_id='{escaped}'"
# The query will look for a tenant literally named "tenant_b' OR tenant_id='tenant_a"
# which doesn't exist - preventing access to either tenant's data
assert "tenant_b'' OR tenant_id=''tenant_a" in sql
def test_cross_app_access_attempt(self):
"""Test prevention of cross-application data access."""
attacker_input = "app1' OR app_id='app2"
escaped = escape_sql_string(attacker_input)
sql = f"WHERE app_id='{escaped}'"
# Cannot access app2's data
assert "app1'' OR app_id=''app2" in sql
def test_bypass_status_filter(self):
"""Test prevention of bypassing status filters."""
# Try to see all statuses instead of just 'running'
attacker_input = "running' OR status LIKE '%"
escaped = escape_sql_string(attacker_input)
sql = f"WHERE status='{escaped}'"
# Status condition is not bypassed
assert "running'' OR status LIKE ''%" in sql
# === Edge Cases ===
def test_escape_only_quotes(self):
"""Test string with only quotes."""
assert escape_sql_string("'") == "''"
assert escape_sql_string("''") == "''''"
def test_escape_mixed_content(self):
"""Test string with mixed quotes and other chars."""
input_str = "It's a 'test' of O'Reilly's code"
escaped = escape_sql_string(input_str)
assert escaped == "It''s a ''test'' of O''Reilly''s code"
def test_escape_unicode_with_quotes(self):
"""Test Unicode strings with quotes."""
input_str = "租户' OR '1'='1"
escaped = escape_sql_string(input_str)
assert escaped == "租户'' OR ''1''=''1"
class TestEscapeIdentifier:
"""Test escape_identifier function."""
def test_escape_uuid(self):
"""Test escaping UUID identifiers."""
uuid = "550e8400-e29b-41d4-a716-446655440000"
assert escape_identifier(uuid) == uuid
def test_escape_alphanumeric_id(self):
"""Test escaping alphanumeric identifiers."""
assert escape_identifier("tenant_123") == "tenant_123"
assert escape_identifier("app-abc-123") == "app-abc-123"
def test_escape_identifier_with_quote(self):
"""Test escaping identifier with single quote."""
malicious = "tenant' OR '1'='1"
escaped = escape_identifier(malicious)
assert escaped == "tenant'' OR ''1''=''1"
def test_identifier_injection_attempt(self):
"""Test prevention of injection through identifiers."""
# Common identifier injection patterns
test_cases = [
("id' OR '1'='1", "id'' OR ''1''=''1"),
("id'; DROP TABLE", "id''; DROP TABLE"),
("id' UNION SELECT", "id'' UNION SELECT"),
]
for malicious, expected in test_cases:
assert escape_identifier(malicious) == expected
class TestSQLInjectionIntegration:
"""Integration tests simulating real SQL construction scenarios."""
def test_complete_where_clause_safety(self):
"""Test that a complete WHERE clause is safe from injection."""
# Simulating typical query construction
tenant_id = "tenant' OR '1'='1"
app_id = "app' UNION SELECT"
run_id = "run' --"
escaped_tenant = escape_identifier(tenant_id)
escaped_app = escape_identifier(app_id)
escaped_run = escape_identifier(run_id)
sql = f"""
SELECT * FROM workflow_runs
WHERE tenant_id='{escaped_tenant}'
AND app_id='{escaped_app}'
AND id='{escaped_run}'
"""
# Verify all special characters are escaped
assert "tenant'' OR ''1''=''1" in sql
assert "app'' UNION SELECT" in sql
assert "run'' --" in sql
# Verify SQL structure is preserved (3 conditions with AND)
assert sql.count("AND") == 2
def test_multiple_conditions_with_injection_attempts(self):
"""Test multiple conditions all attempting injection."""
conditions = {
"tenant_id": "t1' OR tenant_id='t2",
"app_id": "a1' OR app_id='a2",
"status": "running' OR '1'='1",
}
where_parts = []
for field, value in conditions.items():
escaped = escape_sql_string(value)
where_parts.append(f"{field}='{escaped}'")
where_clause = " AND ".join(where_parts)
# All injection attempts are neutralized
assert "t1'' OR tenant_id=''t2" in where_clause
assert "a1'' OR app_id=''a2" in where_clause
assert "running'' OR ''1''=''1" in where_clause
# AND structure is preserved
assert where_clause.count(" AND ") == 2
@pytest.mark.parametrize(
("attack_vector", "description"),
[
("' OR '1'='1", "Boolean injection"),
("' OR '1'='1' --", "Boolean with comment"),
("' UNION SELECT * FROM users --", "Union injection"),
("'; DROP TABLE workflow_runs; --", "Destructive command"),
("' AND SLEEP(10) --", "Time-based blind"),
("' OR tenant_id LIKE '%", "Wildcard injection"),
("admin' --", "Comment bypass"),
("' OR 1=1 LIMIT 1 --", "Limit bypass"),
],
)
def test_common_injection_vectors(self, attack_vector, description):
"""Test protection against common injection attack vectors."""
escaped = escape_sql_string(attack_vector)
# Build SQL
sql = f"WHERE tenant_id='{escaped}'"
# Verify the attack string is now a safe literal
# The key indicator: all internal single quotes are doubled
internal_quotes = escaped.count("''")
original_quotes = attack_vector.count("'")
# Each original quote should be doubled
assert internal_quotes == original_quotes
# Verify SQL has exactly 2 quotes (opening and closing)
assert sql.count("'") >= 2 # At least opening and closing
def test_logstore_specific_scenario(self):
"""Test SQL injection prevention in LogStore-specific scenarios."""
# Simulate LogStore query with window function
tenant_id = "tenant' OR '1'='1"
app_id = "app' UNION SELECT"
escaped_tenant = escape_identifier(tenant_id)
escaped_app = escape_identifier(app_id)
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM workflow_execution_logstore
WHERE tenant_id='{escaped_tenant}'
AND app_id='{escaped_app}'
AND __time__ > 0
) AS subquery WHERE rn = 1
"""
# Complex query structure is maintained
assert "ROW_NUMBER()" in sql
assert "PARTITION BY id" in sql
# Injection attempts are escaped
assert "tenant'' OR ''1''=''1" in sql
assert "app'' UNION SELECT" in sql
# ====================================================================================
# Tests for LogStore Query Syntax (SDK Mode)
# ====================================================================================
class TestLogStoreQueryEscape:
"""Test escape_logstore_query_value for SDK mode query syntax."""
def test_normal_value(self):
"""Test escaping normal alphanumeric value."""
value = "550e8400-e29b-41d4-a716-446655440000"
escaped = escape_logstore_query_value(value)
# Should be wrapped in double quotes
assert escaped == '"550e8400-e29b-41d4-a716-446655440000"'
def test_empty_value(self):
"""Test escaping empty string."""
assert escape_logstore_query_value("") == '""'
def test_value_with_and_keyword(self):
"""Test that 'and' keyword is neutralized when quoted."""
malicious = "value and field:evil"
escaped = escape_logstore_query_value(malicious)
# Should be wrapped in quotes, making 'and' a literal
assert escaped == '"value and field:evil"'
# Simulate using in query
query = f"tenant_id:{escaped}"
assert query == 'tenant_id:"value and field:evil"'
def test_value_with_or_keyword(self):
"""Test that 'or' keyword is neutralized when quoted."""
malicious = "tenant_a or tenant_id:tenant_b"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"tenant_a or tenant_id:tenant_b"'
query = f"tenant_id:{escaped}"
assert "or" in query # Present but as literal string
def test_value_with_not_keyword(self):
"""Test that 'not' keyword is neutralized when quoted."""
malicious = "not field:value"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"not field:value"'
def test_value_with_parentheses(self):
"""Test that parentheses are neutralized when quoted."""
malicious = "(tenant_a or tenant_b)"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"(tenant_a or tenant_b)"'
assert "(" in escaped # Present as literal
assert ")" in escaped # Present as literal
def test_value_with_colon(self):
"""Test that colons are neutralized when quoted."""
malicious = "field:value"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"field:value"'
assert ":" in escaped # Present as literal
def test_value_with_double_quotes(self):
"""Test that internal double quotes are escaped."""
value_with_quotes = 'tenant"test"value'
escaped = escape_logstore_query_value(value_with_quotes)
# Double quotes should be escaped with backslash
assert escaped == '"tenant\\"test\\"value"'
# Should have outer quotes plus escaped inner quotes
assert '\\"' in escaped
def test_value_with_backslash(self):
"""Test that backslashes are escaped."""
value_with_backslash = "tenant\\test"
escaped = escape_logstore_query_value(value_with_backslash)
# Backslash should be escaped
assert escaped == '"tenant\\\\test"'
assert "\\\\" in escaped
def test_value_with_backslash_and_quote(self):
"""Test escaping both backslash and double quote."""
value = 'path\\to\\"file"'
escaped = escape_logstore_query_value(value)
# Both should be escaped
assert escaped == '"path\\\\to\\\\\\"file\\""'
# Verify escape order is correct
assert "\\\\" in escaped # Escaped backslash
assert '\\"' in escaped # Escaped double quote
def test_complex_injection_attempt(self):
"""Test complex injection combining multiple operators."""
malicious = 'tenant_a" or (tenant_id:"tenant_b" and app_id:"evil")'
escaped = escape_logstore_query_value(malicious)
# All special chars should be literals or escaped
assert escaped.startswith('"')
assert escaped.endswith('"')
# Inner double quotes escaped, operators become literals
assert "or" in escaped
assert "and" in escaped
assert '\\"' in escaped # Escaped quotes
def test_only_backslash(self):
"""Test escaping a single backslash."""
assert escape_logstore_query_value("\\") == '"\\\\"'
def test_only_double_quote(self):
"""Test escaping a single double quote."""
assert escape_logstore_query_value('"') == '"\\""'
def test_multiple_backslashes(self):
"""Test escaping multiple consecutive backslashes."""
assert escape_logstore_query_value("\\\\\\") == '"\\\\\\\\\\\\"' # 3 backslashes -> 6
def test_escape_sequence_like_input(self):
"""Test that existing escape sequences are properly escaped."""
# Input looks like already escaped, but we still escape it
value = 'value\\"test'
escaped = escape_logstore_query_value(value)
# \\ -> \\\\, " -> \"
assert escaped == '"value\\\\\\"test"'
@pytest.mark.parametrize(
("attack_scenario", "field", "malicious_value"),
[
("Cross-tenant via OR", "tenant_id", "tenant_a or tenant_id:tenant_b"),
("Cross-app via AND", "app_id", "app_a and (app_id:app_b or app_id:app_c)"),
("Boolean logic", "status", "succeeded or status:failed"),
("Negation", "tenant_id", "not tenant_a"),
("Field injection", "run_id", "run123 and tenant_id:evil_tenant"),
("Parentheses grouping", "app_id", "app1 or (app_id:app2 and tenant_id:tenant2)"),
("Quote breaking attempt", "tenant_id", 'tenant" or "1"="1'),
("Backslash escape bypass", "app_id", "app\\ and app_id:evil"),
],
)
def test_logstore_query_injection_scenarios(attack_scenario: str, field: str, malicious_value: str):
"""Test that various LogStore query injection attempts are neutralized."""
escaped = escape_logstore_query_value(malicious_value)
# Build query
query = f"{field}:{escaped}"
# All operators should be within quoted string (literals)
assert escaped.startswith('"')
assert escaped.endswith('"')
# Verify the full query structure is safe
assert query.count(":") >= 1 # At least the main field:value separator

View File

@ -0,0 +1,59 @@
"""Unit tests for traceparent header propagation in EnterpriseRequest.
This test module verifies that the W3C traceparent header is properly
generated and included in HTTP requests made by EnterpriseRequest.
"""
from unittest.mock import MagicMock, patch
import pytest
from services.enterprise.base import EnterpriseRequest
class TestTraceparentPropagation:
"""Unit tests for traceparent header propagation."""
@pytest.fixture
def mock_enterprise_config(self):
"""Mock EnterpriseRequest configuration."""
with (
patch.object(EnterpriseRequest, "base_url", "https://enterprise-api.example.com"),
patch.object(EnterpriseRequest, "secret_key", "test-secret-key"),
patch.object(EnterpriseRequest, "secret_key_header", "Enterprise-Api-Secret-Key"),
):
yield
@pytest.fixture
def mock_httpx_client(self):
"""Mock httpx.Client for testing."""
with patch("services.enterprise.base.httpx.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value.__enter__.return_value = mock_client
mock_client_class.return_value.__exit__.return_value = None
# Setup default response
mock_response = MagicMock()
mock_response.json.return_value = {"result": "success"}
mock_client.request.return_value = mock_response
yield mock_client
def test_traceparent_header_included_when_generated(self, mock_enterprise_config, mock_httpx_client):
"""Test that traceparent header is included when successfully generated."""
# Arrange
expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent):
# Act
EnterpriseRequest.send_request("GET", "/test")
# Assert
mock_httpx_client.request.assert_called_once()
call_args = mock_httpx_client.request.call_args
headers = call_args[1]["headers"]
assert "traceparent" in headers
assert headers["traceparent"] == expected_traceparent
assert headers["Content-Type"] == "application/json"
assert headers["Enterprise-Api-Secret-Key"] == "test-secret-key"

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest
from configs import dify_config
from models.account import Account
from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import (
AccountAlreadyInTenantError,
@ -1147,9 +1147,13 @@ class TestRegisterService:
mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
with patch("services.account_service.Session") as mock_session_class:
with (
patch("services.account_service.Session") as mock_session_class,
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = None
# Mock RegisterService.register
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
@ -1182,9 +1186,59 @@ class TestRegisterService:
email="newuser@example.com",
name="newuser",
language="en-US",
status="pending",
status=AccountStatus.PENDING,
is_setup=True,
)
mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session)
def test_invite_new_member_normalizes_new_account_email(
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
):
"""Ensure inviting with mixed-case email normalizes before registering."""
mock_tenant = MagicMock()
mock_tenant.id = "tenant-456"
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
mixed_email = "Invitee@Example.com"
mock_session = MagicMock()
with (
patch("services.account_service.Session") as mock_session_class,
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = None
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="new-user-789", email="invitee@example.com", name="invitee", status="pending"
)
with patch("services.account_service.RegisterService.register") as mock_register:
mock_register.return_value = mock_new_account
with (
patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
patch("services.account_service.TenantService.switch_tenant") as mock_switch_tenant,
patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token,
):
mock_generate_token.return_value = "invite-token-abc"
RegisterService.invite_new_member(
tenant=mock_tenant,
email=mixed_email,
language="en-US",
role="normal",
inviter=mock_inviter,
)
mock_register.assert_called_once_with(
email="invitee@example.com",
name="invitee",
language="en-US",
status=AccountStatus.PENDING,
is_setup=True,
)
mock_lookup.assert_called_once_with(mixed_email, session=mock_session)
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account)
@ -1207,9 +1261,13 @@ class TestRegisterService:
mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
with patch("services.account_service.Session") as mock_session_class:
with (
patch("services.account_service.Session") as mock_session_class,
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = mock_existing_account
# Mock the db.session.query for TenantAccountJoin
mock_db_query = MagicMock()
@ -1238,6 +1296,7 @@ class TestRegisterService:
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
mock_task_dependencies.delay.assert_called_once()
mock_lookup.assert_called_once_with("existing@example.com", session=mock_session)
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
"""Test inviting a member who is already in the tenant."""
@ -1251,7 +1310,6 @@ class TestRegisterService:
# Mock database queries
query_results = {
("Account", "email", "existing@example.com"): mock_existing_account,
(
"TenantAccountJoin",
"tenant_id",
@ -1261,7 +1319,11 @@ class TestRegisterService:
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Mock TenantService methods
with patch("services.account_service.TenantService.check_member_permission") as mock_check_permission:
with (
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
):
mock_lookup.return_value = mock_existing_account
# Execute test and verify exception
self._assert_exception_raised(
AccountAlreadyInTenantError,
@ -1272,6 +1334,7 @@ class TestRegisterService:
role="normal",
inviter=mock_inviter,
)
mock_lookup.assert_called_once()
def test_invite_new_member_no_inviter(self):
"""Test inviting a member without providing an inviter."""
@ -1497,6 +1560,30 @@ class TestRegisterService:
# Verify results
assert result is None
def test_get_invitation_with_case_fallback_returns_initial_match(self):
"""Fallback helper should return the initial invitation when present."""
invitation = {"workspace_id": "tenant-456"}
with patch(
"services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation
) as mock_get:
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
assert result == invitation
mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123")
def test_get_invitation_with_case_fallback_retries_with_lowercase(self):
"""Fallback helper should retry with lowercase email when needed."""
invitation = {"workspace_id": "tenant-456"}
with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get:
mock_get.side_effect = [None, invitation]
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
assert result == invitation
assert mock_get.call_args_list == [
(("tenant-456", "User@Test.com", "token-123"),),
(("tenant-456", "user@test.com", "token-123"),),
]
# ==================== Helper Method Tests ====================
def test_get_invitation_token_key(self):

View File

@ -453,15 +453,15 @@ wheels = [
[[package]]
name = "azure-core"
version = "1.36.0"
version = "1.38.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "requests" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0a/c4/d4ff3bc3ddf155156460bff340bbe9533f99fac54ddea165f35a8619f162/azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7", size = 351139, upload-time = "2025-10-15T00:33:49.083Z" }
sdist = { url = "https://files.pythonhosted.org/packages/dc/1b/e503e08e755ea94e7d3419c9242315f888fc664211c90d032e40479022bf/azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993", size = 363033, upload-time = "2026-01-12T17:03:05.535Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b", size = 213302, upload-time = "2025-10-15T00:33:51.058Z" },
{ url = "https://files.pythonhosted.org/packages/fc/d8/b8fcba9464f02b121f39de2db2bf57f0b216fe11d014513d666e8634380d/azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335", size = 217825, upload-time = "2026-01-12T17:03:07.291Z" },
]
[[package]]
@ -1368,7 +1368,7 @@ wheels = [
[[package]]
name = "dify-api"
version = "1.11.2"
version = "1.11.3"
source = { virtual = "." }
dependencies = [
{ name = "aliyun-log-python-sdk" },
@ -1965,11 +1965,11 @@ wheels = [
[[package]]
name = "filelock"
version = "3.20.0"
version = "3.20.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" }
sdist = { url = "https://files.pythonhosted.org/packages/1d/65/ce7f1b70157833bf3cb851b556a37d4547ceafc158aa9b34b36782f23696/filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1", size = 19485, upload-time = "2026-01-09T17:55:05.421Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" },
{ url = "https://files.pythonhosted.org/packages/b5/36/7fb70f04bf00bc646cd5bb45aa9eddb15e19437a28b8fb2b4a5249fac770/filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1", size = 16701, upload-time = "2026-01-09T17:55:04.334Z" },
]
[[package]]

View File

@ -1037,18 +1037,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Options:
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
# - extensions.logstore.repositories.logstore_workflow_execution_repository.LogstoreWorkflowExecutionRepository
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
# Core workflow node execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
# - extensions.logstore.repositories.logstore_workflow_node_execution_repository.LogstoreWorkflowNodeExecutionRepository
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation
# Options:
# - repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository (default)
# - extensions.logstore.repositories.logstore_api_workflow_run_repository.LogstoreAPIWorkflowRunRepository
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# API workflow node execution repository implementation
# Options:
# - repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository (default)
# - extensions.logstore.repositories.logstore_api_workflow_node_execution_repository.LogstoreAPIWorkflowNodeExecutionRepository
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# Workflow log cleanup configuration

View File

@ -21,7 +21,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.2
image: langgenius/dify-web:1.11.3
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -704,7 +704,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -746,7 +746,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -785,7 +785,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -815,7 +815,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.2
image: langgenius/dify-web:1.11.3
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -31,6 +31,8 @@ NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false
# The timeout for the text generation in millisecond
NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000
# Used by web/docker/entrypoint.sh to overwrite/export NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS at container startup (Docker only)
TEXT_GENERATION_TIMEOUT_MS=60000
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
NEXT_PUBLIC_CSP_WHITELIST=

View File

@ -53,6 +53,7 @@ vi.mock('@/context/global-public-context', () => {
)
return {
useGlobalPublicStore,
useIsSystemFeaturesPending: () => false,
}
})

View File

@ -9,8 +9,8 @@ import {
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
} from '@/app/education-apply/constants'
import { fetchSetupStatus } from '@/service/common'
import { sendGAEvent } from '@/utils/gtag'
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect'
import { trackEvent } from './base/amplitude'
@ -33,15 +33,8 @@ export const AppInitializer = ({
const isSetupFinished = useCallback(async () => {
try {
if (localStorage.getItem('setup_status') === 'finished')
return true
const setUpStatus = await fetchSetupStatus()
if (setUpStatus.step !== 'finished') {
localStorage.removeItem('setup_status')
return false
}
localStorage.setItem('setup_status', 'finished')
return true
const setUpStatus = await fetchSetupStatusWithCache()
return setUpStatus.step === 'finished'
}
catch (error) {
console.error(error)

View File

@ -34,13 +34,6 @@ vi.mock('@/context/app-context', () => ({
}),
}))
vi.mock('@/service/common', () => ({
fetchCurrentWorkspace: vi.fn(),
fetchLangGeniusVersion: vi.fn(),
fetchUserProfile: vi.fn(),
getSystemFeatures: vi.fn(),
}))
vi.mock('@/service/access-control', () => ({
useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args),
useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args),
@ -125,7 +118,6 @@ const resetAccessControlStore = () => {
const resetGlobalStore = () => {
useGlobalPublicStore.setState({
systemFeatures: defaultSystemFeatures,
isGlobalPending: false,
})
}

View File

@ -54,7 +54,7 @@ const pageNameEnrichmentPlugin = (): amplitude.Types.EnrichmentPlugin => {
}
const AmplitudeProvider: FC<IAmplitudeProps> = ({
sessionReplaySampleRate = 1,
sessionReplaySampleRate = 0.5,
}) => {
useEffect(() => {
// Only enable in Saas edition with valid API key

View File

@ -170,8 +170,12 @@ describe('useChatWithHistory', () => {
await waitFor(() => {
expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', false, 'app-1')
})
expect(result.current.pinnedConversationList).toEqual(pinnedData.data)
expect(result.current.conversationList).toEqual(listData.data)
await waitFor(() => {
expect(result.current.pinnedConversationList).toEqual(pinnedData.data)
})
await waitFor(() => {
expect(result.current.conversationList).toEqual(listData.data)
})
})
})

View File

@ -3,7 +3,8 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { useAppContext } from '@/context/app-context'
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing'
import { fetchSubscriptionUrls } from '@/service/billing'
import { consoleClient } from '@/service/client'
import Toast from '../../../../base/toast'
import { ALL_PLANS } from '../../../config'
import { Plan } from '../../../type'
@ -21,10 +22,15 @@ vi.mock('@/context/app-context', () => ({
}))
vi.mock('@/service/billing', () => ({
fetchBillingUrl: vi.fn(),
fetchSubscriptionUrls: vi.fn(),
}))
vi.mock('@/service/client', () => ({
consoleClient: {
billingUrl: vi.fn(),
},
}))
vi.mock('@/hooks/use-async-window-open', () => ({
useAsyncWindowOpen: vi.fn(),
}))
@ -37,7 +43,7 @@ vi.mock('../../assets', () => ({
const mockUseAppContext = useAppContext as Mock
const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock
const mockFetchBillingUrl = fetchBillingUrl as Mock
const mockBillingUrl = consoleClient.billingUrl as Mock
const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock
const mockToastNotify = Toast.notify as Mock
@ -69,7 +75,7 @@ beforeEach(() => {
vi.clearAllMocks()
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open()))
mockFetchBillingUrl.mockResolvedValue({ url: 'https://billing.example' })
mockBillingUrl.mockResolvedValue({ url: 'https://billing.example' })
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://subscription.example' })
assignedHref = ''
})
@ -143,7 +149,7 @@ describe('CloudPlanItem', () => {
type: 'error',
message: 'billing.buyPermissionDeniedTip',
}))
expect(mockFetchBillingUrl).not.toHaveBeenCalled()
expect(mockBillingUrl).not.toHaveBeenCalled()
})
it('should open billing portal when upgrading current paid plan', async () => {
@ -162,7 +168,7 @@ describe('CloudPlanItem', () => {
fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' }))
await waitFor(() => {
expect(mockFetchBillingUrl).toHaveBeenCalledTimes(1)
expect(mockBillingUrl).toHaveBeenCalledTimes(1)
})
expect(openWindow).toHaveBeenCalledTimes(1)
})

View File

@ -6,7 +6,8 @@ import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { useAppContext } from '@/context/app-context'
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing'
import { fetchSubscriptionUrls } from '@/service/billing'
import { consoleClient } from '@/service/client'
import Toast from '../../../../base/toast'
import { ALL_PLANS } from '../../../config'
import { Plan } from '../../../type'
@ -76,7 +77,7 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
try {
if (isCurrentPaidPlan) {
await openAsyncWindow(async () => {
const res = await fetchBillingUrl()
const res = await consoleClient.billingUrl()
if (res.url)
return res.url
throw new Error('Failed to open billing page')

View File

@ -30,8 +30,8 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) =
category: PluginCategoryEnum.datasource,
exclude,
type: 'plugin',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
else {
@ -39,10 +39,10 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) =
query: '',
category: PluginCategoryEnum.datasource,
type: 'plugin',
pageSize: 1000,
page_size: 1000,
exclude,
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
}, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])

View File

@ -275,8 +275,8 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
category: PluginCategoryEnum.model,
exclude,
type: 'plugin',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
else {
@ -284,10 +284,10 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
query: '',
category: PluginCategoryEnum.model,
type: 'plugin',
pageSize: 1000,
page_size: 1000,
exclude,
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
}, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])

View File

@ -100,11 +100,11 @@ export const useMarketplacePlugins = () => {
const [queryParams, setQueryParams] = useState<PluginsSearchParams>()
const normalizeParams = useCallback((pluginsSearchParams: PluginsSearchParams) => {
const pageSize = pluginsSearchParams.pageSize || 40
const page_size = pluginsSearchParams.page_size || 40
return {
...pluginsSearchParams,
pageSize,
page_size,
}
}, [])
@ -116,20 +116,20 @@ export const useMarketplacePlugins = () => {
plugins: [] as Plugin[],
total: 0,
page: 1,
pageSize: 40,
page_size: 40,
}
}
const params = normalizeParams(queryParams)
const {
query,
sortBy,
sortOrder,
sort_by,
sort_order,
category,
tags,
exclude,
type,
pageSize,
page_size,
} = params
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
@ -137,10 +137,10 @@ export const useMarketplacePlugins = () => {
const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
body: {
page: pageParam,
page_size: pageSize,
page_size,
query,
sort_by: sortBy,
sort_order: sortOrder,
sort_by,
sort_order,
category: category !== 'all' ? category : '',
tags,
exclude,
@ -154,7 +154,7 @@ export const useMarketplacePlugins = () => {
plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)),
total: res.data.total,
page: pageParam,
pageSize,
page_size,
}
}
catch {
@ -162,13 +162,13 @@ export const useMarketplacePlugins = () => {
plugins: [],
total: 0,
page: pageParam,
pageSize,
page_size,
}
}
},
getNextPageParam: (lastPage) => {
const nextPage = lastPage.page + 1
const loaded = lastPage.page * lastPage.pageSize
const loaded = lastPage.page * lastPage.page_size
return loaded < (lastPage.total || 0) ? nextPage : undefined
},
initialPageParam: 1,

View File

@ -2,8 +2,8 @@ import type { SearchParams } from 'nuqs'
import { dehydrate, HydrationBoundary } from '@tanstack/react-query'
import { createLoader } from 'nuqs/server'
import { getQueryClientServer } from '@/context/query-client-server'
import { marketplaceQuery } from '@/service/client'
import { PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants'
import { marketplaceKeys } from './query'
import { marketplaceSearchParamsParsers } from './search-params'
import { getCollectionsParams, getMarketplaceCollectionsAndPlugins } from './utils'
@ -23,7 +23,7 @@ async function getDehydratedState(searchParams?: Promise<SearchParams>) {
const queryClient = getQueryClientServer()
await queryClient.prefetchQuery({
queryKey: marketplaceKeys.collections(getCollectionsParams(params.category)),
queryKey: marketplaceQuery.collections.queryKey({ input: { query: getCollectionsParams(params.category) } }),
queryFn: () => getMarketplaceCollectionsAndPlugins(getCollectionsParams(params.category)),
})
return dehydrate(queryClient)

View File

@ -60,10 +60,10 @@ vi.mock('@/service/use-plugins', () => ({
// Mock tanstack query
const mockFetchNextPage = vi.fn()
const mockHasNextPage = false
let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, pageSize: number }> } | undefined
let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, page_size: number }> } | undefined
let capturedInfiniteQueryFn: ((ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown>) | null = null
let capturedQueryFn: ((ctx: { signal: AbortSignal }) => Promise<unknown>) | null = null
let capturedGetNextPageParam: ((lastPage: { page: number, pageSize: number, total: number }) => number | undefined) | null = null
let capturedGetNextPageParam: ((lastPage: { page: number, page_size: number, total: number }) => number | undefined) | null = null
vi.mock('@tanstack/react-query', () => ({
useQuery: vi.fn(({ queryFn, enabled }: { queryFn: (ctx: { signal: AbortSignal }) => Promise<unknown>, enabled: boolean }) => {
@ -83,7 +83,7 @@ vi.mock('@tanstack/react-query', () => ({
}),
useInfiniteQuery: vi.fn(({ queryFn, getNextPageParam, enabled: _enabled }: {
queryFn: (ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown>
getNextPageParam: (lastPage: { page: number, pageSize: number, total: number }) => number | undefined
getNextPageParam: (lastPage: { page: number, page_size: number, total: number }) => number | undefined
enabled: boolean
}) => {
// Capture queryFn and getNextPageParam for later testing
@ -97,9 +97,9 @@ vi.mock('@tanstack/react-query', () => ({
// Call getNextPageParam to increase coverage
if (getNextPageParam) {
// Test with more data available
getNextPageParam({ page: 1, pageSize: 40, total: 100 })
getNextPageParam({ page: 1, page_size: 40, total: 100 })
// Test with no more data
getNextPageParam({ page: 3, pageSize: 40, total: 100 })
getNextPageParam({ page: 3, page_size: 40, total: 100 })
}
return {
data: mockInfiniteQueryData,
@ -151,6 +151,7 @@ vi.mock('@/service/base', () => ({
// Mock config
vi.mock('@/config', () => ({
API_PREFIX: '/api',
APP_VERSION: '1.0.0',
IS_MARKETPLACE: false,
MARKETPLACE_API_PREFIX: 'https://marketplace.dify.ai/api/v1',
@ -731,10 +732,10 @@ describe('useMarketplacePlugins', () => {
expect(() => {
result.current.queryPlugins({
query: 'test',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
category: 'tool',
pageSize: 20,
page_size: 20,
})
}).not.toThrow()
})
@ -747,7 +748,7 @@ describe('useMarketplacePlugins', () => {
result.current.queryPlugins({
query: 'test',
type: 'bundle',
pageSize: 40,
page_size: 40,
})
}).not.toThrow()
})
@ -798,8 +799,8 @@ describe('useMarketplacePlugins', () => {
result.current.queryPlugins({
query: 'test',
category: 'all',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}).not.toThrow()
})
@ -824,7 +825,7 @@ describe('useMarketplacePlugins', () => {
expect(() => {
result.current.queryPlugins({
query: 'test',
pageSize: 100,
page_size: 100,
})
}).not.toThrow()
})
@ -843,7 +844,7 @@ describe('Hooks queryFn Coverage', () => {
// Set mock data to have pages
mockInfiniteQueryData = {
pages: [
{ plugins: [{ name: 'plugin1' }], total: 10, page: 1, pageSize: 40 },
{ plugins: [{ name: 'plugin1' }], total: 10, page: 1, page_size: 40 },
],
}
@ -863,8 +864,8 @@ describe('Hooks queryFn Coverage', () => {
it('should expose page and total from infinite query data', async () => {
mockInfiniteQueryData = {
pages: [
{ plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, pageSize: 40 },
{ plugins: [{ name: 'plugin3' }], total: 20, page: 2, pageSize: 40 },
{ plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, page_size: 40 },
{ plugins: [{ name: 'plugin3' }], total: 20, page: 2, page_size: 40 },
],
}
@ -893,7 +894,7 @@ describe('Hooks queryFn Coverage', () => {
it('should return total from first page when query is set and data exists', async () => {
mockInfiniteQueryData = {
pages: [
{ plugins: [], total: 50, page: 1, pageSize: 40 },
{ plugins: [], total: 50, page: 1, page_size: 40 },
],
}
@ -917,8 +918,8 @@ describe('Hooks queryFn Coverage', () => {
type: 'plugin',
query: 'search test',
category: 'model',
sortBy: 'version_updated_at',
sortOrder: 'ASC',
sort_by: 'version_updated_at',
sort_order: 'ASC',
})
expect(result.current).toBeDefined()
@ -1027,13 +1028,13 @@ describe('Advanced Hook Integration', () => {
// Test with all possible parameters
result.current.queryPlugins({
query: 'comprehensive test',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
category: 'tool',
tags: ['tag1', 'tag2'],
exclude: ['excluded-plugin'],
type: 'plugin',
pageSize: 50,
page_size: 50,
})
expect(result.current).toBeDefined()
@ -1081,9 +1082,9 @@ describe('Direct queryFn Coverage', () => {
result.current.queryPlugins({
query: 'direct test',
category: 'tool',
sortBy: 'install_count',
sortOrder: 'DESC',
pageSize: 40,
sort_by: 'install_count',
sort_order: 'DESC',
page_size: 40,
})
// Now queryFn should be captured and enabled
@ -1255,7 +1256,7 @@ describe('Direct queryFn Coverage', () => {
result.current.queryPlugins({
query: 'structure test',
pageSize: 20,
page_size: 20,
})
if (capturedInfiniteQueryFn) {
@ -1264,14 +1265,14 @@ describe('Direct queryFn Coverage', () => {
plugins: unknown[]
total: number
page: number
pageSize: number
page_size: number
}
// Verify the returned structure
expect(response).toHaveProperty('plugins')
expect(response).toHaveProperty('total')
expect(response).toHaveProperty('page')
expect(response).toHaveProperty('pageSize')
expect(response).toHaveProperty('page_size')
}
})
})
@ -1296,7 +1297,7 @@ describe('flatMap Coverage', () => {
],
total: 5,
page: 1,
pageSize: 40,
page_size: 40,
},
{
plugins: [
@ -1304,7 +1305,7 @@ describe('flatMap Coverage', () => {
],
total: 5,
page: 2,
pageSize: 40,
page_size: 40,
},
],
}
@ -1336,8 +1337,8 @@ describe('flatMap Coverage', () => {
it('should test hook with pages data for flatMap path', async () => {
mockInfiniteQueryData = {
pages: [
{ plugins: [], total: 100, page: 1, pageSize: 40 },
{ plugins: [], total: 100, page: 2, pageSize: 40 },
{ plugins: [], total: 100, page: 1, page_size: 40 },
{ plugins: [], total: 100, page: 2, page_size: 40 },
],
}
@ -1371,7 +1372,7 @@ describe('flatMap Coverage', () => {
plugins: unknown[]
total: number
page: number
pageSize: number
page_size: number
}
// When error is caught, should return fallback data
expect(response.plugins).toEqual([])
@ -1392,15 +1393,15 @@ describe('flatMap Coverage', () => {
// Test getNextPageParam function directly
if (capturedGetNextPageParam) {
// When there are more pages
const nextPage = capturedGetNextPageParam({ page: 1, pageSize: 40, total: 100 })
const nextPage = capturedGetNextPageParam({ page: 1, page_size: 40, total: 100 })
expect(nextPage).toBe(2)
// When all data is loaded
const noMorePages = capturedGetNextPageParam({ page: 3, pageSize: 40, total: 100 })
const noMorePages = capturedGetNextPageParam({ page: 3, page_size: 40, total: 100 })
expect(noMorePages).toBeUndefined()
// Edge case: exactly at boundary
const atBoundary = capturedGetNextPageParam({ page: 2, pageSize: 50, total: 100 })
const atBoundary = capturedGetNextPageParam({ page: 2, page_size: 50, total: 100 })
expect(atBoundary).toBeUndefined()
}
})
@ -1427,7 +1428,7 @@ describe('flatMap Coverage', () => {
plugins: unknown[]
total: number
page: number
pageSize: number
page_size: number
}
// Catch block should return fallback values
expect(response.plugins).toEqual([])
@ -1446,7 +1447,7 @@ describe('flatMap Coverage', () => {
plugins: [{ name: 'test-plugin-1' }, { name: 'test-plugin-2' }],
total: 10,
page: 1,
pageSize: 40,
page_size: 40,
},
],
}
@ -1489,9 +1490,12 @@ describe('Async Utils', () => {
{ type: 'plugin', org: 'test', name: 'plugin2' },
]
globalThis.fetch = vi.fn().mockResolvedValue({
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
})
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const { getMarketplacePluginsByCollectionId } = await import('./utils')
const result = await getMarketplacePluginsByCollectionId('test-collection', {
@ -1514,19 +1518,26 @@ describe('Async Utils', () => {
})
it('should pass abort signal when provided', async () => {
const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }]
globalThis.fetch = vi.fn().mockResolvedValue({
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
})
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const controller = new AbortController()
const { getMarketplacePluginsByCollectionId } = await import('./utils')
await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal })
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
expect(globalThis.fetch).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({ signal: controller.signal }),
expect.any(Request),
expect.any(Object),
)
const call = vi.mocked(globalThis.fetch).mock.calls[0]
const request = call[0] as Request
expect(request.url).toContain('test-collection')
})
})
@ -1535,19 +1546,25 @@ describe('Async Utils', () => {
const mockCollections = [
{ name: 'collection1', label: {}, description: {}, rule: '', created_at: '', updated_at: '' },
]
const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }]
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
let callCount = 0
globalThis.fetch = vi.fn().mockImplementation(() => {
callCount++
if (callCount === 1) {
return Promise.resolve({
json: () => Promise.resolve({ data: { collections: mockCollections } }),
})
return Promise.resolve(
new Response(JSON.stringify({ data: { collections: mockCollections } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
}
return Promise.resolve({
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
})
return Promise.resolve(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
})
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
@ -1571,9 +1588,12 @@ describe('Async Utils', () => {
})
it('should append condition and type to URL when provided', async () => {
globalThis.fetch = vi.fn().mockResolvedValue({
json: () => Promise.resolve({ data: { collections: [] } }),
})
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { collections: [] } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
await getMarketplaceCollectionsAndPlugins({
@ -1581,10 +1601,11 @@ describe('Async Utils', () => {
type: 'bundle',
})
expect(globalThis.fetch).toHaveBeenCalledWith(
expect.stringContaining('condition=category=tool'),
expect.any(Object),
)
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
expect(globalThis.fetch).toHaveBeenCalled()
const call = vi.mocked(globalThis.fetch).mock.calls[0]
const request = call[0] as Request
expect(request.url).toContain('condition=category%3Dtool')
})
})
})

View File

@ -1,22 +1,14 @@
import type { CollectionsAndPluginsSearchParams, PluginsSearchParams } from './types'
import type { PluginsSearchParams } from './types'
import type { MarketPlaceInputs } from '@/contract/router'
import { useInfiniteQuery, useQuery } from '@tanstack/react-query'
import { marketplaceQuery } from '@/service/client'
import { getMarketplaceCollectionsAndPlugins, getMarketplacePlugins } from './utils'
// TODO: Avoid manual maintenance of query keys and better service management,
// https://github.com/langgenius/dify/issues/30342
export const marketplaceKeys = {
all: ['marketplace'] as const,
collections: (params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collections', params] as const,
collectionPlugins: (collectionId: string, params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collectionPlugins', collectionId, params] as const,
plugins: (params?: PluginsSearchParams) => [...marketplaceKeys.all, 'plugins', params] as const,
}
export function useMarketplaceCollectionsAndPlugins(
collectionsParams: CollectionsAndPluginsSearchParams,
collectionsParams: MarketPlaceInputs['collections']['query'],
) {
return useQuery({
queryKey: marketplaceKeys.collections(collectionsParams),
queryKey: marketplaceQuery.collections.queryKey({ input: { query: collectionsParams } }),
queryFn: ({ signal }) => getMarketplaceCollectionsAndPlugins(collectionsParams, { signal }),
})
}
@ -25,11 +17,16 @@ export function useMarketplacePlugins(
queryParams: PluginsSearchParams | undefined,
) {
return useInfiniteQuery({
queryKey: marketplaceKeys.plugins(queryParams),
queryKey: marketplaceQuery.searchAdvanced.queryKey({
input: {
body: queryParams!,
params: { kind: queryParams?.type === 'bundle' ? 'bundles' : 'plugins' },
},
}),
queryFn: ({ pageParam = 1, signal }) => getMarketplacePlugins(queryParams, pageParam, signal),
getNextPageParam: (lastPage) => {
const nextPage = lastPage.page + 1
const loaded = lastPage.page * lastPage.pageSize
const loaded = lastPage.page * lastPage.page_size
return loaded < (lastPage.total || 0) ? nextPage : undefined
},
initialPageParam: 1,

View File

@ -26,8 +26,8 @@ export function useMarketplaceData() {
query: searchPluginText,
category: activePluginType === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginType,
tags: filterPluginTags,
sortBy: sort.sortBy,
sortOrder: sort.sortOrder,
sort_by: sort.sortBy,
sort_order: sort.sortOrder,
type: getMarketplaceListFilterType(activePluginType),
}
}, [isSearchMode, searchPluginText, activePluginType, filterPluginTags, sort])

View File

@ -30,9 +30,9 @@ export type MarketplaceCollectionPluginsResponse = {
export type PluginsSearchParams = {
query: string
page?: number
pageSize?: number
sortBy?: string
sortOrder?: string
page_size?: number
sort_by?: string
sort_order?: string
category?: string
tags?: string[]
exclude?: string[]

View File

@ -4,14 +4,12 @@ import type {
MarketplaceCollection,
PluginsSearchParams,
} from '@/app/components/plugins/marketplace/types'
import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types'
import type { Plugin } from '@/app/components/plugins/types'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import {
APP_VERSION,
IS_MARKETPLACE,
MARKETPLACE_API_PREFIX,
} from '@/config'
import { postMarketplace } from '@/service/base'
import { marketplaceClient } from '@/service/client'
import { getMarketplaceUrl } from '@/utils/var'
import { PLUGIN_TYPE_SEARCH_MAP } from './constants'
@ -19,10 +17,6 @@ type MarketplaceFetchOptions = {
signal?: AbortSignal
}
const getMarketplaceHeaders = () => new Headers({
'X-Dify-Version': !IS_MARKETPLACE ? APP_VERSION : '999.0.0',
})
export const getPluginIconInMarketplace = (plugin: Plugin) => {
if (plugin.type === 'bundle')
return `${MARKETPLACE_API_PREFIX}/bundles/${plugin.org}/${plugin.name}/icon`
@ -65,24 +59,15 @@ export const getMarketplacePluginsByCollectionId = async (
let plugins: Plugin[] = []
try {
const url = `${MARKETPLACE_API_PREFIX}/collections/${collectionId}/plugins`
const headers = getMarketplaceHeaders()
const marketplaceCollectionPluginsData = await globalThis.fetch(
url,
{
cache: 'no-store',
method: 'POST',
headers,
signal: options?.signal,
body: JSON.stringify({
category: query?.category,
exclude: query?.exclude,
type: query?.type,
}),
const marketplaceCollectionPluginsDataJson = await marketplaceClient.collectionPlugins({
params: {
collectionId,
},
)
const marketplaceCollectionPluginsDataJson = await marketplaceCollectionPluginsData.json()
plugins = (marketplaceCollectionPluginsDataJson.data.plugins || []).map((plugin: Plugin) => getFormattedPlugin(plugin))
body: query,
}, {
signal: options?.signal,
})
plugins = (marketplaceCollectionPluginsDataJson.data?.plugins || []).map(plugin => getFormattedPlugin(plugin))
}
// eslint-disable-next-line unused-imports/no-unused-vars
catch (e) {
@ -99,22 +84,16 @@ export const getMarketplaceCollectionsAndPlugins = async (
let marketplaceCollections: MarketplaceCollection[] = []
let marketplaceCollectionPluginsMap: Record<string, Plugin[]> = {}
try {
let marketplaceUrl = `${MARKETPLACE_API_PREFIX}/collections?page=1&page_size=100`
if (query?.condition)
marketplaceUrl += `&condition=${query.condition}`
if (query?.type)
marketplaceUrl += `&type=${query.type}`
const headers = getMarketplaceHeaders()
const marketplaceCollectionsData = await globalThis.fetch(
marketplaceUrl,
{
headers,
cache: 'no-store',
signal: options?.signal,
const marketplaceCollectionsDataJson = await marketplaceClient.collections({
query: {
...query,
page: 1,
page_size: 100,
},
)
const marketplaceCollectionsDataJson = await marketplaceCollectionsData.json()
marketplaceCollections = marketplaceCollectionsDataJson.data.collections || []
}, {
signal: options?.signal,
})
marketplaceCollections = marketplaceCollectionsDataJson.data?.collections || []
await Promise.all(marketplaceCollections.map(async (collection: MarketplaceCollection) => {
const plugins = await getMarketplacePluginsByCollectionId(collection.name, query, options)
@ -143,42 +122,42 @@ export const getMarketplacePlugins = async (
plugins: [] as Plugin[],
total: 0,
page: 1,
pageSize: 40,
page_size: 40,
}
}
const {
query,
sortBy,
sortOrder,
sort_by,
sort_order,
category,
tags,
type,
pageSize = 40,
page_size = 40,
} = queryParams
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
try {
const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
const res = await marketplaceClient.searchAdvanced({
params: {
kind: type === 'bundle' ? 'bundles' : 'plugins',
},
body: {
page: pageParam,
page_size: pageSize,
page_size,
query,
sort_by: sortBy,
sort_order: sortOrder,
sort_by,
sort_order,
category: category !== 'all' ? category : '',
tags,
type,
},
signal,
})
}, { signal })
const resPlugins = res.data.bundles || res.data.plugins || []
return {
plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)),
total: res.data.total,
page: pageParam,
pageSize,
page_size,
}
}
catch {
@ -186,7 +165,7 @@ export const getMarketplacePlugins = async (
plugins: [],
total: 0,
page: pageParam,
pageSize,
page_size,
}
}
}

View File

@ -1606,6 +1606,7 @@ export const useNodesInteractions = () => {
const offsetX = currentPosition.x - x
const offsetY = currentPosition.y - y
let idMapping: Record<string, string> = {}
const parentChildrenToAppend: { parentId: string, childId: string, childType: BlockEnum }[] = []
clipboardElements.forEach((nodeToPaste, index) => {
const nodeType = nodeToPaste.data.type
@ -1619,6 +1620,7 @@ export const useNodesInteractions = () => {
_isBundled: false,
_connectedSourceHandleIds: [],
_connectedTargetHandleIds: [],
_dimmed: false,
title: genNewNodeTitleFromOld(nodeToPaste.data.title),
},
position: {
@ -1686,27 +1688,24 @@ export const useNodesInteractions = () => {
return
// handle paste to nested block
if (selectedNode.data.type === BlockEnum.Iteration) {
newNode.data.isInIteration = true
newNode.data.iteration_id = selectedNode.data.iteration_id
newNode.parentId = selectedNode.id
newNode.positionAbsolute = {
x: newNode.position.x,
y: newNode.position.y,
}
// set position base on parent node
newNode.position = getNestedNodePosition(newNode, selectedNode)
}
else if (selectedNode.data.type === BlockEnum.Loop) {
newNode.data.isInLoop = true
newNode.data.loop_id = selectedNode.data.loop_id
if (selectedNode.data.type === BlockEnum.Iteration || selectedNode.data.type === BlockEnum.Loop) {
const isIteration = selectedNode.data.type === BlockEnum.Iteration
newNode.data.isInIteration = isIteration
newNode.data.iteration_id = isIteration ? selectedNode.id : undefined
newNode.data.isInLoop = !isIteration
newNode.data.loop_id = !isIteration ? selectedNode.id : undefined
newNode.parentId = selectedNode.id
newNode.zIndex = isIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX
newNode.positionAbsolute = {
x: newNode.position.x,
y: newNode.position.y,
}
// set position base on parent node
newNode.position = getNestedNodePosition(newNode, selectedNode)
// update parent children array like native add
parentChildrenToAppend.push({ parentId: selectedNode.id, childId: newNode.id, childType: newNode.data.type })
}
}
}
@ -1737,7 +1736,17 @@ export const useNodesInteractions = () => {
}
})
setNodes([...nodes, ...nodesToPaste])
const newNodes = produce(nodes, (draft: Node[]) => {
parentChildrenToAppend.forEach(({ parentId, childId, childType }) => {
const p = draft.find(n => n.id === parentId)
if (p) {
p.data._children?.push({ nodeId: childId, nodeType: childType })
}
})
draft.push(...nodesToPaste)
})
setNodes(newNodes)
setEdges([...edges, ...edgesToPaste])
saveStateToHistory(WorkflowHistoryEvent.NodePaste, {
nodeId: nodesToPaste?.[0]?.id,

Some files were not shown because too many files have changed in this diff Show More