Merge branch 'main' into feat/hitl-frontend

This commit is contained in:
twwu 2026-01-14 13:24:56 +08:00
commit dfb25df5ec
125 changed files with 4121 additions and 1224 deletions

View File

@ -35,7 +35,7 @@ from libs.rsa import generate_key_pair
from models import Tenant from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID from models.provider_ids import DatasourceProviderID, ToolProviderID
@ -64,8 +64,10 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip(): if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red")) click.echo(click.style("Passwords do not match.", fg="red"))
return return
normalized_email = email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none() account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account: if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red")) click.echo(click.style(f"Account not found for email: {email}", fg="red"))
@ -86,7 +88,7 @@ def reset_password(email, new_password, password_confirm):
base64_password_hashed = base64.b64encode(password_hashed).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed account.password = base64_password_hashed
account.password_salt = base64_salt account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(email) AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green")) click.echo(click.style("Password reset successfully.", fg="green"))
@ -102,20 +104,22 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip(): if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red")) click.echo(click.style("New emails do not match.", fg="red"))
return return
normalized_new_email = new_email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none() account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account: if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red")) click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return return
try: try:
email_validate(new_email) email_validate(normalized_new_email)
except: except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red")) click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return return
account.email = new_email account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green")) click.echo(click.style("Email updated successfully.", fg="green"))
@ -660,7 +664,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
return return
# Create account # Create account
email = email.strip() email = email.strip().lower()
if "@" not in email: if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red")) click.echo(click.style("Invalid email address.", fg="red"))

View File

@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseSettings): class VolcengineTOSStorageConfig(BaseSettings):
""" """
Configuration settings for Volcengine Tinder Object Storage (TOS) Configuration settings for Volcengine Torch Object Storage (TOS)
""" """
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field( VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(

View File

@ -63,10 +63,9 @@ class ActivateCheckApi(Resource):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args.workspace_id workspaceId = args.workspace_id
reg_email = args.email
token = args.token token = args.token
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
if invitation: if invitation:
data = invitation.get("data", {}) data = invitation.get("data", {})
tenant = invitation.get("tenant", None) tenant = invitation.get("tenant", None)
@ -100,11 +99,12 @@ class ActivateApi(Resource):
def post(self): def post(self):
args = ActivatePayload.model_validate(console_ns.payload) args = ActivatePayload.model_validate(console_ns.payload)
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token) normalized_request_email = args.email.lower() if args.email else None
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
if invitation is None: if invitation is None:
raise AlreadyActivateError() raise AlreadyActivateError()
RegisterService.revoke_token(args.workspace_id, args.email, args.token) RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
account = invitation["account"] account = invitation["account"]
account.name = args.name account.name = args.name

View File

@ -1,7 +1,6 @@
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
@email_register_enabled @email_register_enabled
def post(self): def post(self):
args = EmailRegisterSendPayload.model_validate(console_ns.payload) args = EmailRegisterSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
if args.language in languages: if args.language in languages:
language = args.language language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError() raise AccountInFreezeError()
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = None token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
return {"result": "success", "data": token} return {"result": "success", "data": token}
@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
def post(self): def post(self):
args = EmailRegisterValidityPayload.model_validate(console_ns.payload) args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
user_email = args.email user_email = args.email.lower()
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email) is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
if is_email_register_error_rate_limit: if is_email_register_error_rate_limit:
raise EmailRegisterLimitError() raise EmailRegisterLimitError()
@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
raise InvalidEmailError() raise InvalidEmailError()
if args.code != token_data.get("code"): if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args.email) AccountService.add_email_register_error_rate_limit(user_email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
user_email, code=args.code, additional_data={"phase": "register"} user_email, code=args.code, additional_data={"phase": "register"}
) )
AccountService.reset_email_register_error_rate_limit(args.email) AccountService.reset_email_register_error_rate_limit(user_email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/email-register") @console_ns.route("/email-register")
@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
AccountService.revoke_email_register_token(args.token) AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "") email = register_data.get("email", "")
normalized_email = email.lower()
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account: if account:
raise EmailAlreadyInUseError() raise EmailAlreadyInUseError()
else: else:
account = self._create_new_account(email, args.password_confirm) account = self._create_new_account(normalized_email, args.password_confirm)
if not account: if not account:
raise AccountNotFoundError() raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(email) AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()} return {"result": "success", "data": token_pair.model_dump()}
def _create_new_account(self, email, password) -> Account | None: def _create_new_account(self, email: str, password: str) -> Account | None:
# Create new account if allowed # Create new account if allowed
account = None account = None
try: try:

View File

@ -4,7 +4,6 @@ import secrets
from flask import request from flask import request
from flask_restx import Resource, fields from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import console_ns from controllers.console import console_ns
@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
args = ForgotPasswordSendPayload.model_validate(console_ns.payload) args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US" language = "en-US"
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email( token = AccountService.send_reset_password_email(
account=account, account=account,
email=args.email, email=normalized_email,
language=language, language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register, is_allow_register=FeatureService.get_system_features().is_allow_register,
) )
@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self): def post(self):
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload) args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
user_email = args.email user_email = args.email.lower()
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email) is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit: if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError() raise EmailPasswordResetLimitError()
@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
raise InvalidEmailError() raise InvalidEmailError()
if args.code != token_data.get("code"): if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args.email) AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token # Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token( _, new_token = AccountService.generate_reset_password_token(
user_email, code=args.code, additional_data={"phase": "reset"} token_email, code=args.code, additional_data={"phase": "reset"}
) )
AccountService.reset_forgot_password_error_rate_limit(args.email) AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/forgot-password/resets") @console_ns.route("/forgot-password/resets")
@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt) password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "") email = reset_data.get("email", "")
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account: if account:
self._update_existing_account(account, password_hashed, salt, session) self._update_existing_account(account, password_hashed, salt, session)

View File

@ -90,32 +90,38 @@ class LoginApi(Resource):
def post(self): def post(self):
"""Authenticate user and login.""" """Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload) args = LoginPayload.model_validate(console_ns.payload)
request_email = args.email
normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError() raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email) is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit: if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError() raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
invitation_data: dict[str, Any] | None = None invitation_data: dict[str, Any] | None = None
if args.invite_token: if invite_token:
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token) invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation_data is None:
invite_token = None
try: try:
if invitation_data: if invitation_data:
data = invitation_data.get("data", {}) data = invitation_data.get("data", {})
invitee_email = data.get("email") if data else None invitee_email = data.get("email") if data else None
if invitee_email != args.email: invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email:
raise InvalidEmailError() raise InvalidEmailError()
account = AccountService.authenticate(args.email, args.password, args.invite_token) account = _authenticate_account_with_case_fallback(
else: request_email, normalized_email, args.password, invite_token
account = AccountService.authenticate(args.email, args.password) )
except services.errors.account.AccountLoginError: except services.errors.account.AccountLoginError:
raise AccountBannedError() raise AccountBannedError()
except services.errors.account.AccountPasswordError: except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(args.email) AccountService.add_login_error_rate_limit(normalized_email)
raise AuthenticationFailedError() raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace # SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0: if len(tenants) == 0:
@ -130,7 +136,7 @@ class LoginApi(Resource):
} }
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args.email) AccountService.reset_login_error_rate_limit(normalized_email)
# Create response with cookies instead of returning tokens in body # Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"}) response = make_response({"result": "success"})
@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__]) @console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self): def post(self):
args = EmailPayload.model_validate(console_ns.payload) args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
if args.language is not None and args.language == "zh-Hans": if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args.email) account = _get_account_with_case_fallback(args.email)
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
token = AccountService.send_reset_password_email( token = AccountService.send_reset_password_email(
email=args.email, email=normalized_email,
account=account, account=account,
language=language, language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register, is_allow_register=FeatureService.get_system_features().is_allow_register,
@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__]) @console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self): def post(self):
args = EmailPayload.model_validate(console_ns.payload) args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
else: else:
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args.email) account = _get_account_with_case_fallback(args.email)
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
if account is None: if account is None:
if FeatureService.get_system_features().is_allow_register: if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args.email, language=language) token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
else: else:
raise AccountNotFound() raise AccountNotFound()
else: else:
@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource):
def post(self): def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload) args = EmailCodeLoginPayload.model_validate(console_ns.payload)
user_email = args.email original_email = args.email
user_email = original_email.lower()
language = args.language language = args.language
token_data = AccountService.get_email_code_login_data(args.token) token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if token_data["email"] != args.email: token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email:
raise InvalidEmailError() raise InvalidEmailError()
if token_data["code"] != args.code: if token_data["code"] != args.code:
@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args.token) AccountService.revoke_email_code_login_token(args.token)
try: try:
account = AccountService.get_user_through_email(user_email) account = _get_account_with_case_fallback(original_email)
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
if account: if account:
@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError: except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args.email) AccountService.reset_login_error_rate_limit(user_email)
# Create response with cookies instead of returning tokens in body # Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"}) response = make_response({"result": "success"})
@ -309,3 +320,22 @@ class RefreshTokenApi(Resource):
return response return response
except Exception as e: except Exception as e:
return {"result": "fail", "message": str(e)}, 401 return {"result": "fail", "message": str(e)}, 401
def _get_account_with_case_fallback(email: str):
account = AccountService.get_user_through_email(email)
if account or email == email.lower():
return account
return AccountService.get_user_through_email(email.lower())
def _authenticate_account_with_case_fallback(
original_email: str, normalized_email: str, password: str, invite_token: str | None
):
try:
return AccountService.authenticate(original_email, password, invite_token)
except services.errors.account.AccountPasswordError:
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)

View File

@ -3,7 +3,6 @@ import logging
import httpx import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restx import Resource from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
@ -118,7 +117,10 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token) invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation: if invitation:
invitation_email = invitation.get("email", None) invitation_email = invitation.get("email", None)
if invitation_email != user_info.email: invitation_email_normalized = (
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
)
if invitation_email_normalized != user_info.email.lower():
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
@ -175,7 +177,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account: if not account:
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none() account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account return account
@ -197,9 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
tenant_was_created.send(new_tenant) tenant_was_created.send(new_tenant)
if not account: if not account:
normalized_email = user_info.email.lower()
oauth_new_user = True oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register: if not FeatureService.get_system_features().is_allow_register:
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountRegisterError( raise AccountRegisterError(
description=( description=(
"This email account has been deleted within the past " "This email account has been deleted within the past "
@ -210,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
raise AccountRegisterError(description=("Invalid email or password")) raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify" account_name = user_info.name or "Dify"
account = RegisterService.register( account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
) )
# Set interface language # Set interface language

View File

@ -7,7 +7,7 @@ from typing import Literal, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, marshal_with from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -104,6 +104,15 @@ class DocumentRenamePayload(BaseModel):
name: str name: str
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
status: str | None = Field(None, title="Status", description="Document status.")
fetch_val: str = Field("false", alias="fetch")
register_schema_models( register_schema_models(
console_ns, console_ns,
KnowledgeConfig, KnowledgeConfig,
@ -225,14 +234,16 @@ class DatasetDocumentListApi(Resource):
def get(self, dataset_id): def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int) raw_args = request.args.to_dict()
limit = request.args.get("limit", default=20, type=int) param = DocumentDatasetListParam.model_validate(raw_args)
search = request.args.get("keyword", default=None, type=str) page = param.page
sort = request.args.get("sort", default="-created_at", type=str) limit = param.limit
status = request.args.get("status", default=None, type=str) search = param.search
sort = param.sort_by
status = param.status
# "yes", "true", "t", "y", "1" convert to True, while others convert to False. # "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try: try:
fetch_val = request.args.get("fetch", default="false") fetch_val = param.fetch_val
if isinstance(fetch_val, bool): if isinstance(fetch_val, bool):
fetch = fetch_val fetch = fetch_val
else: else:

View File

@ -84,10 +84,11 @@ class SetupApi(Resource):
raise NotInitValidateError() raise NotInitValidateError()
args = SetupRequestPayload.model_validate(console_ns.payload) args = SetupRequestPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
# setup # setup
RegisterService.setup( RegisterService.setup(
email=args.email, email=normalized_email,
name=args.name, name=args.name,
password=args.password, password=args.password,
ip_address=extract_remote_ip(request), ip_address=extract_remote_ip(request),

View File

@ -41,7 +41,7 @@ from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode from models import AccountIntegrate, InvitationCode
from services.account_service import AccountService from services.account_service import AccountService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource):
else: else:
language = "en-US" language = "en-US"
account = None account = None
user_email = args.email user_email = None
email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email": if args.phase is not None and args.phase == "new_email":
if args.token is None: if args.token is None:
raise InvalidTokenError() raise InvalidTokenError()
@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
user_email = reset_data.get("email", "") user_email = reset_data.get("email", "")
if user_email != current_user.email: if user_email.lower() != current_user.email.lower():
raise InvalidEmailError() raise InvalidEmailError()
user_email = current_user.email
else: else:
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None: if account is None:
raise AccountNotFound() raise AccountNotFound()
email_for_sending = account.email
user_email = account.email
token = AccountService.send_change_email_email( token = AccountService.send_change_email_email(
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase account=account,
email=email_for_sending,
old_email=user_email,
language=language,
phase=args.phase,
) )
return {"result": "success", "data": token} return {"result": "success", "data": token}
@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource):
payload = console_ns.payload or {} payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload) args = ChangeEmailValidityPayload.model_validate(payload)
user_email = args.email user_email = args.email.lower()
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email) is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
if is_change_email_error_rate_limit: if is_change_email_error_rate_limit:
raise EmailChangeLimitError() raise EmailChangeLimitError()
@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource):
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
raise InvalidEmailError() raise InvalidEmailError()
if args.code != token_data.get("code"): if args.code != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(args.email) AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource):
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={} user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
) )
AccountService.reset_change_email_error_rate_limit(args.email) AccountService.reset_change_email_error_rate_limit(user_email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/account/change-email/reset") @console_ns.route("/account/change-email/reset")
@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource):
def post(self): def post(self):
payload = console_ns.payload or {} payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload) args = ChangeEmailResetPayload.model_validate(payload)
normalized_new_email = args.new_email.lower()
if AccountService.is_account_in_freeze(args.new_email): if AccountService.is_account_in_freeze(normalized_new_email):
raise AccountInFreezeError() raise AccountInFreezeError()
if not AccountService.check_email_unique(args.new_email): if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError() raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token) reset_data = AccountService.get_change_email_data(args.token)
@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource):
old_email = reset_data.get("old_email", "") old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if current_user.email != old_email: if current_user.email.lower() != old_email.lower():
raise AccountNotFound() raise AccountNotFound()
updated_account = AccountService.update_account_email(current_user, email=args.new_email) updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email( AccountService.send_change_email_completed_notify_email(
email=args.new_email, email=normalized_new_email,
) )
return updated_account return updated_account
@ -645,8 +657,9 @@ class CheckEmailUnique(Resource):
def post(self): def post(self):
payload = console_ns.payload or {} payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload) args = CheckEmailUniquePayload.model_validate(payload)
if AccountService.is_account_in_freeze(args.email): normalized_email = args.email.lower()
if AccountService.is_account_in_freeze(normalized_email):
raise AccountInFreezeError() raise AccountInFreezeError()
if not AccountService.check_email_unique(args.email): if not AccountService.check_email_unique(normalized_email):
raise EmailAlreadyInUseError() raise EmailAlreadyInUseError()
return {"result": "success"} return {"result": "success"}

View File

@ -116,26 +116,31 @@ class MemberInviteEmailApi(Resource):
raise WorkspaceMembersLimitExceeded() raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails: for invitee_email in invitee_emails:
normalized_invitee_email = invitee_email.lower()
try: try:
if not inviter.current_tenant: if not inviter.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
token = RegisterService.invite_new_member( token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
) )
encoded_invitee_email = parse.quote(invitee_email) encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append( invitation_results.append(
{ {
"status": "success", "status": "success",
"email": invitee_email, "email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}", "url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
} }
) )
except AccountAlreadyInTenantError: except AccountAlreadyInTenantError:
invitation_results.append( invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} {"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
) )
except Exception as e: except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
return { return {
"result": "success", "result": "success",

View File

@ -4,7 +4,6 @@ import secrets
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
@ -22,7 +21,7 @@ from controllers.web import web_ns
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models import Account from models.account import Account
from services.account_service import AccountService from services.account_service import AccountService
@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource):
def post(self): def post(self):
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {}) payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
request_email = payload.email
normalized_email = request_email.lower()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US" language = "en-US"
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none() account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None token = None
if account is None: if account is None:
raise AuthenticationFailedError() raise AuthenticationFailedError()
else: else:
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language) token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
return {"result": "success", "data": token} return {"result": "success", "data": token}
@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self): def post(self):
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {}) payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
user_email = payload.email user_email = payload.email.lower()
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email) is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit: if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError() raise EmailPasswordResetLimitError()
@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
raise InvalidEmailError() raise InvalidEmailError()
if payload.code != token_data.get("code"): if payload.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(payload.email) AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token # Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token( _, new_token = AccountService.generate_reset_password_token(
user_email, code=payload.code, additional_data={"phase": "reset"} token_email, code=payload.code, additional_data={"phase": "reset"}
) )
AccountService.reset_forgot_password_error_rate_limit(payload.email) AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@web_ns.route("/forgot-password/resets") @web_ns.route("/forgot-password/resets")
@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "") email = reset_data.get("email", "")
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account: if account:
self._update_existing_account(account, password_hashed, salt, session) self._update_existing_account(account, password_hashed, salt, session)

View File

@ -197,25 +197,29 @@ class EmailCodeLoginApi(Resource):
) )
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"].lower()
token_data = WebAppAuthService.get_email_code_login_data(args["token"]) token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if token_data["email"] != args["email"]: token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if normalized_token_email != user_email:
raise InvalidEmailError() raise InvalidEmailError()
if token_data["code"] != args["code"]: if token_data["code"] != args["code"]:
raise EmailCodeError() raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"]) WebAppAuthService.revoke_email_code_login_token(args["token"])
account = WebAppAuthService.get_user_through_email(user_email) account = WebAppAuthService.get_user_through_email(token_email)
if not account: if not account:
raise AuthenticationFailedError() raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account) token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(user_email)
response = make_response({"result": "success", "data": {"access_token": token}}) response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response return response

View File

@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
), ),
) )
assistant_message = AssistantPromptMessage(content="", tool_calls=[]) assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
if tool_calls: if tool_calls:
assistant_message.tool_calls = [ assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall( AssistantPromptMessage.ToolCall(
@ -200,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
) )
for tool_call in tool_calls for tool_call in tool_calls
] ]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message) self._current_thoughts.append(assistant_message)

View File

@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
from core.db.session_factory import session_factory from core.db.session_factory import session_factory
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion from core.variables.variables import Variable
from core.workflow.enums import WorkflowType from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_engine.layers.base import GraphEngineLayer
@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs, system_variables=system_inputs,
user_inputs=inputs, user_inputs=inputs,
environment_variables=self._workflow.environment_variables, environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`, # Based on the definition of `Variable`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. # `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
) )
@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager, trace_manager=app_generate_entity.trace_manager,
) )
def _initialize_conversation_variables(self) -> list[VariableUnion]: def _initialize_conversation_variables(self) -> list[Variable]:
""" """
Initialize conversation variables for the current conversation. Initialize conversation variables for the current conversation.
@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables] conversation_variables = [var.to_variable() for var in existing_variables]
session.commit() session.commit()
return cast(list[VariableUnion], conversation_variables) return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
""" """

View File

@ -189,7 +189,7 @@ class BaseAppGenerator:
elif value == 0: elif value == 0:
value = False value = False
case VariableEntityType.JSON_OBJECT: case VariableEntityType.JSON_OBJECT:
if not isinstance(value, dict): if value and not isinstance(value, dict):
raise ValueError(f"{variable_entity.variable} in input form must be a dict") raise ValueError(f"{variable_entity.variable} in input form must be a dict")
case _: case _:
raise AssertionError("this statement should be unreachable.") raise AssertionError("this statement should be unreachable.")

View File

@ -1,6 +1,6 @@
import logging import logging
from core.variables import Variable from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType from core.workflow.enums import NodeType
@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
if selector[0] != CONVERSATION_VARIABLE_NODE_ID: if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue continue
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable): if not isinstance(variable, VariableBase):
logger.warning( logger.warning(
"Conversation variable not found in variable pool. selector=%s", "Conversation variable not found in variable pool. selector=%s",
selector, selector,

View File

@ -251,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise :return: True if prompt message is empty, False otherwise
""" """
if not super().is_empty() and not self.tool_calls: return super().is_empty() and not self.tool_calls
return False
return True
class SystemPromptMessage(PromptMessage): class SystemPromptMessage(PromptMessage):

View File

@ -1,6 +1,7 @@
import logging import logging
from collections.abc import Sequence from collections.abc import Sequence
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import ( from core.ops.aliyun_trace.data_exporter.traceclient import (
@ -54,7 +55,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo, ToolTraceInfo,
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities import WorkflowNodeExecution from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db from extensions.ext_database import db
@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
), ),
status=status, status=status,
links=trace_metadata.links, links=trace_metadata.links,
span_kind=SpanKind.SERVER,
) )
self.trace_client.add_span(message_span) self.trace_client.add_span(message_span)
@ -273,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id) service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine) session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=service_account, user=service_account,
app_id=app_id, app_id=app_id,
@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
), ),
status=status, status=status,
links=trace_metadata.links, links=trace_metadata.links,
span_kind=SpanKind.SERVER,
) )
self.trace_client.add_span(message_span) self.trace_client.add_span(message_span)
@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
), ),
status=status, status=status,
links=trace_metadata.links, links=trace_metadata.links,
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
) )
self.trace_client.add_span(workflow_span) self.trace_client.add_span(workflow_span)

View File

@ -166,7 +166,7 @@ class SpanBuilder:
attributes=span_data.attributes, attributes=span_data.attributes,
events=span_data.events, events=span_data.events,
links=span_data.links, links=span_data.links,
kind=trace_api.SpanKind.INTERNAL, kind=span_data.span_kind,
status=span_data.status, status=span_data.status,
start_time=span_data.start_time, start_time=span_data.start_time,
end_time=span_data.end_time, end_time=span_data.end_time,

View File

@ -4,7 +4,7 @@ from typing import Any
from opentelemetry import trace as trace_api from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -34,3 +34,4 @@ class SpanData(BaseModel):
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.") status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.") start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.") end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")

View File

@ -7,8 +7,8 @@ from typing import Any, cast
from flask import has_request_context from flask import has_request_context
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import (
ToolProviderType, ToolProviderType,
) )
from core.tools.errors import ToolInvokeError from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping from factories.file_factory import build_from_mapping
from libs.login import current_user from libs.login import current_user
from models import Account, Tenant from models import Account, Tenant
@ -230,30 +229,32 @@ class WorkflowTool(Tool):
""" """
Resolve user from database (worker/Celery context). Resolve user from database (worker/Celery context).
""" """
with session_factory.create_session() as session:
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = session.scalar(user_stmt)
if user:
user.current_tenant = tenant
session.expunge(user)
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = session.scalar(end_user_stmt)
if end_user:
session.expunge(end_user)
return end_user
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = db.session.scalar(tenant_stmt)
if not tenant:
return None return None
user_stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(user_stmt)
if user:
user.current_tenant = tenant
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = db.session.scalar(end_user_stmt)
if end_user:
return end_user
return None
def _get_workflow(self, app_id: str, version: str) -> Workflow: def _get_workflow(self, app_id: str, version: str) -> Workflow:
""" """
get the workflow by app id and version get the workflow by app id and version
""" """
with Session(db.engine, expire_on_commit=False) as session, session.begin(): with session_factory.create_session() as session, session.begin():
if not version: if not version:
stmt = ( stmt = (
select(Workflow) select(Workflow)
@ -265,22 +266,24 @@ class WorkflowTool(Tool):
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt) workflow = session.scalar(stmt)
if not workflow: if not workflow:
raise ValueError("workflow not found or not published") raise ValueError("workflow not found or not published")
return workflow session.expunge(workflow)
return workflow
def _get_app(self, app_id: str) -> App: def _get_app(self, app_id: str) -> App:
""" """
get the app by app id get the app by app id
""" """
stmt = select(App).where(App.id == app_id) stmt = select(App).where(App.id == app_id)
with Session(db.engine, expire_on_commit=False) as session, session.begin(): with session_factory.create_session() as session, session.begin():
app = session.scalar(stmt) app = session.scalar(stmt)
if not app: if not app:
raise ValueError("app not found") raise ValueError("app not found")
return app session.expunge(app)
return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
""" """

View File

@ -30,6 +30,7 @@ from .variables import (
SecretVariable, SecretVariable,
StringVariable, StringVariable,
Variable, Variable,
VariableBase,
) )
__all__ = [ __all__ = [
@ -62,4 +63,5 @@ __all__ = [
"StringSegment", "StringSegment",
"StringVariable", "StringVariable",
"Variable", "Variable",
"VariableBase",
] ]

View File

@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class. # - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except: # - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool. # - `SegmentGroup`, which is not added to the variable pool.
# - `Variable` and its subclasses, which are handled by `VariableUnion`. # - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[ SegmentUnion: TypeAlias = Annotated[
( (
Annotated[NoneSegment, Tag(SegmentType.NONE)] Annotated[NoneSegment, Tag(SegmentType.NONE)]

View File

@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType from .types import SegmentType
class Variable(Segment): class VariableBase(Segment):
""" """
A variable is a segment that has a name. A variable is a segment that has a name.
@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list) selector: Sequence[str] = Field(default_factory=list)
class StringVariable(StringSegment, Variable): class StringVariable(StringSegment, VariableBase):
pass pass
class FloatVariable(FloatSegment, Variable): class FloatVariable(FloatSegment, VariableBase):
pass pass
class IntegerVariable(IntegerSegment, Variable): class IntegerVariable(IntegerSegment, VariableBase):
pass pass
class ObjectVariable(ObjectSegment, Variable): class ObjectVariable(ObjectSegment, VariableBase):
pass pass
class ArrayVariable(ArraySegment, Variable): class ArrayVariable(ArraySegment, VariableBase):
pass pass
@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value) return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable): class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE value_type: SegmentType = SegmentType.NONE
value: None = None value: None = None
class FileVariable(FileSegment, Variable): class FileVariable(FileSegment, VariableBase):
pass pass
class BooleanVariable(BooleanSegment, Variable): class BooleanVariable(BooleanSegment, VariableBase):
pass pass
@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any value: Any
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. # The `Variable` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required. # Use `VariableBase` for type hinting when serialization is not required.
# #
# Note: # Note:
# - All variants in `VariableUnion` must inherit from the `Variable` class. # - All variants in `Variable` must inherit from the `VariableBase` class.
# - The union must include all non-abstract subclasses of `Segment`, except: # - The union must include all non-abstract subclasses of `VariableBase`.
VariableUnion: TypeAlias = Annotated[ Variable: TypeAlias = Annotated[
( (
Annotated[NoneVariable, Tag(SegmentType.NONE)] Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)] | Annotated[StringVariable, Tag(SegmentType.STRING)]

View File

@ -1,7 +1,7 @@
import abc import abc
from typing import Protocol from typing import Protocol
from core.variables import Variable from core.variables import VariableBase
class ConversationVariableUpdater(Protocol): class ConversationVariableUpdater(Protocol):
@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
""" """
@abc.abstractmethod @abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable"): def update(self, conversation_id: str, variable: "VariableBase"):
""" """
Updates the value of the specified conversation variable in the underlying storage. Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value. :param variable: The `VariableBase` instance containing the updated value.
""" """
pass pass

View File

@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion from core.variables.variables import Variable
class CommandType(StrEnum): class CommandType(StrEnum):
@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
class VariableUpdate(BaseModel): class VariableUpdate(BaseModel):
"""Represents a single variable update instruction.""" """Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value") value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand): class UpdateVariablesCommand(GraphEngineCommand):

View File

@ -11,7 +11,7 @@ from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import ( from core.workflow.enums import (
NodeExecutionType, NodeExecutionType,
@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
datetime, datetime,
list[GraphNodeEventBase], list[GraphNodeEventBase],
object | None, object | None,
dict[str, VariableUnion], dict[str, Variable],
LLMUsage, LLMUsage,
] ]
], ],
@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
item: object, item: object,
flask_app: Flask, flask_app: Flask,
context_vars: contextvars.Context, context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]: ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results.""" """Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None) iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return variable_mapping return variable_mapping
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]: def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None: def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
parent_pool = self.graph_runtime_state.variable_pool parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -73,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
assigned_variable_selector = self.node_data.assigned_variable_selector assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable): if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found") raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode: match self.node_data.write_mode:

View File

@ -2,7 +2,7 @@ import json
from collections.abc import Mapping, MutableMapping, Sequence from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# ==================== Validation Part # ==================== Validation Part
# Check if variable exists # Check if variable exists
if not isinstance(variable, Variable): if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=item.variable_selector) raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported # Check if operation is supported
@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
for selector in updated_variable_selectors: for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable): if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=selector) raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value process_data[variable.name] = variable.value
@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def _handle_item( def _handle_item(
self, self,
*, *,
variable: Variable, variable: VariableBase,
operation: Operation, operation: Operation,
value: Any, value: Any,
): ):

View File

@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import RAGPipelineVariableInput, VariableUnion from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import ( from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID,
@ -32,7 +32,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary. # The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one. # elements of the selector except the first one.
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field( variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
description="Variables mapping", description="Variables mapping",
default=defaultdict(dict), default=defaultdict(dict),
) )
@ -46,13 +46,13 @@ class VariablePool(BaseModel):
description="System variables", description="System variables",
default_factory=SystemVariable.empty, default_factory=SystemVariable.empty,
) )
environment_variables: Sequence[VariableUnion] = Field( environment_variables: Sequence[Variable] = Field(
description="Environment variables.", description="Environment variables.",
default_factory=list[VariableUnion], default_factory=list[Variable],
) )
conversation_variables: Sequence[VariableUnion] = Field( conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.", description="Conversation variables.",
default_factory=list[VariableUnion], default_factory=list[Variable],
) )
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.", description="RAG pipeline variables.",
@ -105,7 +105,7 @@ class VariablePool(BaseModel):
f"got {len(selector)} elements" f"got {len(selector)} elements"
) )
if isinstance(value, Variable): if isinstance(value, VariableBase):
variable = value variable = value
elif isinstance(value, Segment): elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector) variable = variable_factory.segment_to_variable(segment=value, selector=selector)
@ -114,9 +114,9 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector) variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector) node_id, name = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`, # Based on the definition of `Variable`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. # `VariableBase` instances can be safely used as `Variable` since they are compatible.
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable) self.variable_dictionary[node_id][name] = cast(Variable, variable)
@classmethod @classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:

View File

@ -2,7 +2,7 @@ import abc
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Protocol from typing import Any, Protocol
from core.variables import Variable from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool from core.workflow.runtime import VariablePool
@ -26,7 +26,7 @@ class VariableLoader(Protocol):
""" """
@abc.abstractmethod @abc.abstractmethod
def load_variables(self, selectors: list[list[str]]) -> list[Variable]: def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
"""Load variables based on the provided selectors. If the selectors are empty, """Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list. this method should return an empty list.
@ -36,7 +36,7 @@ class VariableLoader(Protocol):
:param: selectors: a list of string list, each inner list should have at least two elements: :param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID, - the first element is the node ID,
- the second element is the variable name. - the second element is the variable name.
:return: a list of Variable objects that match the provided selectors. :return: a list of VariableBase objects that match the provided selectors.
""" """
pass pass
@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
Serves as a placeholder when no variable loading is needed. Serves as a placeholder when no variable loading is needed.
""" """
def load_variables(self, selectors: list[list[str]]) -> list[Variable]: def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
return [] return []

View File

@ -19,6 +19,7 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
@ -136,13 +137,11 @@ class WorkflowEntry:
:param user_inputs: user inputs :param user_inputs: user inputs
:return: :return:
""" """
node_config = workflow.get_node_config_by_id(node_id) node_config = dict(workflow.get_node_config_by_id(node_id))
node_config_data = node_config.get("data", {}) node_config_data = node_config.get("data", {})
# Get node class # Get node type
node_type = NodeType(node_config_data.get("type")) node_type = NodeType(node_config_data.get("type"))
node_version = node_config_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init graph init params and runtime state # init graph init params and runtime state
graph_init_params = GraphInitParams( graph_init_params = GraphInitParams(
@ -158,12 +157,12 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state # init workflow run state
node = node_cls( node_factory = DifyNodeFactory(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
) )
node = node_factory.create_node(node_config)
node_cls = type(node)
try: try:
# variable selector to variable mapping # variable selector to variable mapping

View File

@ -10,6 +10,7 @@ import os
from dotenv import load_dotenv from dotenv import load_dotenv
from configs import dify_config
from dify_app import DifyApp from dify_app import DifyApp
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,12 +20,17 @@ def is_enabled() -> bool:
""" """
Check if logstore extension is enabled. Check if logstore extension is enabled.
Logstore is considered enabled when:
1. All required Aliyun SLS environment variables are set
2. At least one repository configuration points to a logstore implementation
Returns: Returns:
True if all required Aliyun SLS environment variables are set, False otherwise True if logstore should be initialized, False otherwise
""" """
# Load environment variables from .env file # Load environment variables from .env file
load_dotenv() load_dotenv()
# Check if Aliyun SLS connection parameters are configured
required_vars = [ required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID", "ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET", "ALIYUN_SLS_ACCESS_KEY_SECRET",
@ -33,24 +39,32 @@ def is_enabled() -> bool:
"ALIYUN_SLS_PROJECT_NAME", "ALIYUN_SLS_PROJECT_NAME",
] ]
all_set = all(os.environ.get(var) for var in required_vars) sls_vars_set = all(os.environ.get(var) for var in required_vars)
if not all_set: if not sls_vars_set:
logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set") return False
return all_set # Check if any repository configuration points to logstore implementation
repository_configs = [
dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_RUN_REPOSITORY,
]
uses_logstore = any("logstore" in config.lower() for config in repository_configs)
if not uses_logstore:
return False
logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
return True
def init_app(app: DifyApp): def init_app(app: DifyApp):
""" """
Initialize logstore on application startup. Initialize logstore on application startup.
If initialization fails, the application continues running without logstore features.
This function:
1. Creates Aliyun SLS project if it doesn't exist
2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
3. Creates indexes with field configurations based on PostgreSQL table structures
This operation is idempotent and only executes once during application startup.
Args: Args:
app: The Dify application instance app: The Dify application instance
@ -58,17 +72,23 @@ def init_app(app: DifyApp):
try: try:
from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.aliyun_logstore import AliyunLogStore
logger.info("Initializing logstore...") logger.info("Initializing Aliyun SLS Logstore...")
# Create logstore client and initialize project/logstores/indexes # Create logstore client and initialize resources
logstore_client = AliyunLogStore() logstore_client = AliyunLogStore()
logstore_client.init_project_logstore() logstore_client.init_project_logstore()
# Attach to app for potential later use
app.extensions["logstore"] = logstore_client app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully") logger.info("Logstore initialized successfully")
except Exception: except Exception:
logger.exception("Failed to initialize logstore") logger.exception(
# Don't raise - allow application to continue even if logstore init fails "Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
# This ensures that the application can still run if logstore is misconfigured "Application will continue but logstore features will NOT work.",
os.environ.get("ALIYUN_SLS_ENDPOINT"),
os.environ.get("ALIYUN_SLS_REGION"),
os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
)
# Don't raise - allow application to continue even if logstore setup fails

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import logging import logging
import os import os
import socket
import threading import threading
import time import time
from collections.abc import Sequence from collections.abc import Sequence
@ -179,9 +180,18 @@ class AliyunLogStore:
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "") self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "") self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365)) self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true" self.log_enabled: bool = (
os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
)
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true" self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
# Get timeout configuration
check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
# Pre-check endpoint connectivity to prevent indefinite hangs
self._check_endpoint_connectivity(self.endpoint, check_timeout)
# Initialize SDK client # Initialize SDK client
self.client = LogClient( self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
@ -199,6 +209,49 @@ class AliyunLogStore:
self.__class__._initialized = True self.__class__._initialized = True
@staticmethod
def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
"""
Check if the SLS endpoint is reachable before creating LogClient.
Prevents indefinite hangs when the endpoint is unreachable.
Args:
endpoint: SLS endpoint URL
timeout: Connection timeout in seconds
Raises:
ConnectionError: If endpoint is not reachable
"""
# Parse endpoint URL to extract hostname and port
from urllib.parse import urlparse
parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
hostname = parsed_url.hostname
port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
if not hostname:
raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
sock = None
try:
# Create socket and set timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect((hostname, port))
except Exception as e:
# Catch all exceptions and provide clear error message
error_type = type(e).__name__
raise ConnectionError(
f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
) from e
finally:
# Ensure socket is properly closed
if sock:
try:
sock.close()
except Exception: # noqa: S110
pass # Ignore errors during cleanup
@property @property
def supports_pg_protocol(self) -> bool: def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled.""" """Check if PG protocol is supported and enabled."""
@ -220,19 +273,16 @@ class AliyunLogStore:
try: try:
self._use_pg_protocol = self._pg_client.init_connection() self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol: if self._use_pg_protocol:
logger.info("Successfully connected to project %s using PG protocol", self.project_name) logger.info("Using PG protocol for project %s", self.project_name)
# Check if scan_index is enabled for all logstores # Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled() self._check_and_disable_pg_if_scan_index_disabled()
return True return True
else: else:
logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name) logger.info("Using SDK mode for project %s", self.project_name)
return False return False
except Exception as e: except Exception as e:
logger.warning( logger.info("Using SDK mode for project %s", self.project_name)
"Failed to establish PG connection for project %s: %s. Will use SDK mode.", logger.debug("PG connection details: %s", str(e))
self.project_name,
str(e),
)
self._use_pg_protocol = False self._use_pg_protocol = False
return False return False
@ -246,10 +296,6 @@ class AliyunLogStore:
if self._use_pg_protocol: if self._use_pg_protocol:
return return
logger.info(
"Attempting delayed PG connection for newly created project %s ...",
self.project_name,
)
self._attempt_pg_connection_init() self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None self.__class__._pg_connection_timer = None
@ -284,11 +330,7 @@ class AliyunLogStore:
if project_is_new: if project_is_new:
# For newly created projects, schedule delayed PG connection # For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False self._use_pg_protocol = False
logger.info( logger.info("Using SDK mode for project %s (newly created)", self.project_name)
"Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
self.project_name,
self.__class__._pg_connection_delay,
)
if self.__class__._pg_connection_timer is not None: if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel() self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer( self.__class__._pg_connection_timer = threading.Timer(
@ -299,7 +341,6 @@ class AliyunLogStore:
self.__class__._pg_connection_timer.start() self.__class__._pg_connection_timer.start()
else: else:
# For existing projects, attempt PG connection immediately # For existing projects, attempt PG connection immediately
logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init() self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None: def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
@ -318,9 +359,9 @@ class AliyunLogStore:
existing_config = self.get_existing_index_config(logstore_name) existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index: if existing_config and not existing_config.scan_index:
logger.info( logger.info(
"Logstore %s has scan_index=false, USE SDK mode for read/write operations. " "Logstore %s requires scan_index enabled, using SDK mode for project %s",
"PG protocol requires scan_index to be enabled.",
logstore_name, logstore_name,
self.project_name,
) )
self._use_pg_protocol = False self._use_pg_protocol = False
# Close PG connection if it was initialized # Close PG connection if it was initialized
@ -748,7 +789,6 @@ class AliyunLogStore:
reverse=reverse, reverse=reverse,
) )
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled: if self.log_enabled:
logger.info( logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | " "[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
@ -770,7 +810,6 @@ class AliyunLogStore:
for log in logs: for log in logs:
result.append(log.get_contents()) result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled: if self.log_enabled:
logger.info( logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d", "[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
@ -845,7 +884,6 @@ class AliyunLogStore:
query=full_query, query=full_query,
) )
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled: if self.log_enabled:
logger.info( logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s", "[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
@ -853,8 +891,7 @@ class AliyunLogStore:
self.project_name, self.project_name,
from_time, from_time,
to_time, to_time,
query, full_query,
sql,
) )
try: try:
@ -865,7 +902,6 @@ class AliyunLogStore:
for log in logs: for log in logs:
result.append(log.get_contents()) result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled: if self.log_enabled:
logger.info( logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d", "[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",

View File

@ -7,8 +7,7 @@ from contextlib import contextmanager
from typing import Any from typing import Any
import psycopg2 import psycopg2
import psycopg2.pool from sqlalchemy import create_engine
from psycopg2 import InterfaceError, OperationalError
from configs import dify_config from configs import dify_config
@ -16,11 +15,7 @@ logger = logging.getLogger(__name__)
class AliyunLogStorePG: class AliyunLogStorePG:
""" """PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
PostgreSQL protocol support for Aliyun SLS LogStore.
Handles PG connection pooling and operations for regions that support PG protocol.
"""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str): def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
""" """
@ -36,24 +31,11 @@ class AliyunLogStorePG:
self._access_key_secret = access_key_secret self._access_key_secret = access_key_secret
self._endpoint = endpoint self._endpoint = endpoint
self.project_name = project_name self.project_name = project_name
self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None self._engine: Any = None # SQLAlchemy Engine
self._use_pg_protocol = False self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool: def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
""" """Fast TCP port check to avoid long waits on unsupported regions."""
Check if a TCP port is reachable using socket connection.
This provides a fast check before attempting full database connection,
preventing long waits when connecting to unsupported regions.
Args:
host: Hostname or IP address
port: Port number
timeout: Connection timeout in seconds (default: 2.0)
Returns:
True if port is reachable, False otherwise
"""
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout) sock.settimeout(timeout)
@ -65,166 +47,101 @@ class AliyunLogStorePG:
return False return False
def init_connection(self) -> bool: def init_connection(self) -> bool:
""" """Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
Initialize PostgreSQL connection pool for SLS PG protocol support.
Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
_use_pg_protocol to True and creates a connection pool. If connection fails
(region doesn't support PG protocol or other errors), returns False.
Returns:
True if PG protocol is supported and initialized, False otherwise
"""
try: try:
# Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "") pg_host = self._endpoint.replace("http://", "").replace("https://", "")
# Get pool configuration # Pool configuration
pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10)) pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
logger.debug( logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
"Check PG protocol connection to SLS: host=%s, project=%s",
pg_host,
self.project_name,
)
# Fast port connectivity check before attempting full connection # Fast port check to avoid long waits
# This prevents long waits when connecting to unsupported regions
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0): if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
logger.info( logger.debug("Using SDK mode for host=%s", pg_host)
"USE SDK mode for read/write operations, host=%s",
pg_host,
)
return False return False
# Create connection pool # Build connection URL
self._pg_pool = psycopg2.pool.SimpleConnectionPool( from urllib.parse import quote_plus
minconn=1,
maxconn=pg_max_connections, username = quote_plus(self._access_key_id)
host=pg_host, password = quote_plus(self._access_key_secret)
port=5432, database_url = (
database=self.project_name, f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
user=self._access_key_id,
password=self._access_key_secret,
sslmode="require",
connect_timeout=5,
application_name=f"Dify-{dify_config.project.version}",
) )
# Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables # Create SQLAlchemy engine with connection pool
# Connection pool creation success already indicates connectivity self._engine = create_engine(
database_url,
pool_size=pool_size,
max_overflow=max_overflow,
pool_recycle=pool_recycle,
pool_pre_ping=pool_pre_ping,
pool_timeout=30,
connect_args={
"connect_timeout": 5,
"application_name": f"Dify-{dify_config.project.version}-fixautocommit",
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
},
)
self._use_pg_protocol = True self._use_pg_protocol = True
logger.info( logger.info(
"PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.", "PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
self.project_name, self.project_name,
pool_size,
pool_recycle,
) )
return True return True
except Exception as e: except Exception as e:
# PG connection failed - fallback to SDK mode
self._use_pg_protocol = False self._use_pg_protocol = False
if self._pg_pool: if self._engine:
try: try:
self._pg_pool.closeall() self._engine.dispose()
except Exception: except Exception:
logger.debug("Failed to close PG connection pool during cleanup, ignoring") logger.debug("Failed to dispose engine during cleanup, ignoring")
self._pg_pool = None self._engine = None
logger.info( logger.debug("Using SDK mode for region: %s", str(e))
"PG protocol connection failed (region may not support PG protocol): %s. "
"Falling back to SDK mode for read/write operations.",
str(e),
)
return False
def _is_connection_valid(self, conn: Any) -> bool:
"""
Check if a connection is still valid.
Args:
conn: psycopg2 connection object
Returns:
True if connection is valid, False otherwise
"""
try:
# Check if connection is closed
if conn.closed:
return False
# Quick ping test - execute a lightweight query
# For SLS PG protocol, we can't use SELECT 1 without FROM,
# so we just check the connection status
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.fetchone()
return True
except Exception:
return False return False
@contextmanager @contextmanager
def _get_connection(self): def _get_connection(self):
""" """Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
Context manager to get a PostgreSQL connection from the pool. if not self._engine:
raise RuntimeError("SQLAlchemy engine is not initialized")
Automatically validates and refreshes stale connections. connection = self._engine.raw_connection()
Note: Aliyun SLS PG protocol does not support transactions, so we always
use autocommit mode.
Yields:
psycopg2 connection object
Raises:
RuntimeError: If PG pool is not initialized
"""
if not self._pg_pool:
raise RuntimeError("PG connection pool is not initialized")
conn = self._pg_pool.getconn()
try: try:
# Validate connection and get a fresh one if needed connection.autocommit = True # SLS PG protocol does not support transactions
if not self._is_connection_valid(conn): yield connection
logger.debug("Connection is stale, marking as bad and getting a new one") except Exception:
# Mark connection as bad and get a new one raise
self._pg_pool.putconn(conn, close=True)
conn = self._pg_pool.getconn()
# Aliyun SLS PG protocol does not support transactions, always use autocommit
conn.autocommit = True
yield conn
finally: finally:
# Return connection to pool (or close if it's bad) connection.close()
if self._is_connection_valid(conn):
self._pg_pool.putconn(conn)
else:
self._pg_pool.putconn(conn, close=True)
def close(self) -> None: def close(self) -> None:
"""Close the PostgreSQL connection pool.""" """Dispose SQLAlchemy engine and close all connections."""
if self._pg_pool: if self._engine:
try: try:
self._pg_pool.closeall() self._engine.dispose()
logger.info("PG connection pool closed") logger.info("SQLAlchemy engine disposed")
except Exception: except Exception:
logger.exception("Failed to close PG connection pool") logger.exception("Failed to dispose engine")
def _is_retriable_error(self, error: Exception) -> bool: def _is_retriable_error(self, error: Exception) -> bool:
""" """Check if error is retriable (connection-related issues)."""
Check if an error is retriable (connection-related issues). # Check for psycopg2 connection errors directly
if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
Args:
error: Exception to check
Returns:
True if the error is retriable, False otherwise
"""
# Retry on connection-related errors
if isinstance(error, (OperationalError, InterfaceError)):
return True return True
# Check error message for specific connection issues
error_msg = str(error).lower() error_msg = str(error).lower()
retriable_patterns = [ retriable_patterns = [
"connection", "connection",
@ -234,34 +151,18 @@ class AliyunLogStorePG:
"reset by peer", "reset by peer",
"no route to host", "no route to host",
"network", "network",
"operational error",
"interface error",
] ]
return any(pattern in error_msg for pattern in retriable_patterns) return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None: def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
""" """Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
Write log to SLS using PostgreSQL protocol with automatic retry.
Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
writes with log_version field for versioning, same as SDK implementation.
Args:
logstore: Name of the logstore table
contents: List of (field_name, value) tuples
log_enabled: Whether to enable logging
Raises:
psycopg2.Error: If database operation fails after all retries
"""
if not contents: if not contents:
return return
# Extract field names and values from contents
fields = [field_name for field_name, _ in contents] fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents] values = [value for _, value in contents]
# Build INSERT statement with literal values
# Note: Aliyun SLS PG protocol doesn't support parameterized queries,
# so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields]) field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled: if log_enabled:
@ -272,67 +173,40 @@ class AliyunLogStorePG:
len(contents), len(contents),
) )
# Retry configuration
max_retries = 3 max_retries = 3
retry_delay = 0.1 # Start with 100ms retry_delay = 0.1
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
with self._get_connection() as conn: with self._get_connection() as conn:
with conn.cursor() as cursor: with conn.cursor() as cursor:
# Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields)) placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8") values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}' insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql) cursor.execute(insert_sql)
# Success - exit retry loop
return return
except psycopg2.Error as e: except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e): if not self._is_retriable_error(e):
# Not a retriable error (e.g., data validation error), fail immediately logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
logger.exception(
"Failed to put logs to logstore %s via PG protocol (non-retriable error)",
logstore,
)
raise raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1: if attempt < max_retries - 1:
logger.warning( logger.warning(
"Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", "Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
logstore, logstore,
attempt + 1, attempt + 1,
max_retries, max_retries,
str(e), str(e),
) )
time.sleep(retry_delay) time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff retry_delay *= 2
else: else:
# Last attempt failed logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
logger.exception(
"Failed to put logs to logstore %s via PG protocol after %d attempts",
logstore,
max_retries,
)
raise raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]: def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
""" """Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
Execute SQL query using PostgreSQL protocol with automatic retry.
Args:
sql: SQL query string
logstore: Name of the logstore (for logging purposes)
log_enabled: Whether to enable logging
Returns:
List of result rows as dictionaries
Raises:
psycopg2.Error: If database operation fails after all retries
"""
if log_enabled: if log_enabled:
logger.info( logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s", "[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
@ -341,20 +215,16 @@ class AliyunLogStorePG:
sql, sql,
) )
# Retry configuration
max_retries = 3 max_retries = 3
retry_delay = 0.1 # Start with 100ms retry_delay = 0.1
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
with self._get_connection() as conn: with self._get_connection() as conn:
with conn.cursor() as cursor: with conn.cursor() as cursor:
cursor.execute(sql) cursor.execute(sql)
# Get column names from cursor description
columns = [desc[0] for desc in cursor.description] columns = [desc[0] for desc in cursor.description]
# Fetch all results and convert to list of dicts
result = [] result = []
for row in cursor.fetchall(): for row in cursor.fetchall():
row_dict = {} row_dict = {}
@ -372,36 +242,31 @@ class AliyunLogStorePG:
return result return result
except psycopg2.Error as e: except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e): if not self._is_retriable_error(e):
# Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception( logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s", "Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
logstore, logstore,
sql, sql,
) )
raise raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1: if attempt < max_retries - 1:
logger.warning( logger.warning(
"Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", "Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
logstore, logstore,
attempt + 1, attempt + 1,
max_retries, max_retries,
str(e), str(e),
) )
time.sleep(retry_delay) time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff retry_delay *= 2
else: else:
# Last attempt failed
logger.exception( logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s", "Failed to execute SQL on logstore %s after %d attempts: sql=%s",
logstore, logstore,
max_retries, max_retries,
sql, sql,
) )
raise raise
# This line should never be reached due to raise above, but makes type checker happy
return [] return []

View File

@ -0,0 +1,29 @@
"""
LogStore repository utilities.
"""
from typing import Any
def safe_float(value: Any, default: float = 0.0) -> float:
"""
Safely convert a value to float, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value: Any, default: int = 0) -> int:
"""
Safely convert a value to int, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default

View File

@ -14,6 +14,8 @@ from typing import Any
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
from models.workflow import WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@ -52,9 +54,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
model.created_by_role = data.get("created_by_role") or "" model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or "" model.created_by = data.get("created_by") or ""
# Numeric fields with defaults model.index = safe_int(data.get("index", 0))
model.index = int(data.get("index", 0)) model.elapsed_time = safe_float(data.get("elapsed_time", 0))
model.elapsed_time = float(data.get("elapsed_time", 0))
# Optional fields # Optional fields
model.workflow_run_id = data.get("workflow_run_id") model.workflow_run_id = data.get("workflow_run_id")
@ -130,6 +131,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
node_id, node_id,
) )
try: try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_id = escape_identifier(workflow_id)
escaped_node_id = escape_identifier(node_id)
# Check if PG protocol is supported # Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol: if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record) # Use PG protocol with SQL query (get latest version of each record)
@ -138,10 +145,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *, SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}" FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}' WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{app_id}' AND app_id = '{escaped_app_id}'
AND workflow_id = '{workflow_id}' AND workflow_id = '{escaped_workflow_id}'
AND node_id = '{node_id}' AND node_id = '{escaped_node_id}'
AND __time__ > 0 AND __time__ > 0
) AS subquery WHERE rn = 1 ) AS subquery WHERE rn = 1
LIMIT 100 LIMIT 100
@ -153,7 +160,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
else: else:
# Use SDK with LogStore query syntax # Use SDK with LogStore query syntax
query = ( query = (
f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}" f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
) )
from_time = 0 from_time = 0
to_time = int(time.time()) # now to_time = int(time.time()) # now
@ -227,6 +235,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
workflow_run_id, workflow_run_id,
) )
try: try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_run_id = escape_identifier(workflow_run_id)
# Check if PG protocol is supported # Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol: if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record) # Use PG protocol with SQL query (get latest version of each record)
@ -235,9 +248,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *, SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}" FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}' WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{app_id}' AND app_id = '{escaped_app_id}'
AND workflow_run_id = '{workflow_run_id}' AND workflow_run_id = '{escaped_workflow_run_id}'
AND __time__ > 0 AND __time__ > 0
) AS subquery WHERE rn = 1 ) AS subquery WHERE rn = 1
LIMIT 1000 LIMIT 1000
@ -248,7 +261,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
) )
else: else:
# Use SDK with LogStore query syntax # Use SDK with LogStore query syntax
query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}" query = (
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_run_id: {escaped_workflow_run_id}"
)
from_time = 0 from_time = 0
to_time = int(time.time()) # now to_time = int(time.time()) # now
@ -313,16 +329,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
""" """
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id) logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try: try:
# Escape parameters to prevent SQL injection
escaped_execution_id = escape_identifier(execution_id)
# Check if PG protocol is supported # Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol: if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record) # Use PG protocol with SQL query (get latest version of record)
tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else "" if tenant_id:
escaped_tenant_id = escape_identifier(tenant_id)
tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
else:
tenant_filter = ""
sql_query = f""" sql_query = f"""
SELECT * FROM ( SELECT * FROM (
SELECT *, SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}" FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0 WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1 ) AS subquery WHERE rn = 1
LIMIT 1 LIMIT 1
""" """
@ -332,10 +356,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
) )
else: else:
# Use SDK with LogStore query syntax # Use SDK with LogStore query syntax
# Note: Values must be quoted in LogStore query syntax to prevent injection
if tenant_id: if tenant_id:
query = f"id: {execution_id} and tenant_id: {tenant_id}" query = (
f"id:{escape_logstore_query_value(execution_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
)
else: else:
query = f"id: {execution_id}" query = f"id:{escape_logstore_query_value(execution_id)}"
from_time = 0 from_time = 0
to_time = int(time.time()) # now to_time = int(time.time()) # now

View File

@ -10,6 +10,7 @@ Key Features:
- Optimized deduplication using finished_at IS NOT NULL filter - Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries) - Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security - Multi-tenant data isolation and security
- SQL injection prevention via parameter escaping
""" """
import logging import logging
@ -22,6 +23,8 @@ from typing import Any, cast
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun from models.workflow import WorkflowRun
@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
model.created_by_role = data.get("created_by_role") or "" model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or "" model.created_by = data.get("created_by") or ""
# Numeric fields with defaults model.total_tokens = safe_int(data.get("total_tokens", 0))
model.total_tokens = int(data.get("total_tokens", 0)) model.total_steps = safe_int(data.get("total_steps", 0))
model.total_steps = int(data.get("total_steps", 0)) model.exceptions_count = safe_int(data.get("exceptions_count", 0))
model.exceptions_count = int(data.get("exceptions_count", 0))
# Optional fields # Optional fields
model.graph = data.get("graph") model.graph = data.get("graph")
@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
if model.finished_at and model.created_at: if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds() model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else: else:
model.elapsed_time = float(data.get("elapsed_time", 0)) # Use safe conversion to handle 'null' strings and None values
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
return model return model
@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
status, status,
) )
# Convert triggered_from to list if needed # Convert triggered_from to list if needed
if isinstance(triggered_from, WorkflowRunTriggeredFrom): if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
triggered_from_list = [triggered_from] triggered_from_list = [triggered_from]
else: else:
triggered_from_list = list(triggered_from) triggered_from_list = list(triggered_from)
# Build triggered_from filter # Escape parameters to prevent SQL injection
triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list]) escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Build status filter # Build triggered_from filter with escaped values
status_filter = f"AND status='{status}'" if status else "" # Support both enum and string values for triggered_from
triggered_from_filter = " OR ".join(
[
f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
for tf in triggered_from_list
]
)
# Build status filter with escaped value
status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
# Build last_id filter for pagination # Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record # Note: This is simplified. In production, you'd need to track created_at from last record
@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT * FROM ( SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore} FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}' WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{app_id}' AND app_id='{escaped_app_id}'
AND ({triggered_from_filter}) AND ({triggered_from_filter})
{status_filter} {status_filter}
{last_id_filter} {last_id_filter}
@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id) logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try: try:
# Escape parameters to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Check if PG protocol is supported # Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol: if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record) # Use PG protocol with SQL query (get latest version of record)
@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *, SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}" FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0 WHERE id = '{escaped_run_id}'
AND tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1 ) AS subquery WHERE rn = 1
LIMIT 100 LIMIT 100
""" """
@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
) )
else: else:
# Use SDK with LogStore query syntax # Use SDK with LogStore query syntax
query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}" # Note: Values must be quoted in LogStore query syntax to prevent injection
query = (
f"id:{escape_logstore_query_value(run_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
f"and app_id:{escape_logstore_query_value(app_id)}"
)
from_time = 0 from_time = 0
to_time = int(time.time()) # now to_time = int(time.time()) # now
@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id) logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try: try:
# Escape parameter to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
# Check if PG protocol is supported # Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol: if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record) # Use PG protocol with SQL query (get latest version of record)
@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *, SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}" FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND __time__ > 0 WHERE id = '{escaped_run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1 ) AS subquery WHERE rn = 1
LIMIT 100 LIMIT 100
""" """
@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
) )
else: else:
# Use SDK with LogStore query syntax # Use SDK with LogStore query syntax
query = f"id: {run_id}" # Note: Values must be quoted in LogStore query syntax
query = f"id:{escape_logstore_query_value(run_id)}"
from_time = 0 from_time = 0
to_time = int(time.time()) # now to_time = int(time.time()) # now
@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
triggered_from, triggered_from,
status, status,
) )
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter # Build time range filter
time_filter = "" time_filter = ""
if time_range: if time_range:
@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# If status is provided, simple count # If status is provided, simple count
if status: if status:
escaped_status = escape_sql_string(status)
if status == "running": if status == "running":
# Running status requires window function # Running status requires window function
sql = f""" sql = f"""
@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM ( FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore} FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}' WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{app_id}' AND app_id='{escaped_app_id}'
AND triggered_from='{triggered_from}' AND triggered_from='{escaped_triggered_from}'
AND status='running' AND status='running'
{time_filter} {time_filter}
) t ) t
@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f""" sql = f"""
SELECT COUNT(DISTINCT id) as count SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore} FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}' WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{app_id}' AND app_id='{escaped_app_id}'
AND triggered_from='{triggered_from}' AND triggered_from='{escaped_triggered_from}'
AND status='{status}' AND status='{escaped_status}'
AND finished_at IS NOT NULL AND finished_at IS NOT NULL
{time_filter} {time_filter}
""" """
@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# No status filter - get counts grouped by status # No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running # Use optimized query for finished runs, separate query for running
try: try:
# Escape parameters (already escaped above, reuse variables)
# Count finished runs grouped by status # Count finished runs grouped by status
finished_sql = f""" finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore} FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}' WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{app_id}' AND app_id='{escaped_app_id}'
AND triggered_from='{triggered_from}' AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL AND finished_at IS NOT NULL
{time_filter} {time_filter}
GROUP BY status GROUP BY status
@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM ( FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore} FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}' WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{app_id}' AND app_id='{escaped_app_id}'
AND triggered_from='{triggered_from}' AND triggered_from='{escaped_triggered_from}'
AND status='running' AND status='running'
{time_filter} {time_filter}
) t ) t
@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug( logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from "get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
) )
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = "" time_filter = ""
if start_date: if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f""" sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore} FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}' WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{app_id}' AND app_id='{escaped_app_id}'
AND triggered_from='{triggered_from}' AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL AND finished_at IS NOT NULL
{time_filter} {time_filter}
GROUP BY date GROUP BY date
@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id, app_id,
triggered_from, triggered_from,
) )
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = "" time_filter = ""
if start_date: if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f""" sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore} FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}' WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{app_id}' AND app_id='{escaped_app_id}'
AND triggered_from='{triggered_from}' AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL AND finished_at IS NOT NULL
{time_filter} {time_filter}
GROUP BY date GROUP BY date
@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id, app_id,
triggered_from, triggered_from,
) )
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = "" time_filter = ""
if start_date: if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f""" sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore} FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}' WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{app_id}' AND app_id='{escaped_app_id}'
AND triggered_from='{triggered_from}' AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL AND finished_at IS NOT NULL
{time_filter} {time_filter}
GROUP BY date GROUP BY date
@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id, app_id,
triggered_from, triggered_from,
) )
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = "" time_filter = ""
if start_date: if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
created_by, created_by,
COUNT(DISTINCT id) AS interactions COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore} FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}' WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{app_id}' AND app_id='{escaped_app_id}'
AND triggered_from='{triggered_from}' AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL AND finished_at IS NOT NULL
{time_filter} {time_filter}
GROUP BY date, created_by GROUP BY date, created_by

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id from libs.helper import extract_tenant_id
from models import ( from models import (
@ -22,18 +23,6 @@ from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def to_serializable(obj):
"""
Convert non-JSON-serializable objects into JSON-compatible formats.
- Uses `to_dict()` if it's a callable method.
- Falls back to string representation.
"""
if hasattr(obj, "to_dict") and callable(obj.to_dict):
return obj.to_dict()
return str(obj)
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__( def __init__(
self, self,
@ -79,7 +68,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database) # Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only # Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
# Control flag for whether to write the `graph` field to LogStore. # Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; # If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
@ -113,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Generate log_version as nanosecond timestamp for record versioning # Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns()) log_version = str(time.time_ns())
# Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
json_converter = WorkflowRuntimeTypeConverter()
logstore_model = [ logstore_model = [
("id", domain_model.id_), ("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes ("log_version", log_version), # Add log_version field for append-only writes
@ -127,19 +119,19 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
("version", domain_model.workflow_version), ("version", domain_model.workflow_version),
( (
"graph", "graph",
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable) json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
if domain_model.graph and self._enable_put_graph_field if domain_model.graph and self._enable_put_graph_field
else "{}", else "{}",
), ),
( (
"inputs", "inputs",
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable) json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
if domain_model.inputs if domain_model.inputs
else "{}", else "{}",
), ),
( (
"outputs", "outputs",
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable) json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
if domain_model.outputs if domain_model.outputs
else "{}", else "{}",
), ),

View File

@ -24,6 +24,8 @@ from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier
from libs.helper import extract_tenant_id from libs.helper import extract_tenant_id
from models import ( from models import (
Account, Account,
@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
node_execution_id=data.get("node_execution_id"), node_execution_id=data.get("node_execution_id"),
workflow_id=data.get("workflow_id", ""), workflow_id=data.get("workflow_id", ""),
workflow_execution_id=data.get("workflow_run_id"), workflow_execution_id=data.get("workflow_run_id"),
index=int(data.get("index", 0)), index=safe_int(data.get("index", 0)),
predecessor_node_id=data.get("predecessor_node_id"), predecessor_node_id=data.get("predecessor_node_id"),
node_id=data.get("node_id", ""), node_id=data.get("node_id", ""),
node_type=NodeType(data.get("node_type", "start")), node_type=NodeType(data.get("node_type", "start")),
@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
outputs=outputs, outputs=outputs,
status=status, status=status,
error=data.get("error"), error=data.get("error"),
elapsed_time=float(data.get("elapsed_time", 0.0)), elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
metadata=domain_metadata, metadata=domain_metadata,
created_at=created_at, created_at=created_at,
finished_at=finished_at, finished_at=finished_at,
@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database) # Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only # Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]: def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
logger.debug( logger.debug(
@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
Save or update the inputs, process_data, or outputs associated with a specific Save or update the inputs, process_data, or outputs associated with a specific
node_execution record. node_execution record.
For LogStore implementation, this is similar to save() since we always write For LogStore implementation, this is a no-op for the LogStore write because save()
complete records. We append a new record with updated data fields. already writes all fields including inputs, process_data, and outputs. The caller
typically calls save() first to persist status/metadata, then calls save_execution_data()
to persist data fields. Since LogStore writes complete records atomically, we don't
need a separate write here to avoid duplicate records.
However, if dual-write is enabled, we still need to call the SQL repository's
save_execution_data() method to properly update the SQL database.
Args: Args:
execution: The NodeExecution instance with data to save execution: The NodeExecution instance with data to save
""" """
logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id) logger.debug(
# In LogStore, we simply write a new complete record with the data "save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
# The log_version timestamp will ensure this is treated as the latest version execution.id,
self.save(execution) execution.node_execution_id,
)
# No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
# Calling save() again would create a duplicate record in the append-only LogStore
# Dual-write to SQL database if enabled (for safe migration)
if self._enable_dual_write:
try:
self.sql_repository.save_execution_data(execution)
logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
except Exception:
logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
# Don't raise - LogStore write succeeded, SQL is just a backup
def get_by_workflow_run( def get_by_workflow_run(
self, self,
@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
) -> Sequence[WorkflowNodeExecution]: ) -> Sequence[WorkflowNodeExecution]:
""" """
Retrieve all NodeExecution instances for a specific workflow run. Retrieve all NodeExecution instances for a specific workflow run.
Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication. Uses LogStore SQL query with window function to get the latest version of each node execution.
This ensures we only get the final version of each node execution. This ensures we only get the most recent version of each node execution record.
Args: Args:
workflow_run_id: The workflow run ID workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results order_config: Optional configuration for ordering results
@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
A list of NodeExecution instances A list of NodeExecution instances
Note: Note:
This method filters by finished_at IS NOT NULL to avoid duplicates from This method uses ROW_NUMBER() window function partitioned by node_execution_id
version updates. For complete history including intermediate states, to get the latest version (highest log_version) of each node execution.
a different query strategy would be needed.
""" """
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
# Build SQL query with deduplication using finished_at IS NOT NULL # Build SQL query with deduplication using window function
# This optimization avoids window functions for common case where we only # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
# want the final state of each node execution # ensures we get the latest version of each node execution
# Build ORDER BY clause # Escape parameters to prevent SQL injection
escaped_workflow_run_id = escape_identifier(workflow_run_id)
escaped_tenant_id = escape_identifier(self._tenant_id)
# Build ORDER BY clause for outer query
order_clause = "" order_clause = ""
if order_config and order_config.order_by: if order_config and order_config.order_by:
order_fields = [] order_fields = []
@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
if order_fields: if order_fields:
order_clause = "ORDER BY " + ", ".join(order_fields) order_clause = "ORDER BY " + ", ".join(order_fields)
sql = f""" # Build app_id filter for subquery
SELECT * app_id_filter = ""
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{workflow_run_id}'
AND tenant_id='{self._tenant_id}'
AND finished_at IS NOT NULL
"""
if self._app_id: if self._app_id:
sql += f" AND app_id='{self._app_id}'" escaped_app_id = escape_identifier(self._app_id)
app_id_filter = f" AND app_id='{escaped_app_id}'"
# Use window function to get latest version of each node execution
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{escaped_workflow_run_id}'
AND tenant_id='{escaped_tenant_id}'
{app_id_filter}
) t
WHERE rn = 1
"""
if order_clause: if order_clause:
sql += f" {order_clause}" sql += f" {order_clause}"

View File

@ -0,0 +1,134 @@
"""
SQL Escape Utility for LogStore Queries
This module provides escaping utilities to prevent injection attacks in LogStore queries.
LogStore supports two query modes:
1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
Key Security Concerns:
- Prevent tenant A from accessing tenant B's data via injection
- SLS queries are read-only, so we focus on data access control
- Different escaping strategies for SQL vs LogStore query syntax
"""
def escape_sql_string(value: str) -> str:
"""
Escape a string value for safe use in SQL queries.
This function escapes single quotes by doubling them, which is the standard
SQL escaping method. This prevents SQL injection by ensuring that user input
cannot break out of string literals.
Args:
value: The string value to escape
Returns:
Escaped string safe for use in SQL queries
Examples:
>>> escape_sql_string("normal_value")
"normal_value"
>>> escape_sql_string("value' OR '1'='1")
"value'' OR ''1''=''1"
>>> escape_sql_string("tenant's_id")
"tenant''s_id"
Security:
- Prevents breaking out of string literals
- Stops injection attacks like: ' OR '1'='1
- Protects against cross-tenant data access
"""
if not value:
return value
# Escape single quotes by doubling them (standard SQL escaping)
# This prevents breaking out of string literals in SQL queries
return value.replace("'", "''")
def escape_identifier(value: str) -> str:
"""
Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
This function is for PG protocol mode (SQL syntax).
For SDK mode, use escape_logstore_query_value() instead.
Args:
value: The identifier value to escape
Returns:
Escaped identifier safe for use in SQL queries
Examples:
>>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
"550e8400-e29b-41d4-a716-446655440000"
>>> escape_identifier("tenant_id' OR '1'='1")
"tenant_id'' OR ''1''=''1"
Security:
- Prevents SQL injection via identifiers
- Stops cross-tenant access attempts
- Works for UUIDs, alphanumeric IDs, and similar identifiers
"""
# For identifiers, use the same escaping as strings
# This is simple and effective for preventing injection
return escape_sql_string(value)
def escape_logstore_query_value(value: str) -> str:
"""
Escape value for LogStore query syntax (SDK mode).
LogStore query syntax rules:
1. Keywords (and/or/not) are case-insensitive
2. Single quotes are ordinary characters (no special meaning)
3. Double quotes wrap values: key:"value"
4. Backslash is the escape character:
- \" for double quote inside value
- \\ for backslash itself
5. Parentheses can change query structure
To prevent injection:
- Wrap value in double quotes to treat special chars as literals
- Escape backslashes and double quotes using backslash
Args:
value: The value to escape for LogStore query syntax
Returns:
Quoted and escaped value safe for LogStore query syntax (includes the quotes)
Examples:
>>> escape_logstore_query_value("normal_value")
'"normal_value"'
>>> escape_logstore_query_value("value or field:evil")
'"value or field:evil"' # 'or' and ':' are now literals
>>> escape_logstore_query_value('value"test')
'"value\\"test"' # Internal double quote escaped
>>> escape_logstore_query_value('value\\test')
'"value\\\\test"' # Backslash escaped
Security:
- Prevents injection via and/or/not keywords
- Prevents injection via colons (:)
- Prevents injection via parentheses
- Protects against cross-tenant data access
Note:
Escape order is critical: backslash first, then double quotes.
Otherwise, we'd double-escape the escape character itself.
"""
if not value:
return '""'
# IMPORTANT: Escape backslashes FIRST, then double quotes
# This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
escaped = value.replace("\\", "\\\\") # \ -> \\
escaped = escaped.replace('"', '\\"') # " -> \"
# Wrap in double quotes to treat as literal string
# This prevents and/or/not/:/() from being interpreted as operators
return f'"{escaped}"'

View File

@ -38,7 +38,7 @@ from core.variables.variables import (
ObjectVariable, ObjectVariable,
SecretVariable, SecretVariable,
StringVariable, StringVariable,
Variable, VariableBase,
) )
from core.workflow.constants import ( from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID,
@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
} }
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"): if not mapping.get("name"):
raise VariableError("missing name") raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]]) return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"): if not mapping.get("name"):
raise VariableError("missing name") raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("variable"): if not mapping.get("variable"):
raise VariableError("missing variable") raise VariableError("missing variable")
return mapping["variable"] return mapping["variable"]
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
""" """
This factory function is used to create the environment variable or the conversation variable, This factory function is used to create the environment variable or the conversation variable,
not support the File type. not support the File type.
@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
if (value := mapping.get("value")) is None: if (value := mapping.get("value")) is None:
raise VariableError("missing value") raise VariableError("missing value")
result: Variable result: VariableBase
match value_type: match value_type:
case SegmentType.STRING: case SegmentType.STRING:
result = StringVariable.model_validate(mapping) result = StringVariable.model_validate(mapping)
@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector: if not result.selector:
result = result.model_copy(update={"selector": selector}) result = result.model_copy(update={"selector": selector})
return cast(Variable, result) return cast(VariableBase, result)
def build_segment(value: Any, /) -> Segment: def build_segment(value: Any, /) -> Segment:
@ -285,8 +285,8 @@ def segment_to_variable(
id: str | None = None, id: str | None = None,
name: str | None = None, name: str | None = None,
description: str = "", description: str = "",
) -> Variable: ) -> VariableBase:
if isinstance(segment, Variable): if isinstance(segment, VariableBase):
return segment return segment
name = name or selector[-1] name = name or selector[-1]
id = id or str(uuid4()) id = id or str(uuid4())
@ -297,7 +297,7 @@ def segment_to_variable(
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast( return cast(
Variable, VariableBase,
variable_class( variable_class(
id=id, id=id,
name=name, name=name,

View File

@ -1,7 +1,7 @@
from flask_restx import fields from flask_restx import fields
from core.helper import encrypter from core.helper import encrypter
from core.variables import SecretVariable, SegmentType, Variable from core.variables import SecretVariable, SegmentType, VariableBase
from fields.member_fields import simple_account_fields from fields.member_fields import simple_account_fields
from libs.helper import TimestampField from libs.helper import TimestampField
@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
"value_type": value.value_type.value, "value_type": value.value_type.value,
"description": value.description, "description": value.description,
} }
if isinstance(value, Variable): if isinstance(value, VariableBase):
return { return {
"id": value.id, "id": value.id,
"name": value.name, "name": value.name,

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import json import json
import logging import logging
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import TYPE_CHECKING, Any, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4 from uuid import uuid4
import sqlalchemy as sa import sqlalchemy as sa
@ -46,7 +44,7 @@ if TYPE_CHECKING:
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter from core.helper import encrypter
from core.variables import SecretVariable, Segment, SegmentType, Variable from core.variables import SecretVariable, Segment, SegmentType, VariableBase
from factories import variable_factory from factories import variable_factory
from libs import helper from libs import helper
@ -69,7 +67,7 @@ class WorkflowType(StrEnum):
RAG_PIPELINE = "rag-pipeline" RAG_PIPELINE = "rag-pipeline"
@classmethod @classmethod
def value_of(cls, value: str) -> WorkflowType: def value_of(cls, value: str) -> "WorkflowType":
""" """
Get value of given mode. Get value of given mode.
@ -82,7 +80,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}") raise ValueError(f"invalid workflow type value {value}")
@classmethod @classmethod
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType: def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
""" """
Get workflow type from app mode. Get workflow type from app mode.
@ -178,12 +176,12 @@ class Workflow(Base): # bug
graph: str, graph: str,
features: str, features: str,
created_by: str, created_by: str,
environment_variables: Sequence[Variable], environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list[dict], rag_pipeline_variables: list[dict],
marked_name: str = "", marked_name: str = "",
marked_comment: str = "", marked_comment: str = "",
) -> Workflow: ) -> "Workflow":
workflow = Workflow() workflow = Workflow()
workflow.id = str(uuid4()) workflow.id = str(uuid4())
workflow.tenant_id = tenant_id workflow.tenant_id = tenant_id
@ -447,7 +445,7 @@ class Workflow(Base): # bug
# decrypt secret variables value # decrypt secret variables value
def decrypt_func( def decrypt_func(
var: Variable, var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: ) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable): if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
@ -463,7 +461,7 @@ class Workflow(Base): # bug
return decrypted_results return decrypted_results
@environment_variables.setter @environment_variables.setter
def environment_variables(self, value: Sequence[Variable]): def environment_variables(self, value: Sequence[VariableBase]):
if not value: if not value:
self._environment_variables = "{}" self._environment_variables = "{}"
return return
@ -487,7 +485,7 @@ class Workflow(Base): # bug
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value # encrypt secret variables value
def encrypt_func(var: Variable) -> Variable: def encrypt_func(var: VariableBase) -> VariableBase:
if isinstance(var, SecretVariable): if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else: else:
@ -517,7 +515,7 @@ class Workflow(Base): # bug
return result return result
@property @property
def conversation_variables(self) -> Sequence[Variable]: def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created. # TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None: if self._conversation_variables is None:
self._conversation_variables = "{}" self._conversation_variables = "{}"
@ -527,7 +525,7 @@ class Workflow(Base): # bug
return results return results
@conversation_variables.setter @conversation_variables.setter
def conversation_variables(self, value: Sequence[Variable]): def conversation_variables(self, value: Sequence[VariableBase]):
self._conversation_variables = json.dumps( self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value}, {var.name: var.model_dump() for var in value},
ensure_ascii=False, ensure_ascii=False,
@ -622,7 +620,7 @@ class WorkflowRun(Base):
finished_at: Mapped[datetime | None] = mapped_column(DateTime) finished_at: Mapped[datetime | None] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
pause: Mapped[WorkflowPause | None] = orm.relationship( pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
"WorkflowPause", "WorkflowPause",
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)", primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
uselist=False, uselist=False,
@ -692,7 +690,7 @@ class WorkflowRun(Base):
} }
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun: def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls( return cls(
id=data.get("id"), id=data.get("id"),
tenant_id=data.get("tenant_id"), tenant_id=data.get("tenant_id"),
@ -844,7 +842,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
created_by: Mapped[str] = mapped_column(StringUUID) created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime) finished_at: Mapped[datetime | None] = mapped_column(DateTime)
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship( offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
"WorkflowNodeExecutionOffload", "WorkflowNodeExecutionOffload",
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)", primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
uselist=True, uselist=True,
@ -854,13 +852,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod @staticmethod
def preload_offload_data( def preload_offload_data(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
): ):
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data)) return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
@staticmethod @staticmethod
def preload_offload_data_and_files( def preload_offload_data_and_files(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
): ):
return query.options( return query.options(
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options( orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
@ -935,7 +933,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
) )
return extras return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None: def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
return next(iter([i for i in self.offload_data if i.type_ == type_]), None) return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property @property
@ -1049,7 +1047,7 @@ class WorkflowNodeExecutionOffload(Base):
back_populates="offload_data", back_populates="offload_data",
) )
file: Mapped[UploadFile | None] = orm.relationship( file: Mapped[Optional["UploadFile"]] = orm.relationship(
foreign_keys=[file_id], foreign_keys=[file_id],
lazy="raise", lazy="raise",
uselist=False, uselist=False,
@ -1067,7 +1065,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
INSTALLED_APP = "installed-app" INSTALLED_APP = "installed-app"
@classmethod @classmethod
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom: def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
""" """
Get value of given mode. Get value of given mode.
@ -1184,7 +1182,7 @@ class ConversationVariable(TypeBase):
) )
@classmethod @classmethod
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable: def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
obj = cls( obj = cls(
id=variable.id, id=variable.id,
app_id=app_id, app_id=app_id,
@ -1193,7 +1191,7 @@ class ConversationVariable(TypeBase):
) )
return obj return obj
def to_variable(self) -> Variable: def to_variable(self) -> VariableBase:
mapping = json.loads(self.data) mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping) return variable_factory.build_conversation_variable_from_mapping(mapping)
@ -1337,7 +1335,7 @@ class WorkflowDraftVariable(Base):
) )
# Relationship to WorkflowDraftVariableFile # Relationship to WorkflowDraftVariableFile
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship( variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
foreign_keys=[file_id], foreign_keys=[file_id],
lazy="raise", lazy="raise",
uselist=False, uselist=False,
@ -1507,7 +1505,7 @@ class WorkflowDraftVariable(Base):
node_execution_id: str | None, node_execution_id: str | None,
description: str = "", description: str = "",
file_id: str | None = None, file_id: str | None = None,
) -> WorkflowDraftVariable: ) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable() variable = WorkflowDraftVariable()
variable.id = str(uuid4()) variable.id = str(uuid4())
variable.created_at = naive_utc_now() variable.created_at = naive_utc_now()
@ -1530,7 +1528,7 @@ class WorkflowDraftVariable(Base):
name: str, name: str,
value: Segment, value: Segment,
description: str = "", description: str = "",
) -> WorkflowDraftVariable: ) -> "WorkflowDraftVariable":
variable = cls._new( variable = cls._new(
app_id=app_id, app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID, node_id=CONVERSATION_VARIABLE_NODE_ID,
@ -1551,7 +1549,7 @@ class WorkflowDraftVariable(Base):
value: Segment, value: Segment,
node_execution_id: str, node_execution_id: str,
editable: bool = False, editable: bool = False,
) -> WorkflowDraftVariable: ) -> "WorkflowDraftVariable":
variable = cls._new( variable = cls._new(
app_id=app_id, app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID, node_id=SYSTEM_VARIABLE_NODE_ID,
@ -1574,7 +1572,7 @@ class WorkflowDraftVariable(Base):
visible: bool = True, visible: bool = True,
editable: bool = True, editable: bool = True,
file_id: str | None = None, file_id: str | None = None,
) -> WorkflowDraftVariable: ) -> "WorkflowDraftVariable":
variable = cls._new( variable = cls._new(
app_id=app_id, app_id=app_id,
node_id=node_id, node_id=node_id,
@ -1670,7 +1668,7 @@ class WorkflowDraftVariableFile(Base):
) )
# Relationship to UploadFile # Relationship to UploadFile
upload_file: Mapped[UploadFile] = orm.relationship( upload_file: Mapped["UploadFile"] = orm.relationship(
foreign_keys=[upload_file_id], foreign_keys=[upload_file_id],
lazy="raise", lazy="raise",
uselist=False, uselist=False,
@ -1737,7 +1735,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# Relationship to WorkflowRun # Relationship to WorkflowRun
workflow_run: Mapped[WorkflowRun] = orm.relationship( workflow_run: Mapped["WorkflowRun"] = orm.relationship(
foreign_keys=[workflow_run_id], foreign_keys=[workflow_run_id],
# require explicit preloading. # require explicit preloading.
lazy="raise", lazy="raise",
@ -1793,7 +1791,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
) )
@classmethod @classmethod
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason: def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired): if isinstance(pause_reason, HumanInputRequired):
return cls( return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dify-api" name = "dify-api"
version = "1.11.2" version = "1.11.3"
requires-python = ">=3.11,<3.13" requires-python = ">=3.11,<3.13"
dependencies = [ dependencies = [

View File

@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, cast from typing import Any, cast
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import func from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
@ -748,6 +748,21 @@ class AccountService:
cls.email_code_login_rate_limiter.increment_rate_limit(email) cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token return token
@staticmethod
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
"""
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
This keeps backward compatibility for older records that stored uppercase emails while the
rest of the system gradually normalizes new inputs.
"""
query_session = session or db.session
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
@classmethod @classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login") return TokenManager.get_token_data(token, "email_code_login")
@ -1363,16 +1378,22 @@ class RegisterService:
if not inviter: if not inviter:
raise ValueError("Inviter is required") raise ValueError("Inviter is required")
normalized_email = email.lower()
"""Invite new member""" """Invite new member"""
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.query(Account).filter_by(email=email).first() account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if not account: if not account:
TenantService.check_member_permission(tenant, inviter, None, "add") TenantService.check_member_permission(tenant, inviter, None, "add")
name = email.split("@")[0] name = normalized_email.split("@")[0]
account = cls.register( account = cls.register(
email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True email=normalized_email,
name=name,
language=language,
status=AccountStatus.PENDING,
is_setup=True,
) )
# Create new tenant member for invited tenant # Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role) TenantService.create_tenant_member(tenant, account, role)
@ -1394,7 +1415,7 @@ class RegisterService:
# send email # send email
send_invite_member_mail_task.delay( send_invite_member_mail_task.delay(
language=language, language=language,
to=email, to=account.email,
token=token, token=token,
inviter_name=inviter.name if inviter else "Dify", inviter_name=inviter.name if inviter else "Dify",
workspace_name=tenant.name, workspace_name=tenant.name,
@ -1493,6 +1514,16 @@ class RegisterService:
invitation: dict = json.loads(data) invitation: dict = json.loads(data)
return invitation return invitation
@classmethod
def get_invitation_with_case_fallback(
cls, workspace_id: str | None, email: str | None, token: str
) -> dict[str, Any] | None:
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
if invitation or not email or email == email.lower():
return invitation
normalized_email = email.lower()
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
def _generate_refresh_token(length: int = 64): def _generate_refresh_token(length: int = 64):
token = secrets.token_hex(length) token = secrets.token_hex(length)

View File

@ -1,7 +1,7 @@
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.variables.variables import Variable from core.variables.variables import VariableBase
from models import ConversationVariable from models import ConversationVariable
@ -13,7 +13,7 @@ class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None: def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker self._session_maker: sessionmaker[Session] = session_maker
def update(self, conversation_id: str, variable: Variable) -> None: def update(self, conversation_id: str, variable: VariableBase) -> None:
stmt = select(ConversationVariable).where( stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
) )

View File

@ -1,9 +1,14 @@
import logging
import os import os
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
import httpx import httpx
from core.helper.trace_id_helper import generate_traceparent_header
logger = logging.getLogger(__name__)
class BaseRequest: class BaseRequest:
proxies: Mapping[str, str] | None = { proxies: Mapping[str, str] | None = {
@ -38,6 +43,15 @@ class BaseRequest:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}" url = f"{cls.base_url}{endpoint}"
mounts = cls._build_mounts() mounts = cls._build_mounts()
try:
# ensure traceparent even when OTEL is disabled
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
except Exception:
logger.debug("Failed to generate traceparent header", exc_info=True)
with httpx.Client(mounts=mounts) as client: with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers) response = client.request(method, url, json=json, params=params, headers=headers)
return response.json() return response.json()

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from mimetypes import guess_type from mimetypes import guess_type
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select
from yarl import URL from yarl import URL
from configs import dify_config from configs import dify_config
@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import (
from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.provider import ProviderCredential
from models.provider_ids import GenericProviderID from models.provider_ids import GenericProviderID
from services.errors.plugin import PluginInstallationForbiddenError from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope from services.feature_service import FeatureService, PluginInstallationScope
@ -506,6 +509,33 @@ class PluginService:
@staticmethod @staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool: def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstaller() manager = PluginInstaller()
# Get plugin info before uninstalling to delete associated credentials
try:
plugins = manager.list_plugins(tenant_id)
plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
if plugin:
plugin_id = plugin.plugin_id
logger.info("Deleting credentials for plugin: %s", plugin_id)
# Delete provider credentials that match this plugin
credentials = db.session.scalars(
select(ProviderCredential).where(
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name.like(f"{plugin_id}/%"),
)
).all()
for cred in credentials:
db.session.delete(cred)
db.session.commit()
logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id)
except Exception as e:
logger.warning("Failed to delete credentials: %s", e)
# Continue with uninstall even if credential deletion fails
return manager.uninstall(tenant_id, plugin_installation_id) return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod @staticmethod

View File

@ -36,7 +36,7 @@ from core.rag.entities.event import (
) )
from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable from core.variables.variables import VariableBase
from core.workflow.entities.workflow_node_execution import ( from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution, WorkflowNodeExecution,
WorkflowNodeExecutionStatus, WorkflowNodeExecutionStatus,
@ -270,8 +270,8 @@ class RagPipelineService:
graph: dict, graph: dict,
unique_hash: str | None, unique_hash: str | None,
account: Account, account: Account,
environment_variables: Sequence[Variable], environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list, rag_pipeline_variables: list,
) -> Workflow: ) -> Workflow:
""" """

View File

@ -12,6 +12,7 @@ from libs.passport import PassportService
from libs.password import compare_password from libs.password import compare_password
from models import Account, AccountStatus from models import Account, AccountStatus
from models.model import App, EndUser, Site from models.model import App, EndUser, Site
from services.account_service import AccountService
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
@ -32,7 +33,7 @@ class WebAppAuthService:
@staticmethod @staticmethod
def authenticate(email: str, password: str) -> Account: def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password""" """authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first() account = AccountService.get_account_by_email_with_case_fallback(email)
if not account: if not account:
raise AccountNotFoundError() raise AccountNotFoundError()
@ -52,7 +53,7 @@ class WebAppAuthService:
@classmethod @classmethod
def get_user_through_email(cls, email: str): def get_user_through_email(cls, email: str):
account = db.session.query(Account).where(Account.email == email).first() account = AccountService.get_account_by_email_with_case_fallback(email)
if not account: if not account:
return None return None

View File

@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
from configs import dify_config from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File from core.file.models import File
from core.variables import Segment, StringSegment, Variable from core.variables import Segment, StringSegment, VariableBase
from core.variables.consts import SELECTORS_LENGTH from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import ( from core.variables.segments import (
ArrayFileSegment, ArrayFileSegment,
@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
# Application ID for which variables are being loaded. # Application ID for which variables are being loaded.
_app_id: str _app_id: str
_tenant_id: str _tenant_id: str
_fallback_variables: Sequence[Variable] _fallback_variables: Sequence[VariableBase]
def __init__( def __init__(
self, self,
engine: Engine, engine: Engine,
app_id: str, app_id: str,
tenant_id: str, tenant_id: str,
fallback_variables: Sequence[Variable] | None = None, fallback_variables: Sequence[VariableBase] | None = None,
): ):
self._engine = engine self._engine = engine
self._app_id = app_id self._app_id = app_id
@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
return (selector[0], selector[1]) return (selector[0], selector[1])
def load_variables(self, selectors: list[list[str]]) -> list[Variable]: def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
if not selectors: if not selectors:
return [] return []
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance. # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
variable_by_selector: dict[tuple[str, str], Variable] = {} variable_by_selector: dict[tuple[str, str], VariableBase] = {}
with Session(bind=self._engine, expire_on_commit=False) as session: with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session) srv = WorkflowDraftVariableService(session)
@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
return list(variable_by_selector.values()) return list(variable_by_selector.values())
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]: def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable` # This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
# and must remain synchronized with it. # and must remain synchronized with it.
# Ideally, these should be co-located for better maintainability. # Ideally, these should be co-located for better maintainability.

View File

@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File from core.file import File
from core.repositories import DifyCoreRepositoryFactory from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable from core.variables import VariableBase
from core.variables.variables import VariableUnion from core.variables.variables import Variable
from core.workflow.entities import WorkflowNodeExecution from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
@ -198,8 +198,8 @@ class WorkflowService:
features: dict, features: dict,
unique_hash: str | None, unique_hash: str | None,
account: Account, account: Account,
environment_variables: Sequence[Variable], environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[VariableBase],
) -> Workflow: ) -> Workflow:
""" """
Sync draft workflow Sync draft workflow
@ -1044,7 +1044,7 @@ def _setup_variable_pool(
workflow: Workflow, workflow: Workflow,
node_type: NodeType, node_type: NodeType,
conversation_id: str, conversation_id: str,
conversation_variables: list[Variable], conversation_variables: list[VariableBase],
): ):
# Only inject system variables for START node type. # Only inject system variables for START node type.
if node_type == NodeType.START or node_type.is_trigger_node: if node_type == NodeType.START or node_type.is_trigger_node:
@ -1070,9 +1070,9 @@ def _setup_variable_pool(
system_variables=system_variable, system_variables=system_variable,
user_inputs=user_inputs, user_inputs=user_inputs,
environment_variables=workflow.environment_variables, environment_variables=workflow.environment_variables,
# Based on the definition of `VariableUnion`, # Based on the definition of `Variable`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. # `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), # conversation_variables=cast(list[Variable], conversation_variables), #
) )
return variable_pool return variable_pool

View File

@ -40,7 +40,7 @@ class TestActivateCheckApi:
"tenant": tenant, "tenant": tenant,
} }
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation): def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation):
""" """
Test checking valid invitation token. Test checking valid invitation token.
@ -66,7 +66,7 @@ class TestActivateCheckApi:
assert response["data"]["workspace_id"] == "workspace-123" assert response["data"]["workspace_id"] == "workspace-123"
assert response["data"]["email"] == "invitee@example.com" assert response["data"]["email"] == "invitee@example.com"
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_invalid_invitation_token(self, mock_get_invitation, app): def test_check_invalid_invitation_token(self, mock_get_invitation, app):
""" """
Test checking invalid invitation token. Test checking invalid invitation token.
@ -88,7 +88,7 @@ class TestActivateCheckApi:
# Assert # Assert
assert response["is_valid"] is False assert response["is_valid"] is False
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation): def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation):
""" """
Test checking token without workspace ID. Test checking token without workspace ID.
@ -109,7 +109,7 @@ class TestActivateCheckApi:
assert response["is_valid"] is True assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token") mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation): def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation):
""" """
Test checking token without email parameter. Test checking token without email parameter.
@ -130,6 +130,20 @@ class TestActivateCheckApi:
assert response["is_valid"] is True assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token") mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation):
"""Ensure token validation uses lowercase emails."""
mock_get_invitation.return_value = mock_invitation
with app.test_request_context(
"/activate/check?workspace_id=workspace-123&email=Invitee@Example.com&token=valid_token"
):
api = ActivateCheckApi()
response = api.get()
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
class TestActivateApi: class TestActivateApi:
"""Test cases for account activation endpoint.""" """Test cases for account activation endpoint."""
@ -212,7 +226,7 @@ class TestActivateApi:
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
mock_db.session.commit.assert_called_once() mock_db.session.commit.assert_called_once()
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_activation_with_invalid_token(self, mock_get_invitation, app): def test_activation_with_invalid_token(self, mock_get_invitation, app):
""" """
Test account activation with invalid token. Test account activation with invalid token.
@ -241,7 +255,7 @@ class TestActivateApi:
with pytest.raises(AlreadyActivateError): with pytest.raises(AlreadyActivateError):
api.post() api.post()
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db") @patch("controllers.console.auth.activate.db")
def test_activation_sets_interface_theme( def test_activation_sets_interface_theme(
@ -290,7 +304,7 @@ class TestActivateApi:
("es-ES", "Europe/Madrid"), ("es-ES", "Europe/Madrid"),
], ],
) )
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db") @patch("controllers.console.auth.activate.db")
def test_activation_with_different_locales( def test_activation_with_different_locales(
@ -336,7 +350,7 @@ class TestActivateApi:
assert mock_account.interface_language == language assert mock_account.interface_language == language
assert mock_account.timezone == timezone assert mock_account.timezone == timezone
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db") @patch("controllers.console.auth.activate.db")
def test_activation_returns_success_response( def test_activation_returns_success_response(
@ -376,7 +390,7 @@ class TestActivateApi:
# Assert # Assert
assert response == {"result": "success"} assert response == {"result": "success"}
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db") @patch("controllers.console.auth.activate.db")
def test_activation_without_workspace_id( def test_activation_without_workspace_id(
@ -415,3 +429,37 @@ class TestActivateApi:
# Assert # Assert
assert response["result"] == "success" assert response["result"] == "success"
mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token") mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
def test_activation_normalizes_email_before_lookup(
self,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_account,
):
"""Ensure uppercase emails are normalized before lookup and revocation."""
mock_get_invitation.return_value = mock_invitation
with app.test_request_context(
"/activate",
method="POST",
json={
"workspace_id": "workspace-123",
"email": "Invitee@Example.com",
"token": "valid_token",
"name": "John Doe",
"interface_language": "en-US",
"timezone": "UTC",
},
):
api = ActivateApi()
response = api.post()
assert response["result"] == "success"
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")

View File

@ -34,7 +34,7 @@ class TestAuthenticationSecurity:
@patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_invalid_email_with_registration_allowed( def test_login_invalid_email_with_registration_allowed(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
): ):
@ -67,7 +67,7 @@ class TestAuthenticationSecurity:
@patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_wrong_password_returns_error( def test_login_wrong_password_returns_error(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
): ):
@ -100,7 +100,7 @@ class TestAuthenticationSecurity:
@patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_invalid_email_with_registration_disabled( def test_login_invalid_email_with_registration_disabled(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
): ):

View File

@ -0,0 +1,177 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.email_register import (
EmailRegisterCheckApi,
EmailRegisterResetApi,
EmailRegisterSendEmailApi,
)
from services.account_service import AccountService
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
class TestEmailRegisterSendEmailApi:
@patch("controllers.console.auth.email_register.Session")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
@patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
def test_send_email_normalizes_and_falls_back(
self,
mock_extract_ip,
mock_is_email_send_ip_limit,
mock_is_freeze,
mock_send_mail,
mock_get_account,
mock_session_cls,
app,
):
mock_send_mail.return_value = "token-123"
mock_is_freeze.return_value = False
mock_account = MagicMock()
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_get_account.return_value = mock_account
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
with app.test_request_context(
"/email-register/send-email",
method="POST",
json={"email": "Invitee@Example.com", "language": "en-US"},
):
response = EmailRegisterSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_is_freeze.assert_called_once_with("invitee@example.com")
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
mock_extract_ip.assert_called_once()
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
class TestEmailRegisterCheckApi:
@patch("controllers.console.auth.email_register.AccountService.reset_email_register_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.generate_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.add_email_register_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
@patch("controllers.console.auth.email_register.AccountService.is_email_register_error_rate_limit")
def test_validity_normalizes_email_before_checks(
self,
mock_rate_limit_check,
mock_get_data,
mock_add_rate,
mock_revoke,
mock_generate_token,
mock_reset_rate,
app,
):
mock_rate_limit_check.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"}
mock_generate_token.return_value = (None, "new-token")
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
with app.test_request_context(
"/email-register/validity",
method="POST",
json={"email": "User@Example.com", "code": "4321", "token": "token-123"},
):
response = EmailRegisterCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
mock_rate_limit_check.assert_called_once_with("user@example.com")
mock_generate_token.assert_called_once_with(
"user@example.com", code="4321", additional_data={"phase": "register"}
)
mock_reset_rate.assert_called_once_with("user@example.com")
mock_add_rate.assert_not_called()
mock_revoke.assert_called_once_with("token-123")
class TestEmailRegisterResetApi:
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.login")
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
@patch("controllers.console.auth.email_register.Session")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
def test_reset_creates_account_with_normalized_email(
self,
mock_extract_ip,
mock_get_data,
mock_revoke_token,
mock_get_account,
mock_session_cls,
mock_create_account,
mock_login,
mock_reset_login_rate,
app,
):
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
mock_create_account.return_value = MagicMock()
token_pair = MagicMock()
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
mock_login.return_value = token_pair
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_get_account.return_value = None
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
with app.test_request_context(
"/email-register",
method="POST",
json={"token": "token-123", "new_password": "ValidPass123!", "password_confirm": "ValidPass123!"},
):
response = EmailRegisterResetApi().post()
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!")
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_extract_ip.assert_called_once()
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
mock_session = MagicMock()
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_query = MagicMock()
second_query.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_query, second_query]
account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
assert account is expected_account
assert mock_session.execute.call_count == 2

View File

@ -0,0 +1,176 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.forgot_password import (
ForgotPasswordCheckApi,
ForgotPasswordResetApi,
ForgotPasswordSendEmailApi,
)
from services.account_service import AccountService
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
class TestForgotPasswordSendEmailApi:
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1")
def test_send_normalizes_email(
self,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_account,
mock_session_cls,
app,
):
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_send_email.return_value = "token-123"
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
controller_features = SimpleNamespace(is_allow_register=True)
with (
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
patch(
"controllers.console.auth.forgot_password.FeatureService.get_system_features",
return_value=controller_features,
),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
with app.test_request_context(
"/forgot-password",
method="POST",
json={"email": "User@Example.com", "language": "zh-Hans"},
):
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_send_email.assert_called_once_with(
account=mock_account,
email="user@example.com",
language="zh-Hans",
is_allow_register=True,
)
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_extract_ip.assert_called_once()
class TestForgotPasswordCheckApi:
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_check_normalizes_email(
self,
mock_rate_limit_check,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
):
mock_rate_limit_check.return_value = False
mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"}
mock_generate_token.return_value = (None, "new-token")
wraps_features = SimpleNamespace(enable_email_password_login=True)
with (
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "ADMIN@Example.com", "code": "4321", "token": "token-123"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "admin@example.com", "token": "new-token"}
mock_rate_limit_check.assert_called_once_with("admin@example.com")
mock_generate_token.assert_called_once_with(
"Admin@Example.com",
code="4321",
additional_data={"phase": "reset"},
)
mock_reset_rate.assert_called_once_with("admin@example.com")
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_fetches_account_with_original_email(
self,
mock_get_reset_data,
mock_revoke_token,
mock_get_account,
mock_session_cls,
mock_update_account,
app,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
wraps_features = SimpleNamespace(enable_email_password_login=True)
with (
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
with app.test_request_context(
"/forgot-password/resets",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_reset_data.assert_called_once_with("token-123")
mock_revoke_token.assert_called_once_with("token-123")
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_update_account.assert_called_once()
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
mock_session = MagicMock()
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_query = MagicMock()
second_query.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_query, second_query]
account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
assert account is expected_account
assert mock_session.execute.call_count == 2

View File

@ -76,7 +76,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.AccountService.login") @patch("controllers.console.auth.login.AccountService.login")
@ -120,7 +120,7 @@ class TestLoginApi:
response = login_api.post() response = login_api.post()
# Assert # Assert
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!") mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None)
mock_login.assert_called_once() mock_login.assert_called_once()
mock_reset_rate_limit.assert_called_once_with("test@example.com") mock_reset_rate_limit.assert_called_once_with("test@example.com")
assert response.json["result"] == "success" assert response.json["result"] == "success"
@ -128,7 +128,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.AccountService.login") @patch("controllers.console.auth.login.AccountService.login")
@ -182,7 +182,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
""" """
Test login rejection when rate limit is exceeded. Test login rejection when rate limit is exceeded.
@ -230,7 +230,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
def test_login_fails_with_invalid_credentials( def test_login_fails_with_invalid_credentials(
@ -269,7 +269,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.authenticate")
def test_login_fails_for_banned_account( def test_login_fails_for_banned_account(
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app
@ -298,7 +298,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.FeatureService.get_system_features") @patch("controllers.console.auth.login.FeatureService.get_system_features")
@ -343,7 +343,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
""" """
Test login failure when invitation email doesn't match login email. Test login failure when invitation email doesn't match login email.
@ -371,6 +371,52 @@ class TestLoginApi:
with pytest.raises(InvalidEmailError): with pytest.raises(InvalidEmailError):
login_api.post() login_api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.AccountService.login")
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
def test_login_retries_with_lowercase_email(
self,
mock_reset_rate_limit,
mock_login_service,
mock_get_tenants,
mock_add_rate_limit,
mock_authenticate,
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
mock_account,
mock_token_pair,
):
"""Test that login retries with lowercase email when uppercase lookup fails."""
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
mock_get_tenants.return_value = [MagicMock()]
mock_login_service.return_value = mock_token_pair
with app.test_request_context(
"/login",
method="POST",
json={"email": "Upper@Example.com", "password": encode_password("ValidPass123!")},
):
response = LoginApi().post()
assert response.json["result"] == "success"
assert mock_authenticate.call_args_list == [
(("Upper@Example.com", "ValidPass123!", None), {}),
(("upper@example.com", "ValidPass123!", None), {}),
]
mock_add_rate_limit.assert_not_called()
mock_reset_rate_limit.assert_called_once_with("upper@example.com")
class TestLogoutApi: class TestLogoutApi:
"""Test cases for the LogoutApi endpoint.""" """Test cases for the LogoutApi endpoint."""

View File

@ -12,6 +12,7 @@ from controllers.console.auth.oauth import (
) )
from libs.oauth import OAuthUserInfo from libs.oauth import OAuthUserInfo
from models.account import AccountStatus from models.account import AccountStatus
from services.account_service import AccountService
from services.errors.account import AccountRegisterError from services.errors.account import AccountRegisterError
@ -215,6 +216,34 @@ class TestOAuthCallback:
assert status_code == 400 assert status_code == 400
assert response["error"] == expected_error assert response["error"] == expected_error
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.redirect")
def test_invitation_comparison_is_case_insensitive(
self,
mock_redirect,
mock_register_service,
mock_get_providers,
mock_config,
resource,
app,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo(
id="123", name="Test User", email="User@Example.com"
)
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
mock_register_service.is_valid_invite_token.return_value = True
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
resource.get("github")
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123")
@pytest.mark.parametrize( @pytest.mark.parametrize(
("account_status", "expected_redirect"), ("account_status", "expected_redirect"),
[ [
@ -395,12 +424,12 @@ class TestAccountGeneration:
account.name = "Test User" account.name = "Test User"
return account return account
@patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.oauth.Account")
@patch("controllers.console.auth.oauth.Session") @patch("controllers.console.auth.oauth.Session")
@patch("controllers.console.auth.oauth.select") @patch("controllers.console.auth.oauth.Account")
@patch("controllers.console.auth.oauth.db")
def test_should_get_account_by_openid_or_email( def test_should_get_account_by_openid_or_email(
self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
): ):
# Mock db.engine for Session creation # Mock db.engine for Session creation
mock_db.engine = MagicMock() mock_db.engine = MagicMock()
@ -410,15 +439,31 @@ class TestAccountGeneration:
result = _get_account_by_openid_or_email("github", user_info) result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account assert result == mock_account
mock_account_model.get_by_openid.assert_called_once_with("github", "123") mock_account_model.get_by_openid.assert_called_once_with("github", "123")
mock_get_account.assert_not_called()
# Test fallback to email # Test fallback to email lookup
mock_account_model.get_by_openid.return_value = None mock_account_model.get_by_openid.return_value = None
mock_session_instance = MagicMock() mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance mock_session.return_value.__enter__.return_value = mock_session_instance
mock_get_account.return_value = mock_account
result = _get_account_by_openid_or_email("github", user_info) result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account assert result == mock_account
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
mock_session = MagicMock()
first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_result = MagicMock()
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
assert result == expected_account
assert mock_session.execute.call_count == 2
@pytest.mark.parametrize( @pytest.mark.parametrize(
("allow_register", "existing_account", "should_create"), ("allow_register", "existing_account", "should_create"),
@ -466,6 +511,35 @@ class TestAccountGeneration:
mock_register_service.register.assert_called_once_with( mock_register_service.register.assert_called_once_with(
email="test@example.com", name="Test User", password=None, open_id="123", provider="github" email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
) )
else:
mock_register_service.register.assert_not_called()
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_register_with_lowercase_email(
self,
mock_db,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
app,
):
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
mock_feature_service.get_system_features.return_value.is_allow_register = True
mock_register_service.register.return_value = MagicMock()
with app.test_request_context(headers={"Accept-Language": "en-US"}):
_generate_account("github", user_info)
mock_register_service.register.assert_called_once_with(
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email") @patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
@patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.TenantService")

View File

@ -28,6 +28,22 @@ from controllers.console.auth.forgot_password import (
from controllers.console.error import AccountNotFound, EmailSendIpLimitError from controllers.console.error import AccountNotFound, EmailSendIpLimitError
@pytest.fixture(autouse=True)
def _mock_forgot_password_session():
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_session_cls.return_value.__exit__.return_value = None
yield mock_session
@pytest.fixture(autouse=True)
def _mock_forgot_password_db():
with patch("controllers.console.auth.forgot_password.db") as mock_db:
mock_db.engine = MagicMock()
yield mock_db
class TestForgotPasswordSendEmailApi: class TestForgotPasswordSendEmailApi:
"""Test cases for sending password reset emails.""" """Test cases for sending password reset emails."""
@ -47,20 +63,16 @@ class TestForgotPasswordSendEmailApi:
return account return account
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.select")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features") @patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
def test_send_reset_email_success( def test_send_reset_email_success(
self, self,
mock_get_features, mock_get_features,
mock_send_email, mock_send_email,
mock_select, mock_get_account,
mock_session,
mock_is_ip_limit, mock_is_ip_limit,
mock_forgot_db,
mock_wraps_db, mock_wraps_db,
app, app,
mock_account, mock_account,
@ -75,11 +87,8 @@ class TestForgotPasswordSendEmailApi:
""" """
# Arrange # Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_is_ip_limit.return_value = False mock_is_ip_limit.return_value = False
mock_session_instance = MagicMock() mock_get_account.return_value = mock_account
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_send_email.return_value = "reset_token_123" mock_send_email.return_value = "reset_token_123"
mock_get_features.return_value.is_allow_register = True mock_get_features.return_value.is_allow_register = True
@ -125,20 +134,16 @@ class TestForgotPasswordSendEmailApi:
], ],
) )
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.select")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features") @patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
def test_send_reset_email_language_handling( def test_send_reset_email_language_handling(
self, self,
mock_get_features, mock_get_features,
mock_send_email, mock_send_email,
mock_select, mock_get_account,
mock_session,
mock_is_ip_limit, mock_is_ip_limit,
mock_forgot_db,
mock_wraps_db, mock_wraps_db,
app, app,
mock_account, mock_account,
@ -154,11 +159,8 @@ class TestForgotPasswordSendEmailApi:
""" """
# Arrange # Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_is_ip_limit.return_value = False mock_is_ip_limit.return_value = False
mock_session_instance = MagicMock() mock_get_account.return_value = mock_account
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_send_email.return_value = "token" mock_send_email.return_value = "token"
mock_get_features.return_value.is_allow_register = True mock_get_features.return_value.is_allow_register = True
@ -229,8 +231,46 @@ class TestForgotPasswordCheckApi:
assert response["email"] == "test@example.com" assert response["email"] == "test@example.com"
assert response["token"] == "new_token" assert response["token"] == "new_token"
mock_revoke_token.assert_called_once_with("old_token") mock_revoke_token.assert_called_once_with("old_token")
mock_generate_token.assert_called_once_with(
"test@example.com", code="123456", additional_data={"phase": "reset"}
)
mock_reset_rate_limit.assert_called_once_with("test@example.com") mock_reset_rate_limit.assert_called_once_with("test@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
def test_verify_code_preserves_token_email_case(
self,
mock_reset_rate_limit,
mock_generate_token,
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
mock_generate_token.return_value = (None, "fresh-token")
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "user@example.com", "code": "999888", "token": "upper_token"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "fresh-token"}
mock_generate_token.assert_called_once_with(
"User@Example.com", code="999888", additional_data={"phase": "reset"}
)
mock_revoke_token.assert_called_once_with("upper_token")
mock_reset_rate_limit.assert_called_once_with("user@example.com")
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app): def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
@ -355,20 +395,16 @@ class TestForgotPasswordResetApi:
return account return account
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.select")
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
def test_reset_password_success( def test_reset_password_success(
self, self,
mock_get_tenants, mock_get_tenants,
mock_select, mock_get_account,
mock_session,
mock_revoke_token, mock_revoke_token,
mock_get_data, mock_get_data,
mock_forgot_db,
mock_wraps_db, mock_wraps_db,
app, app,
mock_account, mock_account,
@ -383,11 +419,8 @@ class TestForgotPasswordResetApi:
""" """
# Arrange # Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
mock_session_instance = MagicMock() mock_get_account.return_value = mock_account
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_get_tenants.return_value = [MagicMock()] mock_get_tenants.return_value = [MagicMock()]
# Act # Act
@ -475,13 +508,11 @@ class TestForgotPasswordResetApi:
api.post() api.post()
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.select")
def test_reset_password_account_not_found( def test_reset_password_account_not_found(
self, mock_select, mock_session, mock_revoke_token, mock_get_data, mock_forgot_db, mock_wraps_db, app self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
): ):
""" """
Test password reset for non-existent account. Test password reset for non-existent account.
@ -491,11 +522,8 @@ class TestForgotPasswordResetApi:
""" """
# Arrange # Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"} mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
mock_session_instance = MagicMock() mock_get_account.return_value = None
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
mock_session.return_value.__enter__.return_value = mock_session_instance
# Act & Assert # Act & Assert
with app.test_request_context( with app.test_request_context(

View File

@ -0,0 +1,39 @@
from types import SimpleNamespace
from unittest.mock import patch
from controllers.console.setup import SetupApi
class TestSetupApi:
def test_post_lowercases_email_before_register(self):
"""Ensure setup registration normalizes email casing."""
payload = {
"email": "Admin@Example.com",
"name": "Admin User",
"password": "ValidPass123!",
"language": "en-US",
}
setup_api = SetupApi(api=None)
mock_console_ns = SimpleNamespace(payload=payload)
with (
patch("controllers.console.setup.console_ns", mock_console_ns),
patch("controllers.console.setup.get_setup_status", return_value=False),
patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0),
patch("controllers.console.setup.get_init_validate_status", return_value=True),
patch("controllers.console.setup.extract_remote_ip", return_value="127.0.0.1"),
patch("controllers.console.setup.request", object()),
patch("controllers.console.setup.RegisterService.setup") as mock_register,
):
response, status = setup_api.post()
assert response == {"result": "success"}
assert status == 201
mock_register.assert_called_once_with(
email="admin@example.com",
name=payload["name"],
password=payload["password"],
ip_address="127.0.0.1",
language=payload["language"],
)

View File

@ -0,0 +1,247 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, g
from controllers.console.workspace.account import (
AccountDeleteUpdateFeedbackApi,
ChangeEmailCheckApi,
ChangeEmailResetApi,
ChangeEmailSendEmailApi,
CheckEmailUnique,
)
from models import Account
from services.account_service import AccountService
@pytest.fixture
def app():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["RESTX_MASK_HEADER"] = "X-Fields"
app.login_manager = SimpleNamespace(_load_user=lambda: None)
return app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
account = Account(name=account_id, email=email)
account.email = email
account.id = account_id
account.status = "active"
account._current_tenant = tenant_obj
return account
def _set_logged_in_user(account: Account):
g._login_user = account
g._current_tenant = account.current_tenant
class TestChangeEmailSend:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_normalize_new_email_phase(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = {"email": "current@example.com"}
mock_send_email.return_value = "token-abc"
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailSendEmailApi().post()
assert response == {"result": "success", "data": "token-abc"}
mock_send_email.assert_called_once_with(
account=None,
email="new@example.com",
old_email="current@example.com",
language="en-US",
phase="new_email",
)
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
class TestChangeEmailValidity:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_validate_with_normalized_email(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
mock_generate_token.return_value = (None, "new-token")
with app.test_request_context(
"/account/change-email/validity",
method="POST",
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
mock_is_rate_limit.assert_called_once_with("user@example.com")
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
mock_generate_token.assert_called_once_with(
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
)
mock_reset_rate.assert_called_once_with("user@example.com")
mock_csrf.assert_called_once()
class TestChangeEmailReset:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_normalize_new_email_before_update(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {"old_email": "OLD@example.com"}
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
mock_update_account.return_value = mock_account_after_update
with app.test_request_context(
"/account/change-email/reset",
method="POST",
json={"new_email": "New@Example.com", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
ChangeEmailResetApi().post()
mock_is_freeze.assert_called_once_with("new@example.com")
mock_check_unique.assert_called_once_with("new@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_update_account.assert_called_once_with(current_user, email="new@example.com")
mock_send_notify.assert_called_once_with(email="new@example.com")
mock_csrf.assert_called_once()
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
_mock_wraps_db(mock_db)
with app.test_request_context(
"/account/delete/feedback",
method="POST",
json={"email": "User@Example.com", "feedback": "test"},
):
response = AccountDeleteUpdateFeedbackApi().post()
assert response == {"result": "success"}
mock_update.assert_called_once_with("User@Example.com", "test")
class TestCheckEmailUnique:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
_mock_wraps_db(mock_db)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
with app.test_request_context(
"/account/change-email/check-email-unique",
method="POST",
json={"email": "Case@Test.com"},
):
response = CheckEmailUnique().post()
assert response == {"result": "success"}
mock_is_freeze.assert_called_once_with("case@test.com")
mock_check_unique.assert_called_once_with("case@test.com")
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
session = MagicMock()
first = MagicMock()
first.scalar_one_or_none.return_value = None
second = MagicMock()
expected_account = MagicMock()
second.scalar_one_or_none.return_value = expected_account
session.execute.side_effect = [first, second]
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session)
assert result is expected_account
assert session.execute.call_count == 2

View File

@ -0,0 +1,82 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, g
from controllers.console.workspace.members import MemberInviteEmailApi
from models.account import Account, TenantAccountRole
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.login_manager = SimpleNamespace(_load_user=lambda: None)
return flask_app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_feature_flags():
placeholder_quota = SimpleNamespace(limit=0, size=0)
workspace_members = SimpleNamespace(is_available=lambda count: True)
return SimpleNamespace(
billing=SimpleNamespace(enabled=False),
workspace_members=workspace_members,
members=placeholder_quota,
apps=placeholder_quota,
vector_space=placeholder_quota,
documents_upload_quota=placeholder_quota,
annotation_quota_limit=placeholder_quota,
)
class TestMemberInviteEmailApi:
@patch("controllers.console.workspace.members.FeatureService.get_features")
@patch("controllers.console.workspace.members.RegisterService.invite_new_member")
@patch("controllers.console.workspace.members.current_account_with_tenant")
@patch("controllers.console.wraps.db")
@patch("libs.login.check_csrf_token", return_value=None)
def test_invite_normalizes_emails(
self,
mock_csrf,
mock_db,
mock_current_account,
mock_invite_member,
mock_get_features,
app,
):
_mock_wraps_db(mock_db)
mock_get_features.return_value = _build_feature_flags()
mock_invite_member.return_value = "token-abc"
tenant = SimpleNamespace(id="tenant-1", name="Test Tenant")
inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active")
mock_current_account.return_value = (inviter, tenant.id)
with patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"):
with app.test_request_context(
"/workspaces/current/members/invite-email",
method="POST",
json={"emails": ["User@Example.com"], "role": TenantAccountRole.EDITOR.value, "language": "en-US"},
):
account = Account(name="tester", email="tester@example.com")
account._current_tenant = tenant
g._login_user = account
g._current_tenant = tenant
response, status_code = MemberInviteEmailApi().post()
assert status_code == 201
assert response["invitation_results"][0]["email"] == "user@example.com"
assert mock_invite_member.call_count == 1
call_args = mock_invite_member.call_args
assert call_args.kwargs["tenant"] == tenant
assert call_args.kwargs["email"] == "User@Example.com"
assert call_args.kwargs["language"] == "en-US"
assert call_args.kwargs["role"] == TenantAccountRole.EDITOR
assert call_args.kwargs["inviter"] == inviter
mock_csrf.assert_called_once()

View File

@ -1,195 +0,0 @@
"""Unit tests for controllers.web.forgot_password endpoints."""
from __future__ import annotations
import base64
import builtins
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask.views import MethodView
# Ensure flask_restx.api finds MethodView during import.
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
def _load_controller_module():
"""Import controllers.web.forgot_password using a stub package."""
import importlib
import importlib.util
import sys
from types import ModuleType
parent_module_name = "controllers.web"
module_name = f"{parent_module_name}.forgot_password"
if parent_module_name not in sys.modules:
from flask_restx import Namespace
stub = ModuleType(parent_module_name)
stub.__file__ = "controllers/web/__init__.py"
stub.__path__ = ["controllers/web"]
stub.__package__ = "controllers"
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
stub.web_ns = Namespace("web", description="Web API", path="/")
sys.modules[parent_module_name] = stub
return importlib.import_module(module_name)
forgot_password_module = _load_controller_module()
ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi
ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi
ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi
@pytest.fixture
def app() -> Flask:
"""Configure a minimal Flask app for request contexts."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture(autouse=True)
def _enable_web_endpoint_guards():
"""Stub enterprise and feature toggles used by route decorators."""
features = SimpleNamespace(enable_email_password_login=True)
with (
patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features),
):
yield
@pytest.fixture(autouse=True)
def _mock_controller_db():
"""Replace controller-level db reference with a simple stub."""
fake_db = SimpleNamespace(engine=MagicMock(name="engine"))
fake_wraps_db = SimpleNamespace(
session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True))))
)
with (
patch("controllers.web.forgot_password.db", fake_db),
patch("controllers.console.wraps.db", fake_wraps_db),
):
yield fake_db
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10")
def test_send_reset_email_success(
mock_extract_ip: MagicMock,
mock_is_ip_limit: MagicMock,
mock_session: MagicMock,
mock_send_email: MagicMock,
app: Flask,
):
"""POST /forgot-password returns token when email exists and limits allow."""
mock_account = MagicMock()
session_ctx = MagicMock()
mock_session.return_value.__enter__.return_value = session_ctx
session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account
with app.test_request_context(
"/forgot-password",
method="POST",
json={"email": "user@example.com"},
):
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "reset-token"}
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("203.0.113.10")
mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US")
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token"))
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False)
def test_check_token_success(
mock_is_rate_limited: MagicMock,
mock_get_data: MagicMock,
mock_revoke: MagicMock,
mock_generate: MagicMock,
mock_reset_limit: MagicMock,
app: Flask,
):
"""POST /forgot-password/validity validates the code and refreshes token."""
mock_get_data.return_value = {"email": "user@example.com", "code": "123456"}
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "user@example.com", "code": "123456", "token": "old-token"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
mock_is_rate_limited.assert_called_once_with("user@example.com")
mock_get_data.assert_called_once_with("old-token")
mock_revoke.assert_called_once_with("old-token")
mock_generate.assert_called_once_with(
"user@example.com",
code="123456",
additional_data={"phase": "reset"},
)
mock_reset_limit.assert_called_once_with("user@example.com")
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_success(
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_session: MagicMock,
mock_token_bytes: MagicMock,
mock_hash_password: MagicMock,
app: Flask,
):
"""POST /forgot-password/resets updates the stored password when token is valid."""
mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"}
account = MagicMock()
session_ctx = MagicMock()
mock_session.return_value.__enter__.return_value = session_ctx
session_ctx.execute.return_value.scalar_one_or_none.return_value = account
with app.test_request_context(
"/forgot-password/resets",
method="POST",
json={
"token": "reset-token",
"new_password": "StrongPass123!",
"password_confirm": "StrongPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_data.assert_called_once_with("reset-token")
mock_revoke_token.assert_called_once_with("reset-token")
mock_token_bytes.assert_called_once_with(16)
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
expected_password = base64.b64encode(b"hashed-value").decode()
assert account.password == expected_password
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
assert account.password_salt == expected_salt
session_ctx.commit.assert_called_once()

View File

@ -0,0 +1,226 @@
import base64
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.forgot_password import (
ForgotPasswordCheckApi,
ForgotPasswordResetApi,
ForgotPasswordSendEmailApi,
)
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture(autouse=True)
def _patch_wraps():
wraps_features = SimpleNamespace(enable_email_password_login=True)
dify_settings = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD")
with (
patch("controllers.console.wraps.db") as mock_db,
patch("controllers.console.wraps.dify_config", dify_settings),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield
class TestForgotPasswordSendEmailApi:
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
@patch("controllers.web.forgot_password.Session")
def test_should_normalize_email_before_sending(
self,
mock_session_cls,
mock_extract_ip,
mock_rate_limit,
mock_get_account,
mock_send_mail,
app,
):
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_send_mail.return_value = "token-123"
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password",
method="POST",
json={"email": "User@Example.com", "language": "zh-Hans"},
):
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
mock_extract_ip.assert_called_once()
mock_rate_limit.assert_called_once_with("127.0.0.1")
class TestForgotPasswordCheckApi:
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.add_forgot_password_error_rate_limit")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_should_normalize_email_for_validity_checks(
self,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"}
mock_generate_token.return_value = (None, "new-token")
with app.test_request_context(
"/web/forgot-password/validity",
method="POST",
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
mock_is_rate_limit.assert_called_once_with("user@example.com")
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
mock_generate_token.assert_called_once_with(
"User@Example.com",
code="1234",
additional_data={"phase": "reset"},
)
mock_reset_rate.assert_called_once_with("user@example.com")
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_should_preserve_token_email_case(
self,
mock_is_rate_limit,
mock_get_data,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"}
mock_generate_token.return_value = (None, "fresh-token")
with app.test_request_context(
"/web/forgot-password/validity",
method="POST",
json={"email": "mixedcase@example.com", "code": "5678", "token": "token-upper"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "mixedcase@example.com", "token": "fresh-token"}
mock_generate_token.assert_called_once_with(
"MixedCase@Example.com",
code="5678",
additional_data={"phase": "reset"},
)
mock_revoke_token.assert_called_once_with("token-upper")
mock_reset_rate.assert_called_once_with("mixedcase@example.com")
class TestForgotPasswordResetApi:
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
def test_should_fetch_account_with_fallback(
self,
mock_get_reset_data,
mock_revoke_token,
mock_session_cls,
mock_get_account,
mock_update_account,
app,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_update_account.assert_called_once()
mock_revoke_token.assert_called_once_with("token-123")
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
def test_should_update_password_and_commit(
self,
mock_get_account,
mock_get_reset_data,
mock_revoke_token,
mock_session_cls,
mock_token_bytes,
mock_hash_password,
app,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
account = MagicMock()
mock_get_account.return_value = account
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "reset-token",
"new_password": "StrongPass123!",
"password_confirm": "StrongPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_reset_data.assert_called_once_with("reset-token")
mock_revoke_token.assert_called_once_with("reset-token")
mock_token_bytes.assert_called_once_with(16)
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
expected_password = base64.b64encode(b"hashed-value").decode()
assert account.password == expected_password
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
assert account.password_salt == expected_salt
mock_session.commit.assert_called_once()

View File

@ -0,0 +1,91 @@
import base64
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
def encode_code(code: str) -> str:
return base64.b64encode(code.encode("utf-8")).decode()
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture(autouse=True)
def _patch_wraps():
wraps_features = SimpleNamespace(enable_email_password_login=True)
console_dify = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD")
web_dify = SimpleNamespace(ENTERPRISE_ENABLED=True)
with (
patch("controllers.console.wraps.db") as mock_db,
patch("controllers.console.wraps.dify_config", console_dify),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
patch("controllers.web.login.dify_config", web_dify),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield
class TestEmailCodeLoginSendEmailApi:
@patch("controllers.web.login.WebAppAuthService.send_email_code_login_email")
@patch("controllers.web.login.WebAppAuthService.get_user_through_email")
def test_should_fetch_account_with_original_email(
self,
mock_get_user,
mock_send_email,
app,
):
mock_account = MagicMock()
mock_get_user.return_value = mock_account
mock_send_email.return_value = "token-123"
with app.test_request_context(
"/web/email-code-login",
method="POST",
json={"email": "User@Example.com", "language": "en-US"},
):
response = EmailCodeLoginSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_get_user.assert_called_once_with("User@Example.com")
mock_send_email.assert_called_once_with(account=mock_account, language="en-US")
class TestEmailCodeLoginApi:
@patch("controllers.web.login.AccountService.reset_login_error_rate_limit")
@patch("controllers.web.login.WebAppAuthService.login", return_value="new-access-token")
@patch("controllers.web.login.WebAppAuthService.get_user_through_email")
@patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token")
@patch("controllers.web.login.WebAppAuthService.get_email_code_login_data")
def test_should_normalize_email_before_validating(
self,
mock_get_token_data,
mock_revoke_token,
mock_get_user,
mock_login,
mock_reset_login_rate,
app,
):
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_user.return_value = MagicMock()
with app.test_request_context(
"/web/email-code-login/validity",
method="POST",
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
):
response = EmailCodeLoginApi().post()
assert response.get_json() == {"result": "success", "data": {"access_token": "new-access-token"}}
mock_get_user.assert_called_once_with("User@Example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_login.assert_called_once()
mock_reset_login_rate.assert_called_once_with("user@example.com")

View File

@ -228,11 +228,28 @@ def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.M
def scalar(self, _stmt): def scalar(self, _stmt):
return self.results.pop(0) return self.results.pop(0)
# SQLAlchemy Session APIs used by code under test
def expunge(self, *_args, **_kwargs):
pass
def close(self):
pass
# support `with session_factory.create_session() as session:`
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
tenant = SimpleNamespace(id="tenant_id") tenant = SimpleNamespace(id="tenant_id")
end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id") end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id")
db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user]))
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) # Monkeypatch session factory to return our stub session
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: StubSession([tenant, None, end_user]),
)
entity = ToolEntity( entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
@ -266,8 +283,23 @@ def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pyt
def scalar(self, _stmt): def scalar(self, _stmt):
return self.results.pop(0) return self.results.pop(0)
db_stub = SimpleNamespace(session=StubSession([None])) def expunge(self, *_args, **_kwargs):
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) pass
def close(self):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
# Monkeypatch session factory to return our stub session with no tenant
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: StubSession([None]),
)
entity = ToolEntity( entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),

View File

@ -35,7 +35,6 @@ from core.variables.variables import (
SecretVariable, SecretVariable,
StringVariable, StringVariable,
Variable, Variable,
VariableUnion,
) )
from core.workflow.runtime import VariablePool from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
@ -96,7 +95,7 @@ class _Segments(BaseModel):
class _Variables(BaseModel): class _Variables(BaseModel):
variables: list[VariableUnion] variables: list[Variable]
def create_test_file( def create_test_file(
@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad:
# Create one instance of each variable type # Create one instance of each variable type
test_file = create_test_file() test_file = create_test_file()
all_variables: list[VariableUnion] = [ all_variables: list[Variable] = [
NoneVariable(name="none_var"), NoneVariable(name="none_var"),
StringVariable(value="test string", name="string_var"), StringVariable(value="test string", name="string_var"),
IntegerVariable(value=42, name="int_var"), IntegerVariable(value=42, name="int_var"),

View File

@ -11,7 +11,7 @@ from core.variables import (
SegmentType, SegmentType,
StringVariable, StringVariable,
) )
from core.variables.variables import Variable from core.variables.variables import VariableBase
def test_frozen_variables(): def test_frozen_variables():
@ -76,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object(): def test_variable_to_object():
var: Variable = StringVariable(name="text", value="text") var: VariableBase = StringVariable(name="text", value="text")
assert var.to_object() == "text" assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42) var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42 assert var.to_object() == 42

View File

@ -24,7 +24,7 @@ from core.variables.variables import (
IntegerVariable, IntegerVariable,
ObjectVariable, ObjectVariable,
StringVariable, StringVariable,
VariableUnion, Variable,
) )
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.runtime import VariablePool from core.workflow.runtime import VariablePool
@ -160,7 +160,7 @@ class TestVariablePoolSerialization:
) )
# Create environment variables with all types including ArrayFileVariable # Create environment variables with all types including ArrayFileVariable
env_vars: list[VariableUnion] = [ env_vars: list[Variable] = [
StringVariable( StringVariable(
id="env_string_id", id="env_string_id",
name="env_string", name="env_string",
@ -182,7 +182,7 @@ class TestVariablePoolSerialization:
] ]
# Create conversation variables with complex data # Create conversation variables with complex data
conv_vars: list[VariableUnion] = [ conv_vars: list[Variable] = [
StringVariable( StringVariable(
id="conv_string_id", id="conv_string_id",
name="conv_string", name="conv_string",

View File

@ -2,13 +2,17 @@ from types import SimpleNamespace
import pytest import pytest
from configs import dify_config
from core.file.enums import FileType from core.file.enums import FileType
from core.file.models import File, FileTransferMethod from core.file.models import File, FileTransferMethod
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.variables import StringVariable from core.variables.variables import StringVariable
from core.workflow.constants import ( from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID,
) )
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.runtime import VariablePool from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
@ -96,6 +100,58 @@ class TestWorkflowEntry:
assert output_var is not None assert output_var is not None
assert output_var.value == "system_user" assert output_var.value == "system_user"
def test_single_step_run_injects_code_limits(self):
"""Ensure single-step CodeNode execution configures limits."""
# Arrange
node_id = "code_node"
node_data = {
"type": "code",
"title": "Code",
"desc": None,
"variables": [],
"code_language": CodeLanguage.PYTHON3,
"code": "def main():\n return {}",
"outputs": {},
}
node_config = {"id": node_id, "data": node_data}
class StubWorkflow:
def __init__(self):
self.tenant_id = "tenant"
self.app_id = "app"
self.id = "workflow"
self.graph_dict = {"nodes": [node_config], "edges": []}
def get_node_config_by_id(self, target_id: str):
assert target_id == node_id
return node_config
workflow = StubWorkflow()
variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={})
expected_limits = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
max_precision=dify_config.CODE_MAX_PRECISION,
max_depth=dify_config.CODE_MAX_DEPTH,
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
# Act
node, _ = WorkflowEntry.single_step_run(
workflow=workflow,
node_id=node_id,
user_id="user",
user_inputs={},
variable_pool=variable_pool,
)
# Assert
assert isinstance(node, CodeNode)
assert node._limits == expected_limits
def test_mapping_user_inputs_to_variable_pool_with_env_variables(self): def test_mapping_user_inputs_to_variable_pool_with_env_variables(self):
"""Test mapping environment variables from user inputs to variable pool.""" """Test mapping environment variables from user inputs to variable pool."""
# Initialize variable pool with environment variables # Initialize variable pool with environment variables

View File

@ -0,0 +1 @@
"""LogStore extension unit tests."""

View File

@ -0,0 +1,469 @@
"""
Unit tests for SQL escape utility functions.
These tests ensure that SQL injection attacks are properly prevented
in LogStore queries, particularly for cross-tenant access scenarios.
"""
import pytest
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
class TestEscapeSQLString:
"""Test escape_sql_string function."""
def test_escape_empty_string(self):
"""Test escaping empty string."""
assert escape_sql_string("") == ""
def test_escape_normal_string(self):
"""Test escaping string without special characters."""
assert escape_sql_string("tenant_abc123") == "tenant_abc123"
assert escape_sql_string("app-uuid-1234") == "app-uuid-1234"
def test_escape_single_quote(self):
"""Test escaping single quote."""
# Single quote should be doubled
assert escape_sql_string("tenant'id") == "tenant''id"
assert escape_sql_string("O'Reilly") == "O''Reilly"
def test_escape_multiple_quotes(self):
"""Test escaping multiple single quotes."""
assert escape_sql_string("a'b'c") == "a''b''c"
assert escape_sql_string("'''") == "''''''"
# === SQL Injection Attack Scenarios ===
def test_prevent_boolean_injection(self):
"""Test prevention of boolean injection attacks."""
# Classic OR 1=1 attack
malicious_input = "tenant' OR '1'='1"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' OR ''1''=''1"
# When used in SQL, this becomes a safe string literal
sql = f"WHERE tenant_id='{escaped}'"
assert sql == "WHERE tenant_id='tenant'' OR ''1''=''1'"
# The entire input is now a string literal that won't match any tenant
def test_prevent_or_injection(self):
"""Test prevention of OR-based injection."""
malicious_input = "tenant_a' OR tenant_id='tenant_b"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant_a'' OR tenant_id=''tenant_b"
sql = f"WHERE tenant_id='{escaped}'"
# The OR is now part of the string literal, not SQL logic
assert "OR tenant_id=" in sql
# The SQL has: opening ', doubled internal quotes '', and closing '
assert sql == "WHERE tenant_id='tenant_a'' OR tenant_id=''tenant_b'"
def test_prevent_union_injection(self):
"""Test prevention of UNION-based injection."""
malicious_input = "xxx' UNION SELECT password FROM users WHERE '1'='1"
escaped = escape_sql_string(malicious_input)
assert escaped == "xxx'' UNION SELECT password FROM users WHERE ''1''=''1"
# UNION becomes part of the string literal
assert "UNION" in escaped
assert escaped.count("''") == 4 # All internal quotes are doubled
def test_prevent_comment_injection(self):
"""Test prevention of comment-based injection."""
# SQL comment to bypass remaining conditions
malicious_input = "tenant' --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' --"
sql = f"WHERE tenant_id='{escaped}' AND deleted=false"
# The -- is now inside the string, not a SQL comment
assert "--" in sql
assert "AND deleted=false" in sql # This part is NOT commented out
def test_prevent_semicolon_injection(self):
"""Test prevention of semicolon-based multi-statement injection."""
malicious_input = "tenant'; DROP TABLE users; --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant''; DROP TABLE users; --"
# Semicolons and DROP are now part of the string
assert "DROP TABLE" in escaped
def test_prevent_time_based_blind_injection(self):
"""Test prevention of time-based blind SQL injection."""
malicious_input = "tenant' AND SLEEP(5) --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' AND SLEEP(5) --"
# SLEEP becomes part of the string
assert "SLEEP" in escaped
def test_prevent_wildcard_injection(self):
"""Test prevention of wildcard-based injection."""
malicious_input = "tenant' OR tenant_id LIKE '%"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' OR tenant_id LIKE ''%"
# The LIKE and wildcard are now part of the string
assert "LIKE" in escaped
def test_prevent_null_byte_injection(self):
"""Test handling of null bytes."""
# Null bytes can sometimes bypass filters
malicious_input = "tenant\x00' OR '1'='1"
escaped = escape_sql_string(malicious_input)
# Null byte is preserved, but quote is escaped
assert "''1''=''1" in escaped
# === Real-world SAAS Scenarios ===
def test_cross_tenant_access_attempt(self):
"""Test prevention of cross-tenant data access."""
# Attacker tries to access another tenant's data
attacker_input = "tenant_b' OR tenant_id='tenant_a"
escaped = escape_sql_string(attacker_input)
sql = f"SELECT * FROM workflow_runs WHERE tenant_id='{escaped}'"
# The query will look for a tenant literally named "tenant_b' OR tenant_id='tenant_a"
# which doesn't exist - preventing access to either tenant's data
assert "tenant_b'' OR tenant_id=''tenant_a" in sql
def test_cross_app_access_attempt(self):
"""Test prevention of cross-application data access."""
attacker_input = "app1' OR app_id='app2"
escaped = escape_sql_string(attacker_input)
sql = f"WHERE app_id='{escaped}'"
# Cannot access app2's data
assert "app1'' OR app_id=''app2" in sql
def test_bypass_status_filter(self):
"""Test prevention of bypassing status filters."""
# Try to see all statuses instead of just 'running'
attacker_input = "running' OR status LIKE '%"
escaped = escape_sql_string(attacker_input)
sql = f"WHERE status='{escaped}'"
# Status condition is not bypassed
assert "running'' OR status LIKE ''%" in sql
# === Edge Cases ===
def test_escape_only_quotes(self):
"""Test string with only quotes."""
assert escape_sql_string("'") == "''"
assert escape_sql_string("''") == "''''"
def test_escape_mixed_content(self):
"""Test string with mixed quotes and other chars."""
input_str = "It's a 'test' of O'Reilly's code"
escaped = escape_sql_string(input_str)
assert escaped == "It''s a ''test'' of O''Reilly''s code"
def test_escape_unicode_with_quotes(self):
"""Test Unicode strings with quotes."""
input_str = "租户' OR '1'='1"
escaped = escape_sql_string(input_str)
assert escaped == "租户'' OR ''1''=''1"
class TestEscapeIdentifier:
"""Test escape_identifier function."""
def test_escape_uuid(self):
"""Test escaping UUID identifiers."""
uuid = "550e8400-e29b-41d4-a716-446655440000"
assert escape_identifier(uuid) == uuid
def test_escape_alphanumeric_id(self):
"""Test escaping alphanumeric identifiers."""
assert escape_identifier("tenant_123") == "tenant_123"
assert escape_identifier("app-abc-123") == "app-abc-123"
def test_escape_identifier_with_quote(self):
"""Test escaping identifier with single quote."""
malicious = "tenant' OR '1'='1"
escaped = escape_identifier(malicious)
assert escaped == "tenant'' OR ''1''=''1"
def test_identifier_injection_attempt(self):
"""Test prevention of injection through identifiers."""
# Common identifier injection patterns
test_cases = [
("id' OR '1'='1", "id'' OR ''1''=''1"),
("id'; DROP TABLE", "id''; DROP TABLE"),
("id' UNION SELECT", "id'' UNION SELECT"),
]
for malicious, expected in test_cases:
assert escape_identifier(malicious) == expected
class TestSQLInjectionIntegration:
"""Integration tests simulating real SQL construction scenarios."""
def test_complete_where_clause_safety(self):
"""Test that a complete WHERE clause is safe from injection."""
# Simulating typical query construction
tenant_id = "tenant' OR '1'='1"
app_id = "app' UNION SELECT"
run_id = "run' --"
escaped_tenant = escape_identifier(tenant_id)
escaped_app = escape_identifier(app_id)
escaped_run = escape_identifier(run_id)
sql = f"""
SELECT * FROM workflow_runs
WHERE tenant_id='{escaped_tenant}'
AND app_id='{escaped_app}'
AND id='{escaped_run}'
"""
# Verify all special characters are escaped
assert "tenant'' OR ''1''=''1" in sql
assert "app'' UNION SELECT" in sql
assert "run'' --" in sql
# Verify SQL structure is preserved (3 conditions with AND)
assert sql.count("AND") == 2
def test_multiple_conditions_with_injection_attempts(self):
"""Test multiple conditions all attempting injection."""
conditions = {
"tenant_id": "t1' OR tenant_id='t2",
"app_id": "a1' OR app_id='a2",
"status": "running' OR '1'='1",
}
where_parts = []
for field, value in conditions.items():
escaped = escape_sql_string(value)
where_parts.append(f"{field}='{escaped}'")
where_clause = " AND ".join(where_parts)
# All injection attempts are neutralized
assert "t1'' OR tenant_id=''t2" in where_clause
assert "a1'' OR app_id=''a2" in where_clause
assert "running'' OR ''1''=''1" in where_clause
# AND structure is preserved
assert where_clause.count(" AND ") == 2
@pytest.mark.parametrize(
("attack_vector", "description"),
[
("' OR '1'='1", "Boolean injection"),
("' OR '1'='1' --", "Boolean with comment"),
("' UNION SELECT * FROM users --", "Union injection"),
("'; DROP TABLE workflow_runs; --", "Destructive command"),
("' AND SLEEP(10) --", "Time-based blind"),
("' OR tenant_id LIKE '%", "Wildcard injection"),
("admin' --", "Comment bypass"),
("' OR 1=1 LIMIT 1 --", "Limit bypass"),
],
)
def test_common_injection_vectors(self, attack_vector, description):
"""Test protection against common injection attack vectors."""
escaped = escape_sql_string(attack_vector)
# Build SQL
sql = f"WHERE tenant_id='{escaped}'"
# Verify the attack string is now a safe literal
# The key indicator: all internal single quotes are doubled
internal_quotes = escaped.count("''")
original_quotes = attack_vector.count("'")
# Each original quote should be doubled
assert internal_quotes == original_quotes
# Verify SQL has exactly 2 quotes (opening and closing)
assert sql.count("'") >= 2 # At least opening and closing
def test_logstore_specific_scenario(self):
"""Test SQL injection prevention in LogStore-specific scenarios."""
# Simulate LogStore query with window function
tenant_id = "tenant' OR '1'='1"
app_id = "app' UNION SELECT"
escaped_tenant = escape_identifier(tenant_id)
escaped_app = escape_identifier(app_id)
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM workflow_execution_logstore
WHERE tenant_id='{escaped_tenant}'
AND app_id='{escaped_app}'
AND __time__ > 0
) AS subquery WHERE rn = 1
"""
# Complex query structure is maintained
assert "ROW_NUMBER()" in sql
assert "PARTITION BY id" in sql
# Injection attempts are escaped
assert "tenant'' OR ''1''=''1" in sql
assert "app'' UNION SELECT" in sql
# ====================================================================================
# Tests for LogStore Query Syntax (SDK Mode)
# ====================================================================================
class TestLogStoreQueryEscape:
"""Test escape_logstore_query_value for SDK mode query syntax."""
def test_normal_value(self):
"""Test escaping normal alphanumeric value."""
value = "550e8400-e29b-41d4-a716-446655440000"
escaped = escape_logstore_query_value(value)
# Should be wrapped in double quotes
assert escaped == '"550e8400-e29b-41d4-a716-446655440000"'
def test_empty_value(self):
"""Test escaping empty string."""
assert escape_logstore_query_value("") == '""'
def test_value_with_and_keyword(self):
"""Test that 'and' keyword is neutralized when quoted."""
malicious = "value and field:evil"
escaped = escape_logstore_query_value(malicious)
# Should be wrapped in quotes, making 'and' a literal
assert escaped == '"value and field:evil"'
# Simulate using in query
query = f"tenant_id:{escaped}"
assert query == 'tenant_id:"value and field:evil"'
def test_value_with_or_keyword(self):
"""Test that 'or' keyword is neutralized when quoted."""
malicious = "tenant_a or tenant_id:tenant_b"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"tenant_a or tenant_id:tenant_b"'
query = f"tenant_id:{escaped}"
assert "or" in query # Present but as literal string
def test_value_with_not_keyword(self):
"""Test that 'not' keyword is neutralized when quoted."""
malicious = "not field:value"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"not field:value"'
def test_value_with_parentheses(self):
"""Test that parentheses are neutralized when quoted."""
malicious = "(tenant_a or tenant_b)"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"(tenant_a or tenant_b)"'
assert "(" in escaped # Present as literal
assert ")" in escaped # Present as literal
def test_value_with_colon(self):
"""Test that colons are neutralized when quoted."""
malicious = "field:value"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"field:value"'
assert ":" in escaped # Present as literal
def test_value_with_double_quotes(self):
"""Test that internal double quotes are escaped."""
value_with_quotes = 'tenant"test"value'
escaped = escape_logstore_query_value(value_with_quotes)
# Double quotes should be escaped with backslash
assert escaped == '"tenant\\"test\\"value"'
# Should have outer quotes plus escaped inner quotes
assert '\\"' in escaped
def test_value_with_backslash(self):
"""Test that backslashes are escaped."""
value_with_backslash = "tenant\\test"
escaped = escape_logstore_query_value(value_with_backslash)
# Backslash should be escaped
assert escaped == '"tenant\\\\test"'
assert "\\\\" in escaped
def test_value_with_backslash_and_quote(self):
"""Test escaping both backslash and double quote."""
value = 'path\\to\\"file"'
escaped = escape_logstore_query_value(value)
# Both should be escaped
assert escaped == '"path\\\\to\\\\\\"file\\""'
# Verify escape order is correct
assert "\\\\" in escaped # Escaped backslash
assert '\\"' in escaped # Escaped double quote
def test_complex_injection_attempt(self):
"""Test complex injection combining multiple operators."""
malicious = 'tenant_a" or (tenant_id:"tenant_b" and app_id:"evil")'
escaped = escape_logstore_query_value(malicious)
# All special chars should be literals or escaped
assert escaped.startswith('"')
assert escaped.endswith('"')
# Inner double quotes escaped, operators become literals
assert "or" in escaped
assert "and" in escaped
assert '\\"' in escaped # Escaped quotes
def test_only_backslash(self):
"""Test escaping a single backslash."""
assert escape_logstore_query_value("\\") == '"\\\\"'
def test_only_double_quote(self):
"""Test escaping a single double quote."""
assert escape_logstore_query_value('"') == '"\\""'
def test_multiple_backslashes(self):
"""Test escaping multiple consecutive backslashes."""
assert escape_logstore_query_value("\\\\\\") == '"\\\\\\\\\\\\"' # 3 backslashes -> 6
def test_escape_sequence_like_input(self):
"""Test that existing escape sequences are properly escaped."""
# Input looks like already escaped, but we still escape it
value = 'value\\"test'
escaped = escape_logstore_query_value(value)
# \\ -> \\\\, " -> \"
assert escaped == '"value\\\\\\"test"'
@pytest.mark.parametrize(
("attack_scenario", "field", "malicious_value"),
[
("Cross-tenant via OR", "tenant_id", "tenant_a or tenant_id:tenant_b"),
("Cross-app via AND", "app_id", "app_a and (app_id:app_b or app_id:app_c)"),
("Boolean logic", "status", "succeeded or status:failed"),
("Negation", "tenant_id", "not tenant_a"),
("Field injection", "run_id", "run123 and tenant_id:evil_tenant"),
("Parentheses grouping", "app_id", "app1 or (app_id:app2 and tenant_id:tenant2)"),
("Quote breaking attempt", "tenant_id", 'tenant" or "1"="1'),
("Backslash escape bypass", "app_id", "app\\ and app_id:evil"),
],
)
def test_logstore_query_injection_scenarios(attack_scenario: str, field: str, malicious_value: str):
"""Test that various LogStore query injection attempts are neutralized."""
escaped = escape_logstore_query_value(malicious_value)
# Build query
query = f"{field}:{escaped}"
# All operators should be within quoted string (literals)
assert escaped.startswith('"')
assert escaped.endswith('"')
# Verify the full query structure is safe
assert query.count(":") >= 1 # At least the main field:value separator

View File

@ -0,0 +1,59 @@
"""Unit tests for traceparent header propagation in EnterpriseRequest.
This test module verifies that the W3C traceparent header is properly
generated and included in HTTP requests made by EnterpriseRequest.
"""
from unittest.mock import MagicMock, patch
import pytest
from services.enterprise.base import EnterpriseRequest
class TestTraceparentPropagation:
"""Unit tests for traceparent header propagation."""
@pytest.fixture
def mock_enterprise_config(self):
"""Mock EnterpriseRequest configuration."""
with (
patch.object(EnterpriseRequest, "base_url", "https://enterprise-api.example.com"),
patch.object(EnterpriseRequest, "secret_key", "test-secret-key"),
patch.object(EnterpriseRequest, "secret_key_header", "Enterprise-Api-Secret-Key"),
):
yield
@pytest.fixture
def mock_httpx_client(self):
"""Mock httpx.Client for testing."""
with patch("services.enterprise.base.httpx.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value.__enter__.return_value = mock_client
mock_client_class.return_value.__exit__.return_value = None
# Setup default response
mock_response = MagicMock()
mock_response.json.return_value = {"result": "success"}
mock_client.request.return_value = mock_response
yield mock_client
def test_traceparent_header_included_when_generated(self, mock_enterprise_config, mock_httpx_client):
"""Test that traceparent header is included when successfully generated."""
# Arrange
expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent):
# Act
EnterpriseRequest.send_request("GET", "/test")
# Assert
mock_httpx_client.request.assert_called_once()
call_args = mock_httpx_client.request.call_args
headers = call_args[1]["headers"]
assert "traceparent" in headers
assert headers["traceparent"] == expected_traceparent
assert headers["Content-Type"] == "application/json"
assert headers["Enterprise-Api-Secret-Key"] == "test-secret-key"

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from configs import dify_config from configs import dify_config
from models.account import Account from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import ( from services.errors.account import (
AccountAlreadyInTenantError, AccountAlreadyInTenantError,
@ -1147,9 +1147,13 @@ class TestRegisterService:
mock_session = MagicMock() mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
with patch("services.account_service.Session") as mock_session_class: with (
patch("services.account_service.Session") as mock_session_class,
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_session_class.return_value.__enter__.return_value = mock_session mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = None
# Mock RegisterService.register # Mock RegisterService.register
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock( mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
@ -1182,9 +1186,59 @@ class TestRegisterService:
email="newuser@example.com", email="newuser@example.com",
name="newuser", name="newuser",
language="en-US", language="en-US",
status="pending", status=AccountStatus.PENDING,
is_setup=True, is_setup=True,
) )
mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session)
def test_invite_new_member_normalizes_new_account_email(
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
):
"""Ensure inviting with mixed-case email normalizes before registering."""
mock_tenant = MagicMock()
mock_tenant.id = "tenant-456"
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
mixed_email = "Invitee@Example.com"
mock_session = MagicMock()
with (
patch("services.account_service.Session") as mock_session_class,
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = None
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="new-user-789", email="invitee@example.com", name="invitee", status="pending"
)
with patch("services.account_service.RegisterService.register") as mock_register:
mock_register.return_value = mock_new_account
with (
patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
patch("services.account_service.TenantService.switch_tenant") as mock_switch_tenant,
patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token,
):
mock_generate_token.return_value = "invite-token-abc"
RegisterService.invite_new_member(
tenant=mock_tenant,
email=mixed_email,
language="en-US",
role="normal",
inviter=mock_inviter,
)
mock_register.assert_called_once_with(
email="invitee@example.com",
name="invitee",
language="en-US",
status=AccountStatus.PENDING,
is_setup=True,
)
mock_lookup.assert_called_once_with(mixed_email, session=mock_session)
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal") mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account) mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account)
@ -1207,9 +1261,13 @@ class TestRegisterService:
mock_session = MagicMock() mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
with patch("services.account_service.Session") as mock_session_class: with (
patch("services.account_service.Session") as mock_session_class,
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_session_class.return_value.__enter__.return_value = mock_session mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_class.return_value.__exit__.return_value = None mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = mock_existing_account
# Mock the db.session.query for TenantAccountJoin # Mock the db.session.query for TenantAccountJoin
mock_db_query = MagicMock() mock_db_query = MagicMock()
@ -1238,6 +1296,7 @@ class TestRegisterService:
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal") mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account) mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
mock_task_dependencies.delay.assert_called_once() mock_task_dependencies.delay.assert_called_once()
mock_lookup.assert_called_once_with("existing@example.com", session=mock_session)
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies): def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
"""Test inviting a member who is already in the tenant.""" """Test inviting a member who is already in the tenant."""
@ -1251,7 +1310,6 @@ class TestRegisterService:
# Mock database queries # Mock database queries
query_results = { query_results = {
("Account", "email", "existing@example.com"): mock_existing_account,
( (
"TenantAccountJoin", "TenantAccountJoin",
"tenant_id", "tenant_id",
@ -1261,7 +1319,11 @@ class TestRegisterService:
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
# Mock TenantService methods # Mock TenantService methods
with patch("services.account_service.TenantService.check_member_permission") as mock_check_permission: with (
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
):
mock_lookup.return_value = mock_existing_account
# Execute test and verify exception # Execute test and verify exception
self._assert_exception_raised( self._assert_exception_raised(
AccountAlreadyInTenantError, AccountAlreadyInTenantError,
@ -1272,6 +1334,7 @@ class TestRegisterService:
role="normal", role="normal",
inviter=mock_inviter, inviter=mock_inviter,
) )
mock_lookup.assert_called_once()
def test_invite_new_member_no_inviter(self): def test_invite_new_member_no_inviter(self):
"""Test inviting a member without providing an inviter.""" """Test inviting a member without providing an inviter."""
@ -1497,6 +1560,30 @@ class TestRegisterService:
# Verify results # Verify results
assert result is None assert result is None
def test_get_invitation_with_case_fallback_returns_initial_match(self):
"""Fallback helper should return the initial invitation when present."""
invitation = {"workspace_id": "tenant-456"}
with patch(
"services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation
) as mock_get:
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
assert result == invitation
mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123")
def test_get_invitation_with_case_fallback_retries_with_lowercase(self):
"""Fallback helper should retry with lowercase email when needed."""
invitation = {"workspace_id": "tenant-456"}
with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get:
mock_get.side_effect = [None, invitation]
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
assert result == invitation
assert mock_get.call_args_list == [
(("tenant-456", "User@Test.com", "token-123"),),
(("tenant-456", "user@test.com", "token-123"),),
]
# ==================== Helper Method Tests ==================== # ==================== Helper Method Tests ====================
def test_get_invitation_token_key(self): def test_get_invitation_token_key(self):

14
api/uv.lock generated
View File

@ -453,15 +453,15 @@ wheels = [
[[package]] [[package]]
name = "azure-core" name = "azure-core"
version = "1.36.0" version = "1.38.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "requests" }, { name = "requests" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/0a/c4/d4ff3bc3ddf155156460bff340bbe9533f99fac54ddea165f35a8619f162/azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7", size = 351139, upload-time = "2025-10-15T00:33:49.083Z" } sdist = { url = "https://files.pythonhosted.org/packages/dc/1b/e503e08e755ea94e7d3419c9242315f888fc664211c90d032e40479022bf/azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993", size = 363033, upload-time = "2026-01-12T17:03:05.535Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b", size = 213302, upload-time = "2025-10-15T00:33:51.058Z" }, { url = "https://files.pythonhosted.org/packages/fc/d8/b8fcba9464f02b121f39de2db2bf57f0b216fe11d014513d666e8634380d/azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335", size = 217825, upload-time = "2026-01-12T17:03:07.291Z" },
] ]
[[package]] [[package]]
@ -1368,7 +1368,7 @@ wheels = [
[[package]] [[package]]
name = "dify-api" name = "dify-api"
version = "1.11.2" version = "1.11.3"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "aliyun-log-python-sdk" }, { name = "aliyun-log-python-sdk" },
@ -1965,11 +1965,11 @@ wheels = [
[[package]] [[package]]
name = "filelock" name = "filelock"
version = "3.20.0" version = "3.20.3"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" } sdist = { url = "https://files.pythonhosted.org/packages/1d/65/ce7f1b70157833bf3cb851b556a37d4547ceafc158aa9b34b36782f23696/filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1", size = 19485, upload-time = "2026-01-09T17:55:05.421Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" }, { url = "https://files.pythonhosted.org/packages/b5/36/7fb70f04bf00bc646cd5bb45aa9eddb15e19437a28b8fb2b4a5249fac770/filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1", size = 16701, upload-time = "2026-01-09T17:55:04.334Z" },
] ]
[[package]] [[package]]

View File

@ -1037,18 +1037,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Options: # Options:
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default) # - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository # - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
# - extensions.logstore.repositories.logstore_workflow_execution_repository.LogstoreWorkflowExecutionRepository
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
# Core workflow node execution repository implementation # Core workflow node execution repository implementation
# Options: # Options:
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default) # - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository # - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
# - extensions.logstore.repositories.logstore_workflow_node_execution_repository.LogstoreWorkflowNodeExecutionRepository
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation # API workflow run repository implementation
# Options:
# - repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository (default)
# - extensions.logstore.repositories.logstore_api_workflow_run_repository.LogstoreAPIWorkflowRunRepository
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# API workflow node execution repository implementation # API workflow node execution repository implementation
# Options:
# - repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository (default)
# - extensions.logstore.repositories.logstore_api_workflow_node_execution_repository.LogstoreAPIWorkflowNodeExecutionRepository
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# Workflow log cleanup configuration # Workflow log cleanup configuration

View File

@ -21,7 +21,7 @@ services:
# API service # API service
api: api:
image: langgenius/dify-api:1.11.2 image: langgenius/dify-api:1.11.3
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -63,7 +63,7 @@ services:
# worker service # worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.) # The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker: worker:
image: langgenius/dify-api:1.11.2 image: langgenius/dify-api:1.11.3
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -102,7 +102,7 @@ services:
# worker_beat service # worker_beat service
# Celery beat for scheduling periodic tasks. # Celery beat for scheduling periodic tasks.
worker_beat: worker_beat:
image: langgenius/dify-api:1.11.2 image: langgenius/dify-api:1.11.3
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -132,7 +132,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:1.11.2 image: langgenius/dify-web:1.11.3
restart: always restart: always
environment: environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-} CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -704,7 +704,7 @@ services:
# API service # API service
api: api:
image: langgenius/dify-api:1.11.2 image: langgenius/dify-api:1.11.3
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -746,7 +746,7 @@ services:
# worker service # worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.) # The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker: worker:
image: langgenius/dify-api:1.11.2 image: langgenius/dify-api:1.11.3
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -785,7 +785,7 @@ services:
# worker_beat service # worker_beat service
# Celery beat for scheduling periodic tasks. # Celery beat for scheduling periodic tasks.
worker_beat: worker_beat:
image: langgenius/dify-api:1.11.2 image: langgenius/dify-api:1.11.3
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -815,7 +815,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:1.11.2 image: langgenius/dify-web:1.11.3
restart: always restart: always
environment: environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-} CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -31,6 +31,8 @@ NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false
# The timeout for the text generation in millisecond # The timeout for the text generation in millisecond
NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000 NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000
# Used by web/docker/entrypoint.sh to overwrite/export NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS at container startup (Docker only)
TEXT_GENERATION_TIMEOUT_MS=60000
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP # CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
NEXT_PUBLIC_CSP_WHITELIST= NEXT_PUBLIC_CSP_WHITELIST=

View File

@ -53,6 +53,7 @@ vi.mock('@/context/global-public-context', () => {
) )
return { return {
useGlobalPublicStore, useGlobalPublicStore,
useIsSystemFeaturesPending: () => false,
} }
}) })

View File

@ -9,8 +9,8 @@ import {
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
} from '@/app/education-apply/constants' } from '@/app/education-apply/constants'
import { fetchSetupStatus } from '@/service/common'
import { sendGAEvent } from '@/utils/gtag' import { sendGAEvent } from '@/utils/gtag'
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect'
import { trackEvent } from './base/amplitude' import { trackEvent } from './base/amplitude'
@ -33,15 +33,8 @@ export const AppInitializer = ({
const isSetupFinished = useCallback(async () => { const isSetupFinished = useCallback(async () => {
try { try {
if (localStorage.getItem('setup_status') === 'finished') const setUpStatus = await fetchSetupStatusWithCache()
return true return setUpStatus.step === 'finished'
const setUpStatus = await fetchSetupStatus()
if (setUpStatus.step !== 'finished') {
localStorage.removeItem('setup_status')
return false
}
localStorage.setItem('setup_status', 'finished')
return true
} }
catch (error) { catch (error) {
console.error(error) console.error(error)

View File

@ -34,13 +34,6 @@ vi.mock('@/context/app-context', () => ({
}), }),
})) }))
vi.mock('@/service/common', () => ({
fetchCurrentWorkspace: vi.fn(),
fetchLangGeniusVersion: vi.fn(),
fetchUserProfile: vi.fn(),
getSystemFeatures: vi.fn(),
}))
vi.mock('@/service/access-control', () => ({ vi.mock('@/service/access-control', () => ({
useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args), useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args),
useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args), useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args),
@ -125,7 +118,6 @@ const resetAccessControlStore = () => {
const resetGlobalStore = () => { const resetGlobalStore = () => {
useGlobalPublicStore.setState({ useGlobalPublicStore.setState({
systemFeatures: defaultSystemFeatures, systemFeatures: defaultSystemFeatures,
isGlobalPending: false,
}) })
} }

View File

@ -54,7 +54,7 @@ const pageNameEnrichmentPlugin = (): amplitude.Types.EnrichmentPlugin => {
} }
const AmplitudeProvider: FC<IAmplitudeProps> = ({ const AmplitudeProvider: FC<IAmplitudeProps> = ({
sessionReplaySampleRate = 1, sessionReplaySampleRate = 0.5,
}) => { }) => {
useEffect(() => { useEffect(() => {
// Only enable in Saas edition with valid API key // Only enable in Saas edition with valid API key

View File

@ -170,8 +170,12 @@ describe('useChatWithHistory', () => {
await waitFor(() => { await waitFor(() => {
expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', false, 'app-1') expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', false, 'app-1')
}) })
expect(result.current.pinnedConversationList).toEqual(pinnedData.data) await waitFor(() => {
expect(result.current.conversationList).toEqual(listData.data) expect(result.current.pinnedConversationList).toEqual(pinnedData.data)
})
await waitFor(() => {
expect(result.current.conversationList).toEqual(listData.data)
})
}) })
}) })

View File

@ -3,7 +3,8 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react' import * as React from 'react'
import { useAppContext } from '@/context/app-context' import { useAppContext } from '@/context/app-context'
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing' import { fetchSubscriptionUrls } from '@/service/billing'
import { consoleClient } from '@/service/client'
import Toast from '../../../../base/toast' import Toast from '../../../../base/toast'
import { ALL_PLANS } from '../../../config' import { ALL_PLANS } from '../../../config'
import { Plan } from '../../../type' import { Plan } from '../../../type'
@ -21,10 +22,15 @@ vi.mock('@/context/app-context', () => ({
})) }))
vi.mock('@/service/billing', () => ({ vi.mock('@/service/billing', () => ({
fetchBillingUrl: vi.fn(),
fetchSubscriptionUrls: vi.fn(), fetchSubscriptionUrls: vi.fn(),
})) }))
vi.mock('@/service/client', () => ({
consoleClient: {
billingUrl: vi.fn(),
},
}))
vi.mock('@/hooks/use-async-window-open', () => ({ vi.mock('@/hooks/use-async-window-open', () => ({
useAsyncWindowOpen: vi.fn(), useAsyncWindowOpen: vi.fn(),
})) }))
@ -37,7 +43,7 @@ vi.mock('../../assets', () => ({
const mockUseAppContext = useAppContext as Mock const mockUseAppContext = useAppContext as Mock
const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock
const mockFetchBillingUrl = fetchBillingUrl as Mock const mockBillingUrl = consoleClient.billingUrl as Mock
const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock
const mockToastNotify = Toast.notify as Mock const mockToastNotify = Toast.notify as Mock
@ -69,7 +75,7 @@ beforeEach(() => {
vi.clearAllMocks() vi.clearAllMocks()
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true }) mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open())) mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open()))
mockFetchBillingUrl.mockResolvedValue({ url: 'https://billing.example' }) mockBillingUrl.mockResolvedValue({ url: 'https://billing.example' })
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://subscription.example' }) mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://subscription.example' })
assignedHref = '' assignedHref = ''
}) })
@ -143,7 +149,7 @@ describe('CloudPlanItem', () => {
type: 'error', type: 'error',
message: 'billing.buyPermissionDeniedTip', message: 'billing.buyPermissionDeniedTip',
})) }))
expect(mockFetchBillingUrl).not.toHaveBeenCalled() expect(mockBillingUrl).not.toHaveBeenCalled()
}) })
it('should open billing portal when upgrading current paid plan', async () => { it('should open billing portal when upgrading current paid plan', async () => {
@ -162,7 +168,7 @@ describe('CloudPlanItem', () => {
fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' })) fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' }))
await waitFor(() => { await waitFor(() => {
expect(mockFetchBillingUrl).toHaveBeenCalledTimes(1) expect(mockBillingUrl).toHaveBeenCalledTimes(1)
}) })
expect(openWindow).toHaveBeenCalledTimes(1) expect(openWindow).toHaveBeenCalledTimes(1)
}) })

View File

@ -6,7 +6,8 @@ import { useMemo } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useAppContext } from '@/context/app-context' import { useAppContext } from '@/context/app-context'
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing' import { fetchSubscriptionUrls } from '@/service/billing'
import { consoleClient } from '@/service/client'
import Toast from '../../../../base/toast' import Toast from '../../../../base/toast'
import { ALL_PLANS } from '../../../config' import { ALL_PLANS } from '../../../config'
import { Plan } from '../../../type' import { Plan } from '../../../type'
@ -76,7 +77,7 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
try { try {
if (isCurrentPaidPlan) { if (isCurrentPaidPlan) {
await openAsyncWindow(async () => { await openAsyncWindow(async () => {
const res = await fetchBillingUrl() const res = await consoleClient.billingUrl()
if (res.url) if (res.url)
return res.url return res.url
throw new Error('Failed to open billing page') throw new Error('Failed to open billing page')

View File

@ -30,8 +30,8 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) =
category: PluginCategoryEnum.datasource, category: PluginCategoryEnum.datasource,
exclude, exclude,
type: 'plugin', type: 'plugin',
sortBy: 'install_count', sort_by: 'install_count',
sortOrder: 'DESC', sort_order: 'DESC',
}) })
} }
else { else {
@ -39,10 +39,10 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) =
query: '', query: '',
category: PluginCategoryEnum.datasource, category: PluginCategoryEnum.datasource,
type: 'plugin', type: 'plugin',
pageSize: 1000, page_size: 1000,
exclude, exclude,
sortBy: 'install_count', sort_by: 'install_count',
sortOrder: 'DESC', sort_order: 'DESC',
}) })
} }
}, [queryPlugins, queryPluginsWithDebounced, searchText, exclude]) }, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])

View File

@ -275,8 +275,8 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
category: PluginCategoryEnum.model, category: PluginCategoryEnum.model,
exclude, exclude,
type: 'plugin', type: 'plugin',
sortBy: 'install_count', sort_by: 'install_count',
sortOrder: 'DESC', sort_order: 'DESC',
}) })
} }
else { else {
@ -284,10 +284,10 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
query: '', query: '',
category: PluginCategoryEnum.model, category: PluginCategoryEnum.model,
type: 'plugin', type: 'plugin',
pageSize: 1000, page_size: 1000,
exclude, exclude,
sortBy: 'install_count', sort_by: 'install_count',
sortOrder: 'DESC', sort_order: 'DESC',
}) })
} }
}, [queryPlugins, queryPluginsWithDebounced, searchText, exclude]) }, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])

View File

@ -100,11 +100,11 @@ export const useMarketplacePlugins = () => {
const [queryParams, setQueryParams] = useState<PluginsSearchParams>() const [queryParams, setQueryParams] = useState<PluginsSearchParams>()
const normalizeParams = useCallback((pluginsSearchParams: PluginsSearchParams) => { const normalizeParams = useCallback((pluginsSearchParams: PluginsSearchParams) => {
const pageSize = pluginsSearchParams.pageSize || 40 const page_size = pluginsSearchParams.page_size || 40
return { return {
...pluginsSearchParams, ...pluginsSearchParams,
pageSize, page_size,
} }
}, []) }, [])
@ -116,20 +116,20 @@ export const useMarketplacePlugins = () => {
plugins: [] as Plugin[], plugins: [] as Plugin[],
total: 0, total: 0,
page: 1, page: 1,
pageSize: 40, page_size: 40,
} }
} }
const params = normalizeParams(queryParams) const params = normalizeParams(queryParams)
const { const {
query, query,
sortBy, sort_by,
sortOrder, sort_order,
category, category,
tags, tags,
exclude, exclude,
type, type,
pageSize, page_size,
} = params } = params
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins' const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
@ -137,10 +137,10 @@ export const useMarketplacePlugins = () => {
const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, { const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
body: { body: {
page: pageParam, page: pageParam,
page_size: pageSize, page_size,
query, query,
sort_by: sortBy, sort_by,
sort_order: sortOrder, sort_order,
category: category !== 'all' ? category : '', category: category !== 'all' ? category : '',
tags, tags,
exclude, exclude,
@ -154,7 +154,7 @@ export const useMarketplacePlugins = () => {
plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)), plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)),
total: res.data.total, total: res.data.total,
page: pageParam, page: pageParam,
pageSize, page_size,
} }
} }
catch { catch {
@ -162,13 +162,13 @@ export const useMarketplacePlugins = () => {
plugins: [], plugins: [],
total: 0, total: 0,
page: pageParam, page: pageParam,
pageSize, page_size,
} }
} }
}, },
getNextPageParam: (lastPage) => { getNextPageParam: (lastPage) => {
const nextPage = lastPage.page + 1 const nextPage = lastPage.page + 1
const loaded = lastPage.page * lastPage.pageSize const loaded = lastPage.page * lastPage.page_size
return loaded < (lastPage.total || 0) ? nextPage : undefined return loaded < (lastPage.total || 0) ? nextPage : undefined
}, },
initialPageParam: 1, initialPageParam: 1,

View File

@ -2,8 +2,8 @@ import type { SearchParams } from 'nuqs'
import { dehydrate, HydrationBoundary } from '@tanstack/react-query' import { dehydrate, HydrationBoundary } from '@tanstack/react-query'
import { createLoader } from 'nuqs/server' import { createLoader } from 'nuqs/server'
import { getQueryClientServer } from '@/context/query-client-server' import { getQueryClientServer } from '@/context/query-client-server'
import { marketplaceQuery } from '@/service/client'
import { PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants' import { PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants'
import { marketplaceKeys } from './query'
import { marketplaceSearchParamsParsers } from './search-params' import { marketplaceSearchParamsParsers } from './search-params'
import { getCollectionsParams, getMarketplaceCollectionsAndPlugins } from './utils' import { getCollectionsParams, getMarketplaceCollectionsAndPlugins } from './utils'
@ -23,7 +23,7 @@ async function getDehydratedState(searchParams?: Promise<SearchParams>) {
const queryClient = getQueryClientServer() const queryClient = getQueryClientServer()
await queryClient.prefetchQuery({ await queryClient.prefetchQuery({
queryKey: marketplaceKeys.collections(getCollectionsParams(params.category)), queryKey: marketplaceQuery.collections.queryKey({ input: { query: getCollectionsParams(params.category) } }),
queryFn: () => getMarketplaceCollectionsAndPlugins(getCollectionsParams(params.category)), queryFn: () => getMarketplaceCollectionsAndPlugins(getCollectionsParams(params.category)),
}) })
return dehydrate(queryClient) return dehydrate(queryClient)

View File

@ -60,10 +60,10 @@ vi.mock('@/service/use-plugins', () => ({
// Mock tanstack query // Mock tanstack query
const mockFetchNextPage = vi.fn() const mockFetchNextPage = vi.fn()
const mockHasNextPage = false const mockHasNextPage = false
let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, pageSize: number }> } | undefined let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, page_size: number }> } | undefined
let capturedInfiniteQueryFn: ((ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown>) | null = null let capturedInfiniteQueryFn: ((ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown>) | null = null
let capturedQueryFn: ((ctx: { signal: AbortSignal }) => Promise<unknown>) | null = null let capturedQueryFn: ((ctx: { signal: AbortSignal }) => Promise<unknown>) | null = null
let capturedGetNextPageParam: ((lastPage: { page: number, pageSize: number, total: number }) => number | undefined) | null = null let capturedGetNextPageParam: ((lastPage: { page: number, page_size: number, total: number }) => number | undefined) | null = null
vi.mock('@tanstack/react-query', () => ({ vi.mock('@tanstack/react-query', () => ({
useQuery: vi.fn(({ queryFn, enabled }: { queryFn: (ctx: { signal: AbortSignal }) => Promise<unknown>, enabled: boolean }) => { useQuery: vi.fn(({ queryFn, enabled }: { queryFn: (ctx: { signal: AbortSignal }) => Promise<unknown>, enabled: boolean }) => {
@ -83,7 +83,7 @@ vi.mock('@tanstack/react-query', () => ({
}), }),
useInfiniteQuery: vi.fn(({ queryFn, getNextPageParam, enabled: _enabled }: { useInfiniteQuery: vi.fn(({ queryFn, getNextPageParam, enabled: _enabled }: {
queryFn: (ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown> queryFn: (ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown>
getNextPageParam: (lastPage: { page: number, pageSize: number, total: number }) => number | undefined getNextPageParam: (lastPage: { page: number, page_size: number, total: number }) => number | undefined
enabled: boolean enabled: boolean
}) => { }) => {
// Capture queryFn and getNextPageParam for later testing // Capture queryFn and getNextPageParam for later testing
@ -97,9 +97,9 @@ vi.mock('@tanstack/react-query', () => ({
// Call getNextPageParam to increase coverage // Call getNextPageParam to increase coverage
if (getNextPageParam) { if (getNextPageParam) {
// Test with more data available // Test with more data available
getNextPageParam({ page: 1, pageSize: 40, total: 100 }) getNextPageParam({ page: 1, page_size: 40, total: 100 })
// Test with no more data // Test with no more data
getNextPageParam({ page: 3, pageSize: 40, total: 100 }) getNextPageParam({ page: 3, page_size: 40, total: 100 })
} }
return { return {
data: mockInfiniteQueryData, data: mockInfiniteQueryData,
@ -151,6 +151,7 @@ vi.mock('@/service/base', () => ({
// Mock config // Mock config
vi.mock('@/config', () => ({ vi.mock('@/config', () => ({
API_PREFIX: '/api',
APP_VERSION: '1.0.0', APP_VERSION: '1.0.0',
IS_MARKETPLACE: false, IS_MARKETPLACE: false,
MARKETPLACE_API_PREFIX: 'https://marketplace.dify.ai/api/v1', MARKETPLACE_API_PREFIX: 'https://marketplace.dify.ai/api/v1',
@ -731,10 +732,10 @@ describe('useMarketplacePlugins', () => {
expect(() => { expect(() => {
result.current.queryPlugins({ result.current.queryPlugins({
query: 'test', query: 'test',
sortBy: 'install_count', sort_by: 'install_count',
sortOrder: 'DESC', sort_order: 'DESC',
category: 'tool', category: 'tool',
pageSize: 20, page_size: 20,
}) })
}).not.toThrow() }).not.toThrow()
}) })
@ -747,7 +748,7 @@ describe('useMarketplacePlugins', () => {
result.current.queryPlugins({ result.current.queryPlugins({
query: 'test', query: 'test',
type: 'bundle', type: 'bundle',
pageSize: 40, page_size: 40,
}) })
}).not.toThrow() }).not.toThrow()
}) })
@ -798,8 +799,8 @@ describe('useMarketplacePlugins', () => {
result.current.queryPlugins({ result.current.queryPlugins({
query: 'test', query: 'test',
category: 'all', category: 'all',
sortBy: 'install_count', sort_by: 'install_count',
sortOrder: 'DESC', sort_order: 'DESC',
}) })
}).not.toThrow() }).not.toThrow()
}) })
@ -824,7 +825,7 @@ describe('useMarketplacePlugins', () => {
expect(() => { expect(() => {
result.current.queryPlugins({ result.current.queryPlugins({
query: 'test', query: 'test',
pageSize: 100, page_size: 100,
}) })
}).not.toThrow() }).not.toThrow()
}) })
@ -843,7 +844,7 @@ describe('Hooks queryFn Coverage', () => {
// Set mock data to have pages // Set mock data to have pages
mockInfiniteQueryData = { mockInfiniteQueryData = {
pages: [ pages: [
{ plugins: [{ name: 'plugin1' }], total: 10, page: 1, pageSize: 40 }, { plugins: [{ name: 'plugin1' }], total: 10, page: 1, page_size: 40 },
], ],
} }
@ -863,8 +864,8 @@ describe('Hooks queryFn Coverage', () => {
it('should expose page and total from infinite query data', async () => { it('should expose page and total from infinite query data', async () => {
mockInfiniteQueryData = { mockInfiniteQueryData = {
pages: [ pages: [
{ plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, pageSize: 40 }, { plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, page_size: 40 },
{ plugins: [{ name: 'plugin3' }], total: 20, page: 2, pageSize: 40 }, { plugins: [{ name: 'plugin3' }], total: 20, page: 2, page_size: 40 },
], ],
} }
@ -893,7 +894,7 @@ describe('Hooks queryFn Coverage', () => {
it('should return total from first page when query is set and data exists', async () => { it('should return total from first page when query is set and data exists', async () => {
mockInfiniteQueryData = { mockInfiniteQueryData = {
pages: [ pages: [
{ plugins: [], total: 50, page: 1, pageSize: 40 }, { plugins: [], total: 50, page: 1, page_size: 40 },
], ],
} }
@ -917,8 +918,8 @@ describe('Hooks queryFn Coverage', () => {
type: 'plugin', type: 'plugin',
query: 'search test', query: 'search test',
category: 'model', category: 'model',
sortBy: 'version_updated_at', sort_by: 'version_updated_at',
sortOrder: 'ASC', sort_order: 'ASC',
}) })
expect(result.current).toBeDefined() expect(result.current).toBeDefined()
@ -1027,13 +1028,13 @@ describe('Advanced Hook Integration', () => {
// Test with all possible parameters // Test with all possible parameters
result.current.queryPlugins({ result.current.queryPlugins({
query: 'comprehensive test', query: 'comprehensive test',
sortBy: 'install_count', sort_by: 'install_count',
sortOrder: 'DESC', sort_order: 'DESC',
category: 'tool', category: 'tool',
tags: ['tag1', 'tag2'], tags: ['tag1', 'tag2'],
exclude: ['excluded-plugin'], exclude: ['excluded-plugin'],
type: 'plugin', type: 'plugin',
pageSize: 50, page_size: 50,
}) })
expect(result.current).toBeDefined() expect(result.current).toBeDefined()
@ -1081,9 +1082,9 @@ describe('Direct queryFn Coverage', () => {
result.current.queryPlugins({ result.current.queryPlugins({
query: 'direct test', query: 'direct test',
category: 'tool', category: 'tool',
sortBy: 'install_count', sort_by: 'install_count',
sortOrder: 'DESC', sort_order: 'DESC',
pageSize: 40, page_size: 40,
}) })
// Now queryFn should be captured and enabled // Now queryFn should be captured and enabled
@ -1255,7 +1256,7 @@ describe('Direct queryFn Coverage', () => {
result.current.queryPlugins({ result.current.queryPlugins({
query: 'structure test', query: 'structure test',
pageSize: 20, page_size: 20,
}) })
if (capturedInfiniteQueryFn) { if (capturedInfiniteQueryFn) {
@ -1264,14 +1265,14 @@ describe('Direct queryFn Coverage', () => {
plugins: unknown[] plugins: unknown[]
total: number total: number
page: number page: number
pageSize: number page_size: number
} }
// Verify the returned structure // Verify the returned structure
expect(response).toHaveProperty('plugins') expect(response).toHaveProperty('plugins')
expect(response).toHaveProperty('total') expect(response).toHaveProperty('total')
expect(response).toHaveProperty('page') expect(response).toHaveProperty('page')
expect(response).toHaveProperty('pageSize') expect(response).toHaveProperty('page_size')
} }
}) })
}) })
@ -1296,7 +1297,7 @@ describe('flatMap Coverage', () => {
], ],
total: 5, total: 5,
page: 1, page: 1,
pageSize: 40, page_size: 40,
}, },
{ {
plugins: [ plugins: [
@ -1304,7 +1305,7 @@ describe('flatMap Coverage', () => {
], ],
total: 5, total: 5,
page: 2, page: 2,
pageSize: 40, page_size: 40,
}, },
], ],
} }
@ -1336,8 +1337,8 @@ describe('flatMap Coverage', () => {
it('should test hook with pages data for flatMap path', async () => { it('should test hook with pages data for flatMap path', async () => {
mockInfiniteQueryData = { mockInfiniteQueryData = {
pages: [ pages: [
{ plugins: [], total: 100, page: 1, pageSize: 40 }, { plugins: [], total: 100, page: 1, page_size: 40 },
{ plugins: [], total: 100, page: 2, pageSize: 40 }, { plugins: [], total: 100, page: 2, page_size: 40 },
], ],
} }
@ -1371,7 +1372,7 @@ describe('flatMap Coverage', () => {
plugins: unknown[] plugins: unknown[]
total: number total: number
page: number page: number
pageSize: number page_size: number
} }
// When error is caught, should return fallback data // When error is caught, should return fallback data
expect(response.plugins).toEqual([]) expect(response.plugins).toEqual([])
@ -1392,15 +1393,15 @@ describe('flatMap Coverage', () => {
// Test getNextPageParam function directly // Test getNextPageParam function directly
if (capturedGetNextPageParam) { if (capturedGetNextPageParam) {
// When there are more pages // When there are more pages
const nextPage = capturedGetNextPageParam({ page: 1, pageSize: 40, total: 100 }) const nextPage = capturedGetNextPageParam({ page: 1, page_size: 40, total: 100 })
expect(nextPage).toBe(2) expect(nextPage).toBe(2)
// When all data is loaded // When all data is loaded
const noMorePages = capturedGetNextPageParam({ page: 3, pageSize: 40, total: 100 }) const noMorePages = capturedGetNextPageParam({ page: 3, page_size: 40, total: 100 })
expect(noMorePages).toBeUndefined() expect(noMorePages).toBeUndefined()
// Edge case: exactly at boundary // Edge case: exactly at boundary
const atBoundary = capturedGetNextPageParam({ page: 2, pageSize: 50, total: 100 }) const atBoundary = capturedGetNextPageParam({ page: 2, page_size: 50, total: 100 })
expect(atBoundary).toBeUndefined() expect(atBoundary).toBeUndefined()
} }
}) })
@ -1427,7 +1428,7 @@ describe('flatMap Coverage', () => {
plugins: unknown[] plugins: unknown[]
total: number total: number
page: number page: number
pageSize: number page_size: number
} }
// Catch block should return fallback values // Catch block should return fallback values
expect(response.plugins).toEqual([]) expect(response.plugins).toEqual([])
@ -1446,7 +1447,7 @@ describe('flatMap Coverage', () => {
plugins: [{ name: 'test-plugin-1' }, { name: 'test-plugin-2' }], plugins: [{ name: 'test-plugin-1' }, { name: 'test-plugin-2' }],
total: 10, total: 10,
page: 1, page: 1,
pageSize: 40, page_size: 40,
}, },
], ],
} }
@ -1489,9 +1490,12 @@ describe('Async Utils', () => {
{ type: 'plugin', org: 'test', name: 'plugin2' }, { type: 'plugin', org: 'test', name: 'plugin2' },
] ]
globalThis.fetch = vi.fn().mockResolvedValue({ globalThis.fetch = vi.fn().mockResolvedValue(
json: () => Promise.resolve({ data: { plugins: mockPlugins } }), new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
}) status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const { getMarketplacePluginsByCollectionId } = await import('./utils') const { getMarketplacePluginsByCollectionId } = await import('./utils')
const result = await getMarketplacePluginsByCollectionId('test-collection', { const result = await getMarketplacePluginsByCollectionId('test-collection', {
@ -1514,19 +1518,26 @@ describe('Async Utils', () => {
}) })
it('should pass abort signal when provided', async () => { it('should pass abort signal when provided', async () => {
const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }] const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
globalThis.fetch = vi.fn().mockResolvedValue({ globalThis.fetch = vi.fn().mockResolvedValue(
json: () => Promise.resolve({ data: { plugins: mockPlugins } }), new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
}) status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const controller = new AbortController() const controller = new AbortController()
const { getMarketplacePluginsByCollectionId } = await import('./utils') const { getMarketplacePluginsByCollectionId } = await import('./utils')
await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal }) await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal })
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
expect(globalThis.fetch).toHaveBeenCalledWith( expect(globalThis.fetch).toHaveBeenCalledWith(
expect.any(String), expect.any(Request),
expect.objectContaining({ signal: controller.signal }), expect.any(Object),
) )
const call = vi.mocked(globalThis.fetch).mock.calls[0]
const request = call[0] as Request
expect(request.url).toContain('test-collection')
}) })
}) })
@ -1535,19 +1546,25 @@ describe('Async Utils', () => {
const mockCollections = [ const mockCollections = [
{ name: 'collection1', label: {}, description: {}, rule: '', created_at: '', updated_at: '' }, { name: 'collection1', label: {}, description: {}, rule: '', created_at: '', updated_at: '' },
] ]
const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }] const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
let callCount = 0 let callCount = 0
globalThis.fetch = vi.fn().mockImplementation(() => { globalThis.fetch = vi.fn().mockImplementation(() => {
callCount++ callCount++
if (callCount === 1) { if (callCount === 1) {
return Promise.resolve({ return Promise.resolve(
json: () => Promise.resolve({ data: { collections: mockCollections } }), new Response(JSON.stringify({ data: { collections: mockCollections } }), {
}) status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
} }
return Promise.resolve({ return Promise.resolve(
json: () => Promise.resolve({ data: { plugins: mockPlugins } }), new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
}) status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
}) })
const { getMarketplaceCollectionsAndPlugins } = await import('./utils') const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
@ -1571,9 +1588,12 @@ describe('Async Utils', () => {
}) })
it('should append condition and type to URL when provided', async () => { it('should append condition and type to URL when provided', async () => {
globalThis.fetch = vi.fn().mockResolvedValue({ globalThis.fetch = vi.fn().mockResolvedValue(
json: () => Promise.resolve({ data: { collections: [] } }), new Response(JSON.stringify({ data: { collections: [] } }), {
}) status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const { getMarketplaceCollectionsAndPlugins } = await import('./utils') const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
await getMarketplaceCollectionsAndPlugins({ await getMarketplaceCollectionsAndPlugins({
@ -1581,10 +1601,11 @@ describe('Async Utils', () => {
type: 'bundle', type: 'bundle',
}) })
expect(globalThis.fetch).toHaveBeenCalledWith( // oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
expect.stringContaining('condition=category=tool'), expect(globalThis.fetch).toHaveBeenCalled()
expect.any(Object), const call = vi.mocked(globalThis.fetch).mock.calls[0]
) const request = call[0] as Request
expect(request.url).toContain('condition=category%3Dtool')
}) })
}) })
}) })

View File

@ -1,22 +1,14 @@
import type { CollectionsAndPluginsSearchParams, PluginsSearchParams } from './types' import type { PluginsSearchParams } from './types'
import type { MarketPlaceInputs } from '@/contract/router'
import { useInfiniteQuery, useQuery } from '@tanstack/react-query' import { useInfiniteQuery, useQuery } from '@tanstack/react-query'
import { marketplaceQuery } from '@/service/client'
import { getMarketplaceCollectionsAndPlugins, getMarketplacePlugins } from './utils' import { getMarketplaceCollectionsAndPlugins, getMarketplacePlugins } from './utils'
// TODO: Avoid manual maintenance of query keys and better service management,
// https://github.com/langgenius/dify/issues/30342
export const marketplaceKeys = {
all: ['marketplace'] as const,
collections: (params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collections', params] as const,
collectionPlugins: (collectionId: string, params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collectionPlugins', collectionId, params] as const,
plugins: (params?: PluginsSearchParams) => [...marketplaceKeys.all, 'plugins', params] as const,
}
export function useMarketplaceCollectionsAndPlugins( export function useMarketplaceCollectionsAndPlugins(
collectionsParams: CollectionsAndPluginsSearchParams, collectionsParams: MarketPlaceInputs['collections']['query'],
) { ) {
return useQuery({ return useQuery({
queryKey: marketplaceKeys.collections(collectionsParams), queryKey: marketplaceQuery.collections.queryKey({ input: { query: collectionsParams } }),
queryFn: ({ signal }) => getMarketplaceCollectionsAndPlugins(collectionsParams, { signal }), queryFn: ({ signal }) => getMarketplaceCollectionsAndPlugins(collectionsParams, { signal }),
}) })
} }
@ -25,11 +17,16 @@ export function useMarketplacePlugins(
queryParams: PluginsSearchParams | undefined, queryParams: PluginsSearchParams | undefined,
) { ) {
return useInfiniteQuery({ return useInfiniteQuery({
queryKey: marketplaceKeys.plugins(queryParams), queryKey: marketplaceQuery.searchAdvanced.queryKey({
input: {
body: queryParams!,
params: { kind: queryParams?.type === 'bundle' ? 'bundles' : 'plugins' },
},
}),
queryFn: ({ pageParam = 1, signal }) => getMarketplacePlugins(queryParams, pageParam, signal), queryFn: ({ pageParam = 1, signal }) => getMarketplacePlugins(queryParams, pageParam, signal),
getNextPageParam: (lastPage) => { getNextPageParam: (lastPage) => {
const nextPage = lastPage.page + 1 const nextPage = lastPage.page + 1
const loaded = lastPage.page * lastPage.pageSize const loaded = lastPage.page * lastPage.page_size
return loaded < (lastPage.total || 0) ? nextPage : undefined return loaded < (lastPage.total || 0) ? nextPage : undefined
}, },
initialPageParam: 1, initialPageParam: 1,

View File

@ -26,8 +26,8 @@ export function useMarketplaceData() {
query: searchPluginText, query: searchPluginText,
category: activePluginType === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginType, category: activePluginType === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginType,
tags: filterPluginTags, tags: filterPluginTags,
sortBy: sort.sortBy, sort_by: sort.sortBy,
sortOrder: sort.sortOrder, sort_order: sort.sortOrder,
type: getMarketplaceListFilterType(activePluginType), type: getMarketplaceListFilterType(activePluginType),
} }
}, [isSearchMode, searchPluginText, activePluginType, filterPluginTags, sort]) }, [isSearchMode, searchPluginText, activePluginType, filterPluginTags, sort])

View File

@ -30,9 +30,9 @@ export type MarketplaceCollectionPluginsResponse = {
export type PluginsSearchParams = { export type PluginsSearchParams = {
query: string query: string
page?: number page?: number
pageSize?: number page_size?: number
sortBy?: string sort_by?: string
sortOrder?: string sort_order?: string
category?: string category?: string
tags?: string[] tags?: string[]
exclude?: string[] exclude?: string[]

View File

@ -4,14 +4,12 @@ import type {
MarketplaceCollection, MarketplaceCollection,
PluginsSearchParams, PluginsSearchParams,
} from '@/app/components/plugins/marketplace/types' } from '@/app/components/plugins/marketplace/types'
import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types' import type { Plugin } from '@/app/components/plugins/types'
import { PluginCategoryEnum } from '@/app/components/plugins/types' import { PluginCategoryEnum } from '@/app/components/plugins/types'
import { import {
APP_VERSION,
IS_MARKETPLACE,
MARKETPLACE_API_PREFIX, MARKETPLACE_API_PREFIX,
} from '@/config' } from '@/config'
import { postMarketplace } from '@/service/base' import { marketplaceClient } from '@/service/client'
import { getMarketplaceUrl } from '@/utils/var' import { getMarketplaceUrl } from '@/utils/var'
import { PLUGIN_TYPE_SEARCH_MAP } from './constants' import { PLUGIN_TYPE_SEARCH_MAP } from './constants'
@ -19,10 +17,6 @@ type MarketplaceFetchOptions = {
signal?: AbortSignal signal?: AbortSignal
} }
const getMarketplaceHeaders = () => new Headers({
'X-Dify-Version': !IS_MARKETPLACE ? APP_VERSION : '999.0.0',
})
export const getPluginIconInMarketplace = (plugin: Plugin) => { export const getPluginIconInMarketplace = (plugin: Plugin) => {
if (plugin.type === 'bundle') if (plugin.type === 'bundle')
return `${MARKETPLACE_API_PREFIX}/bundles/${plugin.org}/${plugin.name}/icon` return `${MARKETPLACE_API_PREFIX}/bundles/${plugin.org}/${plugin.name}/icon`
@ -65,24 +59,15 @@ export const getMarketplacePluginsByCollectionId = async (
let plugins: Plugin[] = [] let plugins: Plugin[] = []
try { try {
const url = `${MARKETPLACE_API_PREFIX}/collections/${collectionId}/plugins` const marketplaceCollectionPluginsDataJson = await marketplaceClient.collectionPlugins({
const headers = getMarketplaceHeaders() params: {
const marketplaceCollectionPluginsData = await globalThis.fetch( collectionId,
url,
{
cache: 'no-store',
method: 'POST',
headers,
signal: options?.signal,
body: JSON.stringify({
category: query?.category,
exclude: query?.exclude,
type: query?.type,
}),
}, },
) body: query,
const marketplaceCollectionPluginsDataJson = await marketplaceCollectionPluginsData.json() }, {
plugins = (marketplaceCollectionPluginsDataJson.data.plugins || []).map((plugin: Plugin) => getFormattedPlugin(plugin)) signal: options?.signal,
})
plugins = (marketplaceCollectionPluginsDataJson.data?.plugins || []).map(plugin => getFormattedPlugin(plugin))
} }
// eslint-disable-next-line unused-imports/no-unused-vars // eslint-disable-next-line unused-imports/no-unused-vars
catch (e) { catch (e) {
@ -99,22 +84,16 @@ export const getMarketplaceCollectionsAndPlugins = async (
let marketplaceCollections: MarketplaceCollection[] = [] let marketplaceCollections: MarketplaceCollection[] = []
let marketplaceCollectionPluginsMap: Record<string, Plugin[]> = {} let marketplaceCollectionPluginsMap: Record<string, Plugin[]> = {}
try { try {
let marketplaceUrl = `${MARKETPLACE_API_PREFIX}/collections?page=1&page_size=100` const marketplaceCollectionsDataJson = await marketplaceClient.collections({
if (query?.condition) query: {
marketplaceUrl += `&condition=${query.condition}` ...query,
if (query?.type) page: 1,
marketplaceUrl += `&type=${query.type}` page_size: 100,
const headers = getMarketplaceHeaders()
const marketplaceCollectionsData = await globalThis.fetch(
marketplaceUrl,
{
headers,
cache: 'no-store',
signal: options?.signal,
}, },
) }, {
const marketplaceCollectionsDataJson = await marketplaceCollectionsData.json() signal: options?.signal,
marketplaceCollections = marketplaceCollectionsDataJson.data.collections || [] })
marketplaceCollections = marketplaceCollectionsDataJson.data?.collections || []
await Promise.all(marketplaceCollections.map(async (collection: MarketplaceCollection) => { await Promise.all(marketplaceCollections.map(async (collection: MarketplaceCollection) => {
const plugins = await getMarketplacePluginsByCollectionId(collection.name, query, options) const plugins = await getMarketplacePluginsByCollectionId(collection.name, query, options)
@ -143,42 +122,42 @@ export const getMarketplacePlugins = async (
plugins: [] as Plugin[], plugins: [] as Plugin[],
total: 0, total: 0,
page: 1, page: 1,
pageSize: 40, page_size: 40,
} }
} }
const { const {
query, query,
sortBy, sort_by,
sortOrder, sort_order,
category, category,
tags, tags,
type, type,
pageSize = 40, page_size = 40,
} = queryParams } = queryParams
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
try { try {
const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, { const res = await marketplaceClient.searchAdvanced({
params: {
kind: type === 'bundle' ? 'bundles' : 'plugins',
},
body: { body: {
page: pageParam, page: pageParam,
page_size: pageSize, page_size,
query, query,
sort_by: sortBy, sort_by,
sort_order: sortOrder, sort_order,
category: category !== 'all' ? category : '', category: category !== 'all' ? category : '',
tags, tags,
type,
}, },
signal, }, { signal })
})
const resPlugins = res.data.bundles || res.data.plugins || [] const resPlugins = res.data.bundles || res.data.plugins || []
return { return {
plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)), plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)),
total: res.data.total, total: res.data.total,
page: pageParam, page: pageParam,
pageSize, page_size,
} }
} }
catch { catch {
@ -186,7 +165,7 @@ export const getMarketplacePlugins = async (
plugins: [], plugins: [],
total: 0, total: 0,
page: pageParam, page: pageParam,
pageSize, page_size,
} }
} }
} }

View File

@ -1606,6 +1606,7 @@ export const useNodesInteractions = () => {
const offsetX = currentPosition.x - x const offsetX = currentPosition.x - x
const offsetY = currentPosition.y - y const offsetY = currentPosition.y - y
let idMapping: Record<string, string> = {} let idMapping: Record<string, string> = {}
const parentChildrenToAppend: { parentId: string, childId: string, childType: BlockEnum }[] = []
clipboardElements.forEach((nodeToPaste, index) => { clipboardElements.forEach((nodeToPaste, index) => {
const nodeType = nodeToPaste.data.type const nodeType = nodeToPaste.data.type
@ -1619,6 +1620,7 @@ export const useNodesInteractions = () => {
_isBundled: false, _isBundled: false,
_connectedSourceHandleIds: [], _connectedSourceHandleIds: [],
_connectedTargetHandleIds: [], _connectedTargetHandleIds: [],
_dimmed: false,
title: genNewNodeTitleFromOld(nodeToPaste.data.title), title: genNewNodeTitleFromOld(nodeToPaste.data.title),
}, },
position: { position: {
@ -1686,27 +1688,24 @@ export const useNodesInteractions = () => {
return return
// handle paste to nested block // handle paste to nested block
if (selectedNode.data.type === BlockEnum.Iteration) { if (selectedNode.data.type === BlockEnum.Iteration || selectedNode.data.type === BlockEnum.Loop) {
newNode.data.isInIteration = true const isIteration = selectedNode.data.type === BlockEnum.Iteration
newNode.data.iteration_id = selectedNode.data.iteration_id
newNode.parentId = selectedNode.id newNode.data.isInIteration = isIteration
newNode.positionAbsolute = { newNode.data.iteration_id = isIteration ? selectedNode.id : undefined
x: newNode.position.x, newNode.data.isInLoop = !isIteration
y: newNode.position.y, newNode.data.loop_id = !isIteration ? selectedNode.id : undefined
}
// set position base on parent node
newNode.position = getNestedNodePosition(newNode, selectedNode)
}
else if (selectedNode.data.type === BlockEnum.Loop) {
newNode.data.isInLoop = true
newNode.data.loop_id = selectedNode.data.loop_id
newNode.parentId = selectedNode.id newNode.parentId = selectedNode.id
newNode.zIndex = isIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX
newNode.positionAbsolute = { newNode.positionAbsolute = {
x: newNode.position.x, x: newNode.position.x,
y: newNode.position.y, y: newNode.position.y,
} }
// set position base on parent node // set position base on parent node
newNode.position = getNestedNodePosition(newNode, selectedNode) newNode.position = getNestedNodePosition(newNode, selectedNode)
// update parent children array like native add
parentChildrenToAppend.push({ parentId: selectedNode.id, childId: newNode.id, childType: newNode.data.type })
} }
} }
} }
@ -1737,7 +1736,17 @@ export const useNodesInteractions = () => {
} }
}) })
setNodes([...nodes, ...nodesToPaste]) const newNodes = produce(nodes, (draft: Node[]) => {
parentChildrenToAppend.forEach(({ parentId, childId, childType }) => {
const p = draft.find(n => n.id === parentId)
if (p) {
p.data._children?.push({ nodeId: childId, nodeType: childType })
}
})
draft.push(...nodesToPaste)
})
setNodes(newNodes)
setEdges([...edges, ...edgesToPaste]) setEdges([...edges, ...edgesToPaste])
saveStateToHistory(WorkflowHistoryEvent.NodePaste, { saveStateToHistory(WorkflowHistoryEvent.NodePaste, {
nodeId: nodesToPaste?.[0]?.id, nodeId: nodesToPaste?.[0]?.id,

Some files were not shown because too many files have changed in this diff Show More