mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/hitl-frontend
This commit is contained in:
commit
dfb25df5ec
|
|
@ -35,7 +35,7 @@ from libs.rsa import generate_key_pair
|
|||
from models import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
|
|
@ -64,8 +64,10 @@ def reset_password(email, new_password, password_confirm):
|
|||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
normalized_email = email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
|
|
@ -86,7 +88,7 @@ def reset_password(email, new_password, password_confirm):
|
|||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
|
|
@ -102,20 +104,22 @@ def reset_email(email, new_email, email_confirm):
|
|||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
normalized_new_email = new_email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
email_validate(normalized_new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
account.email = normalized_new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
|
|
@ -660,7 +664,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
|||
return
|
||||
|
||||
# Create account
|
||||
email = email.strip()
|
||||
email = email.strip().lower()
|
||||
|
||||
if "@" not in email:
|
||||
click.echo(click.style("Invalid email address.", fg="red"))
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
|
|||
|
||||
class VolcengineTOSStorageConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Volcengine Tinder Object Storage (TOS)
|
||||
Configuration settings for Volcengine Torch Object Storage (TOS)
|
||||
"""
|
||||
|
||||
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(
|
||||
|
|
|
|||
|
|
@ -63,10 +63,9 @@ class ActivateCheckApi(Resource):
|
|||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workspaceId = args.workspace_id
|
||||
reg_email = args.email
|
||||
token = args.token
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
tenant = invitation.get("tenant", None)
|
||||
|
|
@ -100,11 +99,12 @@ class ActivateApi(Resource):
|
|||
def post(self):
|
||||
args = ActivatePayload.model_validate(console_ns.payload)
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
|
||||
normalized_request_email = args.email.lower() if args.email else None
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
|
||||
RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
|
||||
|
||||
account = invitation["account"]
|
||||
account.name = args.name
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
|
|||
@email_register_enabled
|
||||
def post(self):
|
||||
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
|
|
@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
|
|||
if args.language in languages:
|
||||
language = args.language
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
token = None
|
||||
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
|
|
@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
|
|||
def post(self):
|
||||
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
|
||||
if is_email_register_error_rate_limit:
|
||||
raise EmailRegisterLimitError()
|
||||
|
||||
|
|
@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
|
|||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args.email)
|
||||
AccountService.add_email_register_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
|
|
@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
|
|||
user_email, code=args.code, additional_data={"phase": "register"}
|
||||
)
|
||||
|
||||
AccountService.reset_email_register_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_email_register_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/email-register")
|
||||
|
|
@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
|
|||
AccountService.revoke_email_register_token(args.token)
|
||||
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(email, args.password_confirm)
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
def _create_new_account(self, email, password) -> Account | None:
|
||||
def _create_new_account(self, email: str, password: str) -> Account | None:
|
||||
# Create new account if allowed
|
||||
account = None
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import secrets
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
|
|||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
|
@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
@email_password_login_enabled
|
||||
def post(self):
|
||||
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
|
|
@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
)
|
||||
|
|
@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
|
|||
def post(self):
|
||||
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
|
|
@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
|
|||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args.email)
|
||||
AccountService.add_forgot_password_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
|
|
@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
|
|||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args.code, additional_data={"phase": "reset"}
|
||||
token_email, code=args.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_forgot_password_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password/resets")
|
||||
|
|
@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
|
|||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
|
|
|
|||
|
|
@ -90,32 +90,38 @@ class LoginApi(Resource):
|
|||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
args = LoginPayload.model_validate(console_ns.payload)
|
||||
request_email = args.email
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invite_token = args.invite_token
|
||||
invitation_data: dict[str, Any] | None = None
|
||||
if args.invite_token:
|
||||
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
|
||||
if invite_token:
|
||||
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
|
||||
if invitation_data is None:
|
||||
invite_token = None
|
||||
|
||||
try:
|
||||
if invitation_data:
|
||||
data = invitation_data.get("data", {})
|
||||
invitee_email = data.get("email") if data else None
|
||||
if invitee_email != args.email:
|
||||
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
|
||||
if invitee_email_normalized != normalized_email:
|
||||
raise InvalidEmailError()
|
||||
account = AccountService.authenticate(args.email, args.password, args.invite_token)
|
||||
else:
|
||||
account = AccountService.authenticate(args.email, args.password)
|
||||
account = _authenticate_account_with_case_fallback(
|
||||
request_email, normalized_email, args.password, invite_token
|
||||
)
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args.email)
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountPasswordError as exc:
|
||||
AccountService.add_login_error_rate_limit(normalized_email)
|
||||
raise AuthenticationFailedError() from exc
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
|
|
@ -130,7 +136,7 @@ class LoginApi(Resource):
|
|||
}
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
|
@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
|
|||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
account = _get_account_with_case_fallback(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
account=account,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
|
|
@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
|
|
@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
account = _get_account_with_case_fallback(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_email_code_login_email(email=args.email, language=language)
|
||||
token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
|
|
@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource):
|
|||
def post(self):
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
original_email = args.email
|
||||
user_email = original_email.lower()
|
||||
language = args.language
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args.email:
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != user_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args.code:
|
||||
|
|
@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource):
|
|||
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
account = _get_account_with_case_fallback(original_email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
|
|
@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource):
|
|||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
|
@ -309,3 +320,22 @@ class RefreshTokenApi(Resource):
|
|||
return response
|
||||
except Exception as e:
|
||||
return {"result": "fail", "message": str(e)}, 401
|
||||
|
||||
|
||||
def _get_account_with_case_fallback(email: str):
|
||||
account = AccountService.get_user_through_email(email)
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return AccountService.get_user_through_email(email.lower())
|
||||
|
||||
|
||||
def _authenticate_account_with_case_fallback(
|
||||
original_email: str, normalized_email: str, password: str, invite_token: str | None
|
||||
):
|
||||
try:
|
||||
return AccountService.authenticate(original_email, password, invite_token)
|
||||
except services.errors.account.AccountPasswordError:
|
||||
if original_email == normalized_email:
|
||||
raise
|
||||
return AccountService.authenticate(normalized_email, password, invite_token)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import logging
|
|||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
|
|
@ -118,7 +117,10 @@ class OAuthCallback(Resource):
|
|||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
if invitation:
|
||||
invitation_email = invitation.get("email", None)
|
||||
if invitation_email != user_info.email:
|
||||
invitation_email_normalized = (
|
||||
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
|
||||
)
|
||||
if invitation_email_normalized != user_info.email.lower():
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
|
@ -175,7 +177,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
|||
|
||||
if not account:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
||||
|
||||
return account
|
||||
|
||||
|
|
@ -197,9 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
|||
tenant_was_created.send(new_tenant)
|
||||
|
||||
if not account:
|
||||
normalized_email = user_info.email.lower()
|
||||
oauth_new_user = True
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountRegisterError(
|
||||
description=(
|
||||
"This email account has been deleted within the past "
|
||||
|
|
@ -210,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
|||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
email=normalized_email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from typing import Literal, cast
|
|||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
|
@ -104,6 +104,15 @@ class DocumentRenamePayload(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
class DocumentDatasetListParam(BaseModel):
|
||||
page: int = Field(1, title="Page", description="Page number.")
|
||||
limit: int = Field(20, title="Limit", description="Page size.")
|
||||
search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
|
||||
sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
|
||||
status: str | None = Field(None, title="Status", description="Document status.")
|
||||
fetch_val: str = Field("false", alias="fetch")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
KnowledgeConfig,
|
||||
|
|
@ -225,14 +234,16 @@ class DatasetDocumentListApi(Resource):
|
|||
def get(self, dataset_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
sort = request.args.get("sort", default="-created_at", type=str)
|
||||
status = request.args.get("status", default=None, type=str)
|
||||
raw_args = request.args.to_dict()
|
||||
param = DocumentDatasetListParam.model_validate(raw_args)
|
||||
page = param.page
|
||||
limit = param.limit
|
||||
search = param.search
|
||||
sort = param.sort_by
|
||||
status = param.status
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch_val = request.args.get("fetch", default="false")
|
||||
fetch_val = param.fetch_val
|
||||
if isinstance(fetch_val, bool):
|
||||
fetch = fetch_val
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -84,10 +84,11 @@ class SetupApi(Resource):
|
|||
raise NotInitValidateError()
|
||||
|
||||
args = SetupRequestPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
name=args.name,
|
||||
password=args.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ from fields.member_fields import account_fields
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account, AccountIntegrate, InvitationCode
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
|
@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource):
|
|||
else:
|
||||
language = "en-US"
|
||||
account = None
|
||||
user_email = args.email
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
|
|
@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource):
|
|||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email != current_user.email:
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
raise InvalidEmailError()
|
||||
|
||||
user_email = current_user.email
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
email_for_sending = account.email
|
||||
user_email = account.email
|
||||
|
||||
token = AccountService.send_change_email_email(
|
||||
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
|
||||
account=account,
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=args.phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
|
@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource):
|
|||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailValidityPayload.model_validate(payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
|
||||
if is_change_email_error_rate_limit:
|
||||
raise EmailChangeLimitError()
|
||||
|
||||
|
|
@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource):
|
|||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_change_email_error_rate_limit(args.email)
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
|
|
@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource):
|
|||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/reset")
|
||||
|
|
@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource):
|
|||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
normalized_new_email = args.new_email.lower()
|
||||
|
||||
if AccountService.is_account_in_freeze(args.new_email):
|
||||
if AccountService.is_account_in_freeze(normalized_new_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if not AccountService.check_email_unique(args.new_email):
|
||||
if not AccountService.check_email_unique(normalized_new_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
|
|
@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource):
|
|||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email != old_email:
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
email=args.new_email,
|
||||
email=normalized_new_email,
|
||||
)
|
||||
|
||||
return updated_account
|
||||
|
|
@ -645,8 +657,9 @@ class CheckEmailUnique(Resource):
|
|||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = CheckEmailUniquePayload.model_validate(payload)
|
||||
if AccountService.is_account_in_freeze(args.email):
|
||||
normalized_email = args.email.lower()
|
||||
if AccountService.is_account_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
if not AccountService.check_email_unique(args.email):
|
||||
if not AccountService.check_email_unique(normalized_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -116,26 +116,31 @@ class MemberInviteEmailApi(Resource):
|
|||
raise WorkspaceMembersLimitExceeded()
|
||||
|
||||
for invitee_email in invitee_emails:
|
||||
normalized_invitee_email = invitee_email.lower()
|
||||
try:
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
token = RegisterService.invite_new_member(
|
||||
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
||||
tenant=inviter.current_tenant,
|
||||
email=invitee_email,
|
||||
language=interface_language,
|
||||
role=invitee_role,
|
||||
inviter=inviter,
|
||||
)
|
||||
encoded_invitee_email = parse.quote(invitee_email)
|
||||
encoded_invitee_email = parse.quote(normalized_invitee_email)
|
||||
invitation_results.append(
|
||||
{
|
||||
"status": "success",
|
||||
"email": invitee_email,
|
||||
"email": normalized_invitee_email,
|
||||
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
|
||||
}
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append(
|
||||
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
|
||||
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
|
||||
)
|
||||
except Exception as e:
|
||||
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
|
||||
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import secrets
|
|||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
|
@ -22,7 +21,7 @@ from controllers.web import web_ns
|
|||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
|
|
@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
def post(self):
|
||||
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
request_email = payload.email
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
|
@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
||||
token = None
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
|
||||
token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
|
@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource):
|
|||
def post(self):
|
||||
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
user_email = payload.email
|
||||
user_email = payload.email.lower()
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
|
|
@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource):
|
|||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if payload.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(payload.email)
|
||||
AccountService.add_forgot_password_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
|
|
@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource):
|
|||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=payload.code, additional_data={"phase": "reset"}
|
||||
token_email, code=payload.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(payload.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_forgot_password_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password/resets")
|
||||
|
|
@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource):
|
|||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
|
|
|
|||
|
|
@ -197,25 +197,29 @@ class EmailCodeLoginApi(Resource):
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args["email"].lower()
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args["email"]:
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
if normalized_token_email != user_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||
account = WebAppAuthService.get_user_through_email(user_email)
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
if not account:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
|
|
@ -200,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
|
|||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
|
|
@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
|
|
@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
trace_manager=app_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
def _initialize_conversation_variables(self) -> list[VariableUnion]:
|
||||
def _initialize_conversation_variables(self) -> list[Variable]:
|
||||
"""
|
||||
Initialize conversation variables for the current conversation.
|
||||
|
||||
|
|
@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
conversation_variables = [var.to_variable() for var in existing_variables]
|
||||
|
||||
session.commit()
|
||||
return cast(list[VariableUnion], conversation_variables)
|
||||
return cast(list[Variable], conversation_variables)
|
||||
|
||||
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -189,7 +189,7 @@ class BaseAppGenerator:
|
|||
elif value == 0:
|
||||
value = False
|
||||
case VariableEntityType.JSON_OBJECT:
|
||||
if not isinstance(value, dict):
|
||||
if value and not isinstance(value, dict):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables import VariableBase
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType
|
||||
|
|
@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
|||
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
|
||||
continue
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
if not isinstance(variable, VariableBase):
|
||||
logger.warning(
|
||||
"Conversation variable not found in variable pool. selector=%s",
|
||||
selector,
|
||||
|
|
|
|||
|
|
@ -251,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
|
|||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_calls:
|
||||
return False
|
||||
|
||||
return True
|
||||
return super().is_empty() and not self.tool_calls
|
||||
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from opentelemetry.trace import SpanKind
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
|
|
@ -54,7 +55,7 @@ from core.ops.entities.trace_entity import (
|
|||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
|
|
@ -273,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
|
|
@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
|
|
@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ class SpanBuilder:
|
|||
attributes=span_data.attributes,
|
||||
events=span_data.events,
|
||||
links=span_data.links,
|
||||
kind=trace_api.SpanKind.INTERNAL,
|
||||
kind=span_data.span_kind,
|
||||
status=span_data.status,
|
||||
start_time=span_data.start_time,
|
||||
end_time=span_data.end_time,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any
|
|||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
|
@ -34,3 +34,4 @@ class SpanData(BaseModel):
|
|||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
|
||||
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from typing import Any, cast
|
|||
|
||||
from flask import has_request_context
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.tools.__base.tool import Tool
|
||||
|
|
@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import (
|
|||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant
|
||||
|
|
@ -230,30 +229,32 @@ class WorkflowTool(Tool):
|
|||
"""
|
||||
Resolve user from database (worker/Celery context).
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
|
||||
tenant = session.scalar(tenant_stmt)
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
user_stmt = select(Account).where(Account.id == user_id)
|
||||
user = session.scalar(user_stmt)
|
||||
if user:
|
||||
user.current_tenant = tenant
|
||||
session.expunge(user)
|
||||
return user
|
||||
|
||||
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
|
||||
end_user = session.scalar(end_user_stmt)
|
||||
if end_user:
|
||||
session.expunge(end_user)
|
||||
return end_user
|
||||
|
||||
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
|
||||
tenant = db.session.scalar(tenant_stmt)
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
user_stmt = select(Account).where(Account.id == user_id)
|
||||
user = db.session.scalar(user_stmt)
|
||||
if user:
|
||||
user.current_tenant = tenant
|
||||
return user
|
||||
|
||||
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
|
||||
end_user = db.session.scalar(end_user_stmt)
|
||||
if end_user:
|
||||
return end_user
|
||||
|
||||
return None
|
||||
|
||||
def _get_workflow(self, app_id: str, version: str) -> Workflow:
|
||||
"""
|
||||
get the workflow by app id and version
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
if not version:
|
||||
stmt = (
|
||||
select(Workflow)
|
||||
|
|
@ -265,22 +266,24 @@ class WorkflowTool(Tool):
|
|||
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found or not published")
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found or not published")
|
||||
|
||||
return workflow
|
||||
session.expunge(workflow)
|
||||
return workflow
|
||||
|
||||
def _get_app(self, app_id: str) -> App:
|
||||
"""
|
||||
get the app by app id
|
||||
"""
|
||||
stmt = select(App).where(App.id == app_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
return app
|
||||
session.expunge(app)
|
||||
return app
|
||||
|
||||
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from .variables import (
|
|||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
VariableBase,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -62,4 +63,5 @@ __all__ = [
|
|||
"StringSegment",
|
||||
"StringVariable",
|
||||
"Variable",
|
||||
"VariableBase",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
|
|||
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
|
||||
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||
# - `SegmentGroup`, which is not added to the variable pool.
|
||||
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
|
||||
# - `VariableBase` and its subclasses, which are handled by `Variable`.
|
||||
SegmentUnion: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[NoneSegment, Tag(SegmentType.NONE)]
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from .segments import (
|
|||
from .types import SegmentType
|
||||
|
||||
|
||||
class Variable(Segment):
|
||||
class VariableBase(Segment):
|
||||
"""
|
||||
A variable is a segment that has a name.
|
||||
|
||||
|
|
@ -45,23 +45,23 @@ class Variable(Segment):
|
|||
selector: Sequence[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StringVariable(StringSegment, Variable):
|
||||
class StringVariable(StringSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class FloatVariable(FloatSegment, Variable):
|
||||
class FloatVariable(FloatSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class IntegerVariable(IntegerSegment, Variable):
|
||||
class IntegerVariable(IntegerSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class ObjectVariable(ObjectSegment, Variable):
|
||||
class ObjectVariable(ObjectSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayVariable(ArraySegment, Variable):
|
||||
class ArrayVariable(ArraySegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
|
|||
return encrypter.obfuscated_token(self.value)
|
||||
|
||||
|
||||
class NoneVariable(NoneSegment, Variable):
|
||||
class NoneVariable(NoneSegment, VariableBase):
|
||||
value_type: SegmentType = SegmentType.NONE
|
||||
value: None = None
|
||||
|
||||
|
||||
class FileVariable(FileSegment, Variable):
|
||||
class FileVariable(FileSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class BooleanVariable(BooleanSegment, Variable):
|
||||
class BooleanVariable(BooleanSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
|
|||
value: Any
|
||||
|
||||
|
||||
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
|
||||
# Use `Variable` for type hinting when serialization is not required.
|
||||
# The `Variable` type is used to enable serialization and deserialization with Pydantic.
|
||||
# Use `VariableBase` for type hinting when serialization is not required.
|
||||
#
|
||||
# Note:
|
||||
# - All variants in `VariableUnion` must inherit from the `Variable` class.
|
||||
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||
VariableUnion: TypeAlias = Annotated[
|
||||
# - All variants in `Variable` must inherit from the `VariableBase` class.
|
||||
# - The union must include all non-abstract subclasses of `VariableBase`.
|
||||
Variable: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[NoneVariable, Tag(SegmentType.NONE)]
|
||||
| Annotated[StringVariable, Tag(SegmentType.STRING)]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import abc
|
||||
from typing import Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables import VariableBase
|
||||
|
||||
|
||||
class ConversationVariableUpdater(Protocol):
|
||||
|
|
@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, conversation_id: str, variable: "Variable"):
|
||||
def update(self, conversation_id: str, variable: "VariableBase"):
|
||||
"""
|
||||
Updates the value of the specified conversation variable in the underlying storage.
|
||||
|
||||
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
|
||||
:param variable: The `Variable` instance containing the updated value.
|
||||
:param variable: The `VariableBase` instance containing the updated value.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables.variables import Variable
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
|
|
@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
|
|||
class VariableUpdate(BaseModel):
|
||||
"""Represents a single variable update instruction."""
|
||||
|
||||
value: VariableUnion = Field(description="New variable value")
|
||||
value: Variable = Field(description="New variable value")
|
||||
|
||||
|
||||
class UpdateVariablesCommand(GraphEngineCommand):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing_extensions import TypeIs
|
|||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
|
|
@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
datetime,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
dict[str, VariableUnion],
|
||||
dict[str, Variable],
|
||||
LLMUsage,
|
||||
]
|
||||
],
|
||||
|
|
@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
|
@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
|
||||
return variable_mapping
|
||||
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
|
||||
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
|
||||
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
|
||||
parent_pool = self.graph_runtime_state.variable_pool
|
||||
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables import SegmentType, VariableBase
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
|
|
@ -73,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
|||
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
if not isinstance(original_variable, VariableBase):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
match self.node_data.write_mode:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables import SegmentType, VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
|
|
@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
|||
# ==================== Validation Part
|
||||
|
||||
# Check if variable exists
|
||||
if not isinstance(variable, Variable):
|
||||
if not isinstance(variable, VariableBase):
|
||||
raise VariableNotFoundError(variable_selector=item.variable_selector)
|
||||
|
||||
# Check if operation is supported
|
||||
|
|
@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
|||
|
||||
for selector in updated_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
if not isinstance(variable, VariableBase):
|
||||
raise VariableNotFoundError(variable_selector=selector)
|
||||
process_data[variable.name] = variable.value
|
||||
|
||||
|
|
@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
|||
def _handle_item(
|
||||
self,
|
||||
*,
|
||||
variable: Variable,
|
||||
variable: VariableBase,
|
||||
operation: Operation,
|
||||
value: Any,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables import Segment, SegmentGroup, VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, ObjectSegment
|
||||
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
|
||||
from core.variables.variables import RAGPipelineVariableInput, Variable
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
|
|
@ -32,7 +32,7 @@ class VariablePool(BaseModel):
|
|||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
|
||||
description="Variables mapping",
|
||||
default=defaultdict(dict),
|
||||
)
|
||||
|
|
@ -46,13 +46,13 @@ class VariablePool(BaseModel):
|
|||
description="System variables",
|
||||
default_factory=SystemVariable.empty,
|
||||
)
|
||||
environment_variables: Sequence[VariableUnion] = Field(
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list[VariableUnion],
|
||||
default_factory=list[Variable],
|
||||
)
|
||||
conversation_variables: Sequence[VariableUnion] = Field(
|
||||
conversation_variables: Sequence[Variable] = Field(
|
||||
description="Conversation variables.",
|
||||
default_factory=list[VariableUnion],
|
||||
default_factory=list[Variable],
|
||||
)
|
||||
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
|
||||
description="RAG pipeline variables.",
|
||||
|
|
@ -105,7 +105,7 @@ class VariablePool(BaseModel):
|
|||
f"got {len(selector)} elements"
|
||||
)
|
||||
|
||||
if isinstance(value, Variable):
|
||||
if isinstance(value, VariableBase):
|
||||
variable = value
|
||||
elif isinstance(value, Segment):
|
||||
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
|
||||
|
|
@ -114,9 +114,9 @@ class VariablePool(BaseModel):
|
|||
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
||||
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
self.variable_dictionary[node_id][name] = cast(Variable, variable)
|
||||
|
||||
@classmethod
|
||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import abc
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables import VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ class VariableLoader(Protocol):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
|
||||
"""Load variables based on the provided selectors. If the selectors are empty,
|
||||
this method should return an empty list.
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ class VariableLoader(Protocol):
|
|||
:param: selectors: a list of string list, each inner list should have at least two elements:
|
||||
- the first element is the node ID,
|
||||
- the second element is the variable name.
|
||||
:return: a list of Variable objects that match the provided selectors.
|
||||
:return: a list of VariableBase objects that match the provided selectors.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
|
|||
Serves as a placeholder when no variable loading is needed.
|
||||
"""
|
||||
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
|
||||
return []
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
|||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
@ -136,13 +137,11 @@ class WorkflowEntry:
|
|||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
node_config = workflow.get_node_config_by_id(node_id)
|
||||
node_config = dict(workflow.get_node_config_by_id(node_id))
|
||||
node_config_data = node_config.get("data", {})
|
||||
|
||||
# Get node class
|
||||
# Get node type
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
node_version = node_config_data.get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
|
|
@ -158,12 +157,12 @@ class WorkflowEntry:
|
|||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node = node_factory.create_node(node_config)
|
||||
node_cls = type(node)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import os
|
|||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -19,12 +20,17 @@ def is_enabled() -> bool:
|
|||
"""
|
||||
Check if logstore extension is enabled.
|
||||
|
||||
Logstore is considered enabled when:
|
||||
1. All required Aliyun SLS environment variables are set
|
||||
2. At least one repository configuration points to a logstore implementation
|
||||
|
||||
Returns:
|
||||
True if all required Aliyun SLS environment variables are set, False otherwise
|
||||
True if logstore should be initialized, False otherwise
|
||||
"""
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Check if Aliyun SLS connection parameters are configured
|
||||
required_vars = [
|
||||
"ALIYUN_SLS_ACCESS_KEY_ID",
|
||||
"ALIYUN_SLS_ACCESS_KEY_SECRET",
|
||||
|
|
@ -33,24 +39,32 @@ def is_enabled() -> bool:
|
|||
"ALIYUN_SLS_PROJECT_NAME",
|
||||
]
|
||||
|
||||
all_set = all(os.environ.get(var) for var in required_vars)
|
||||
sls_vars_set = all(os.environ.get(var) for var in required_vars)
|
||||
|
||||
if not all_set:
|
||||
logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
|
||||
if not sls_vars_set:
|
||||
return False
|
||||
|
||||
return all_set
|
||||
# Check if any repository configuration points to logstore implementation
|
||||
repository_configs = [
|
||||
dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
|
||||
dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
|
||||
dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
|
||||
dify_config.API_WORKFLOW_RUN_REPOSITORY,
|
||||
]
|
||||
|
||||
uses_logstore = any("logstore" in config.lower() for config in repository_configs)
|
||||
|
||||
if not uses_logstore:
|
||||
return False
|
||||
|
||||
logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
|
||||
return True
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""
|
||||
Initialize logstore on application startup.
|
||||
|
||||
This function:
|
||||
1. Creates Aliyun SLS project if it doesn't exist
|
||||
2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
|
||||
3. Creates indexes with field configurations based on PostgreSQL table structures
|
||||
|
||||
This operation is idempotent and only executes once during application startup.
|
||||
If initialization fails, the application continues running without logstore features.
|
||||
|
||||
Args:
|
||||
app: The Dify application instance
|
||||
|
|
@ -58,17 +72,23 @@ def init_app(app: DifyApp):
|
|||
try:
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
|
||||
logger.info("Initializing logstore...")
|
||||
logger.info("Initializing Aliyun SLS Logstore...")
|
||||
|
||||
# Create logstore client and initialize project/logstores/indexes
|
||||
# Create logstore client and initialize resources
|
||||
logstore_client = AliyunLogStore()
|
||||
logstore_client.init_project_logstore()
|
||||
|
||||
# Attach to app for potential later use
|
||||
app.extensions["logstore"] = logstore_client
|
||||
|
||||
logger.info("Logstore initialized successfully")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize logstore")
|
||||
# Don't raise - allow application to continue even if logstore init fails
|
||||
# This ensures that the application can still run if logstore is misconfigured
|
||||
logger.exception(
|
||||
"Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
|
||||
"Application will continue but logstore features will NOT work.",
|
||||
os.environ.get("ALIYUN_SLS_ENDPOINT"),
|
||||
os.environ.get("ALIYUN_SLS_REGION"),
|
||||
os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
|
||||
os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
|
||||
)
|
||||
# Don't raise - allow application to continue even if logstore setup fails
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
|
|
@ -179,9 +180,18 @@ class AliyunLogStore:
|
|||
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
|
||||
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
|
||||
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
|
||||
self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
|
||||
self.log_enabled: bool = (
|
||||
os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
|
||||
or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
|
||||
)
|
||||
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
|
||||
|
||||
# Get timeout configuration
|
||||
check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
|
||||
|
||||
# Pre-check endpoint connectivity to prevent indefinite hangs
|
||||
self._check_endpoint_connectivity(self.endpoint, check_timeout)
|
||||
|
||||
# Initialize SDK client
|
||||
self.client = LogClient(
|
||||
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
|
||||
|
|
@ -199,6 +209,49 @@ class AliyunLogStore:
|
|||
|
||||
self.__class__._initialized = True
|
||||
|
||||
@staticmethod
|
||||
def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
|
||||
"""
|
||||
Check if the SLS endpoint is reachable before creating LogClient.
|
||||
Prevents indefinite hangs when the endpoint is unreachable.
|
||||
|
||||
Args:
|
||||
endpoint: SLS endpoint URL
|
||||
timeout: Connection timeout in seconds
|
||||
|
||||
Raises:
|
||||
ConnectionError: If endpoint is not reachable
|
||||
"""
|
||||
# Parse endpoint URL to extract hostname and port
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
|
||||
hostname = parsed_url.hostname
|
||||
port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
|
||||
|
||||
if not hostname:
|
||||
raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
|
||||
|
||||
sock = None
|
||||
try:
|
||||
# Create socket and set timeout
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(timeout)
|
||||
sock.connect((hostname, port))
|
||||
except Exception as e:
|
||||
# Catch all exceptions and provide clear error message
|
||||
error_type = type(e).__name__
|
||||
raise ConnectionError(
|
||||
f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
|
||||
) from e
|
||||
finally:
|
||||
# Ensure socket is properly closed
|
||||
if sock:
|
||||
try:
|
||||
sock.close()
|
||||
except Exception: # noqa: S110
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
@property
|
||||
def supports_pg_protocol(self) -> bool:
|
||||
"""Check if PG protocol is supported and enabled."""
|
||||
|
|
@ -220,19 +273,16 @@ class AliyunLogStore:
|
|||
try:
|
||||
self._use_pg_protocol = self._pg_client.init_connection()
|
||||
if self._use_pg_protocol:
|
||||
logger.info("Successfully connected to project %s using PG protocol", self.project_name)
|
||||
logger.info("Using PG protocol for project %s", self.project_name)
|
||||
# Check if scan_index is enabled for all logstores
|
||||
self._check_and_disable_pg_if_scan_index_disabled()
|
||||
return True
|
||||
else:
|
||||
logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
|
||||
logger.info("Using SDK mode for project %s", self.project_name)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to establish PG connection for project %s: %s. Will use SDK mode.",
|
||||
self.project_name,
|
||||
str(e),
|
||||
)
|
||||
logger.info("Using SDK mode for project %s", self.project_name)
|
||||
logger.debug("PG connection details: %s", str(e))
|
||||
self._use_pg_protocol = False
|
||||
return False
|
||||
|
||||
|
|
@ -246,10 +296,6 @@ class AliyunLogStore:
|
|||
if self._use_pg_protocol:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Attempting delayed PG connection for newly created project %s ...",
|
||||
self.project_name,
|
||||
)
|
||||
self._attempt_pg_connection_init()
|
||||
self.__class__._pg_connection_timer = None
|
||||
|
||||
|
|
@ -284,11 +330,7 @@ class AliyunLogStore:
|
|||
if project_is_new:
|
||||
# For newly created projects, schedule delayed PG connection
|
||||
self._use_pg_protocol = False
|
||||
logger.info(
|
||||
"Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
|
||||
self.project_name,
|
||||
self.__class__._pg_connection_delay,
|
||||
)
|
||||
logger.info("Using SDK mode for project %s (newly created)", self.project_name)
|
||||
if self.__class__._pg_connection_timer is not None:
|
||||
self.__class__._pg_connection_timer.cancel()
|
||||
self.__class__._pg_connection_timer = threading.Timer(
|
||||
|
|
@ -299,7 +341,6 @@ class AliyunLogStore:
|
|||
self.__class__._pg_connection_timer.start()
|
||||
else:
|
||||
# For existing projects, attempt PG connection immediately
|
||||
logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
|
||||
self._attempt_pg_connection_init()
|
||||
|
||||
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
|
||||
|
|
@ -318,9 +359,9 @@ class AliyunLogStore:
|
|||
existing_config = self.get_existing_index_config(logstore_name)
|
||||
if existing_config and not existing_config.scan_index:
|
||||
logger.info(
|
||||
"Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
|
||||
"PG protocol requires scan_index to be enabled.",
|
||||
"Logstore %s requires scan_index enabled, using SDK mode for project %s",
|
||||
logstore_name,
|
||||
self.project_name,
|
||||
)
|
||||
self._use_pg_protocol = False
|
||||
# Close PG connection if it was initialized
|
||||
|
|
@ -748,7 +789,6 @@ class AliyunLogStore:
|
|||
reverse=reverse,
|
||||
)
|
||||
|
||||
# Log query info if SQLALCHEMY_ECHO is enabled
|
||||
if self.log_enabled:
|
||||
logger.info(
|
||||
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
|
||||
|
|
@ -770,7 +810,6 @@ class AliyunLogStore:
|
|||
for log in logs:
|
||||
result.append(log.get_contents())
|
||||
|
||||
# Log result count if SQLALCHEMY_ECHO is enabled
|
||||
if self.log_enabled:
|
||||
logger.info(
|
||||
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
|
||||
|
|
@ -845,7 +884,6 @@ class AliyunLogStore:
|
|||
query=full_query,
|
||||
)
|
||||
|
||||
# Log query info if SQLALCHEMY_ECHO is enabled
|
||||
if self.log_enabled:
|
||||
logger.info(
|
||||
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
|
||||
|
|
@ -853,8 +891,7 @@ class AliyunLogStore:
|
|||
self.project_name,
|
||||
from_time,
|
||||
to_time,
|
||||
query,
|
||||
sql,
|
||||
full_query,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -865,7 +902,6 @@ class AliyunLogStore:
|
|||
for log in logs:
|
||||
result.append(log.get_contents())
|
||||
|
||||
# Log result count if SQLALCHEMY_ECHO is enabled
|
||||
if self.log_enabled:
|
||||
logger.info(
|
||||
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@ from contextlib import contextmanager
|
|||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.pool
|
||||
from psycopg2 import InterfaceError, OperationalError
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
|
@ -16,11 +15,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AliyunLogStorePG:
|
||||
"""
|
||||
PostgreSQL protocol support for Aliyun SLS LogStore.
|
||||
|
||||
Handles PG connection pooling and operations for regions that support PG protocol.
|
||||
"""
|
||||
"""PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
|
||||
|
||||
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
|
||||
"""
|
||||
|
|
@ -36,24 +31,11 @@ class AliyunLogStorePG:
|
|||
self._access_key_secret = access_key_secret
|
||||
self._endpoint = endpoint
|
||||
self.project_name = project_name
|
||||
self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
|
||||
self._engine: Any = None # SQLAlchemy Engine
|
||||
self._use_pg_protocol = False
|
||||
|
||||
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
|
||||
"""
|
||||
Check if a TCP port is reachable using socket connection.
|
||||
|
||||
This provides a fast check before attempting full database connection,
|
||||
preventing long waits when connecting to unsupported regions.
|
||||
|
||||
Args:
|
||||
host: Hostname or IP address
|
||||
port: Port number
|
||||
timeout: Connection timeout in seconds (default: 2.0)
|
||||
|
||||
Returns:
|
||||
True if port is reachable, False otherwise
|
||||
"""
|
||||
"""Fast TCP port check to avoid long waits on unsupported regions."""
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(timeout)
|
||||
|
|
@ -65,166 +47,101 @@ class AliyunLogStorePG:
|
|||
return False
|
||||
|
||||
def init_connection(self) -> bool:
|
||||
"""
|
||||
Initialize PostgreSQL connection pool for SLS PG protocol support.
|
||||
|
||||
Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
|
||||
_use_pg_protocol to True and creates a connection pool. If connection fails
|
||||
(region doesn't support PG protocol or other errors), returns False.
|
||||
|
||||
Returns:
|
||||
True if PG protocol is supported and initialized, False otherwise
|
||||
"""
|
||||
"""Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
|
||||
try:
|
||||
# Extract hostname from endpoint (remove protocol if present)
|
||||
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
|
||||
|
||||
# Get pool configuration
|
||||
pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
|
||||
# Pool configuration
|
||||
pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
|
||||
max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
|
||||
pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
|
||||
pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
|
||||
|
||||
logger.debug(
|
||||
"Check PG protocol connection to SLS: host=%s, project=%s",
|
||||
pg_host,
|
||||
self.project_name,
|
||||
)
|
||||
logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
|
||||
|
||||
# Fast port connectivity check before attempting full connection
|
||||
# This prevents long waits when connecting to unsupported regions
|
||||
# Fast port check to avoid long waits
|
||||
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
|
||||
logger.info(
|
||||
"USE SDK mode for read/write operations, host=%s",
|
||||
pg_host,
|
||||
)
|
||||
logger.debug("Using SDK mode for host=%s", pg_host)
|
||||
return False
|
||||
|
||||
# Create connection pool
|
||||
self._pg_pool = psycopg2.pool.SimpleConnectionPool(
|
||||
minconn=1,
|
||||
maxconn=pg_max_connections,
|
||||
host=pg_host,
|
||||
port=5432,
|
||||
database=self.project_name,
|
||||
user=self._access_key_id,
|
||||
password=self._access_key_secret,
|
||||
sslmode="require",
|
||||
connect_timeout=5,
|
||||
application_name=f"Dify-{dify_config.project.version}",
|
||||
# Build connection URL
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
username = quote_plus(self._access_key_id)
|
||||
password = quote_plus(self._access_key_secret)
|
||||
database_url = (
|
||||
f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
|
||||
)
|
||||
|
||||
# Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
|
||||
# Connection pool creation success already indicates connectivity
|
||||
# Create SQLAlchemy engine with connection pool
|
||||
self._engine = create_engine(
|
||||
database_url,
|
||||
pool_size=pool_size,
|
||||
max_overflow=max_overflow,
|
||||
pool_recycle=pool_recycle,
|
||||
pool_pre_ping=pool_pre_ping,
|
||||
pool_timeout=30,
|
||||
connect_args={
|
||||
"connect_timeout": 5,
|
||||
"application_name": f"Dify-{dify_config.project.version}-fixautocommit",
|
||||
"keepalives": 1,
|
||||
"keepalives_idle": 60,
|
||||
"keepalives_interval": 10,
|
||||
"keepalives_count": 5,
|
||||
},
|
||||
)
|
||||
|
||||
self._use_pg_protocol = True
|
||||
logger.info(
|
||||
"PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
|
||||
"PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
|
||||
self.project_name,
|
||||
pool_size,
|
||||
pool_recycle,
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# PG connection failed - fallback to SDK mode
|
||||
self._use_pg_protocol = False
|
||||
if self._pg_pool:
|
||||
if self._engine:
|
||||
try:
|
||||
self._pg_pool.closeall()
|
||||
self._engine.dispose()
|
||||
except Exception:
|
||||
logger.debug("Failed to close PG connection pool during cleanup, ignoring")
|
||||
self._pg_pool = None
|
||||
logger.debug("Failed to dispose engine during cleanup, ignoring")
|
||||
self._engine = None
|
||||
|
||||
logger.info(
|
||||
"PG protocol connection failed (region may not support PG protocol): %s. "
|
||||
"Falling back to SDK mode for read/write operations.",
|
||||
str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
def _is_connection_valid(self, conn: Any) -> bool:
|
||||
"""
|
||||
Check if a connection is still valid.
|
||||
|
||||
Args:
|
||||
conn: psycopg2 connection object
|
||||
|
||||
Returns:
|
||||
True if connection is valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check if connection is closed
|
||||
if conn.closed:
|
||||
return False
|
||||
|
||||
# Quick ping test - execute a lightweight query
|
||||
# For SLS PG protocol, we can't use SELECT 1 without FROM,
|
||||
# so we just check the connection status
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.fetchone()
|
||||
return True
|
||||
except Exception:
|
||||
logger.debug("Using SDK mode for region: %s", str(e))
|
||||
return False
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
"""
|
||||
Context manager to get a PostgreSQL connection from the pool.
|
||||
"""Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
|
||||
if not self._engine:
|
||||
raise RuntimeError("SQLAlchemy engine is not initialized")
|
||||
|
||||
Automatically validates and refreshes stale connections.
|
||||
|
||||
Note: Aliyun SLS PG protocol does not support transactions, so we always
|
||||
use autocommit mode.
|
||||
|
||||
Yields:
|
||||
psycopg2 connection object
|
||||
|
||||
Raises:
|
||||
RuntimeError: If PG pool is not initialized
|
||||
"""
|
||||
if not self._pg_pool:
|
||||
raise RuntimeError("PG connection pool is not initialized")
|
||||
|
||||
conn = self._pg_pool.getconn()
|
||||
connection = self._engine.raw_connection()
|
||||
try:
|
||||
# Validate connection and get a fresh one if needed
|
||||
if not self._is_connection_valid(conn):
|
||||
logger.debug("Connection is stale, marking as bad and getting a new one")
|
||||
# Mark connection as bad and get a new one
|
||||
self._pg_pool.putconn(conn, close=True)
|
||||
conn = self._pg_pool.getconn()
|
||||
|
||||
# Aliyun SLS PG protocol does not support transactions, always use autocommit
|
||||
conn.autocommit = True
|
||||
yield conn
|
||||
connection.autocommit = True # SLS PG protocol does not support transactions
|
||||
yield connection
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
# Return connection to pool (or close if it's bad)
|
||||
if self._is_connection_valid(conn):
|
||||
self._pg_pool.putconn(conn)
|
||||
else:
|
||||
self._pg_pool.putconn(conn, close=True)
|
||||
connection.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the PostgreSQL connection pool."""
|
||||
if self._pg_pool:
|
||||
"""Dispose SQLAlchemy engine and close all connections."""
|
||||
if self._engine:
|
||||
try:
|
||||
self._pg_pool.closeall()
|
||||
logger.info("PG connection pool closed")
|
||||
self._engine.dispose()
|
||||
logger.info("SQLAlchemy engine disposed")
|
||||
except Exception:
|
||||
logger.exception("Failed to close PG connection pool")
|
||||
logger.exception("Failed to dispose engine")
|
||||
|
||||
def _is_retriable_error(self, error: Exception) -> bool:
|
||||
"""
|
||||
Check if an error is retriable (connection-related issues).
|
||||
|
||||
Args:
|
||||
error: Exception to check
|
||||
|
||||
Returns:
|
||||
True if the error is retriable, False otherwise
|
||||
"""
|
||||
# Retry on connection-related errors
|
||||
if isinstance(error, (OperationalError, InterfaceError)):
|
||||
"""Check if error is retriable (connection-related issues)."""
|
||||
# Check for psycopg2 connection errors directly
|
||||
if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
|
||||
return True
|
||||
|
||||
# Check error message for specific connection issues
|
||||
error_msg = str(error).lower()
|
||||
retriable_patterns = [
|
||||
"connection",
|
||||
|
|
@ -234,34 +151,18 @@ class AliyunLogStorePG:
|
|||
"reset by peer",
|
||||
"no route to host",
|
||||
"network",
|
||||
"operational error",
|
||||
"interface error",
|
||||
]
|
||||
return any(pattern in error_msg for pattern in retriable_patterns)
|
||||
|
||||
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
|
||||
"""
|
||||
Write log to SLS using PostgreSQL protocol with automatic retry.
|
||||
|
||||
Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
|
||||
writes with log_version field for versioning, same as SDK implementation.
|
||||
|
||||
Args:
|
||||
logstore: Name of the logstore table
|
||||
contents: List of (field_name, value) tuples
|
||||
log_enabled: Whether to enable logging
|
||||
|
||||
Raises:
|
||||
psycopg2.Error: If database operation fails after all retries
|
||||
"""
|
||||
"""Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
|
||||
if not contents:
|
||||
return
|
||||
|
||||
# Extract field names and values from contents
|
||||
fields = [field_name for field_name, _ in contents]
|
||||
values = [value for _, value in contents]
|
||||
|
||||
# Build INSERT statement with literal values
|
||||
# Note: Aliyun SLS PG protocol doesn't support parameterized queries,
|
||||
# so we need to use mogrify to safely create literal values
|
||||
field_list = ", ".join([f'"{field}"' for field in fields])
|
||||
|
||||
if log_enabled:
|
||||
|
|
@ -272,67 +173,40 @@ class AliyunLogStorePG:
|
|||
len(contents),
|
||||
)
|
||||
|
||||
# Retry configuration
|
||||
max_retries = 3
|
||||
retry_delay = 0.1 # Start with 100ms
|
||||
retry_delay = 0.1
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
# Use mogrify to safely convert values to SQL literals
|
||||
placeholders = ", ".join(["%s"] * len(fields))
|
||||
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
|
||||
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
|
||||
cursor.execute(insert_sql)
|
||||
# Success - exit retry loop
|
||||
return
|
||||
|
||||
except psycopg2.Error as e:
|
||||
# Check if error is retriable
|
||||
if not self._is_retriable_error(e):
|
||||
# Not a retriable error (e.g., data validation error), fail immediately
|
||||
logger.exception(
|
||||
"Failed to put logs to logstore %s via PG protocol (non-retriable error)",
|
||||
logstore,
|
||||
)
|
||||
logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
|
||||
raise
|
||||
|
||||
# Retriable error - log and retry if we have attempts left
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
"Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
|
||||
"Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
|
||||
logstore,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
str(e),
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2 # Exponential backoff
|
||||
retry_delay *= 2
|
||||
else:
|
||||
# Last attempt failed
|
||||
logger.exception(
|
||||
"Failed to put logs to logstore %s via PG protocol after %d attempts",
|
||||
logstore,
|
||||
max_retries,
|
||||
)
|
||||
logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
|
||||
raise
|
||||
|
||||
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Execute SQL query using PostgreSQL protocol with automatic retry.
|
||||
|
||||
Args:
|
||||
sql: SQL query string
|
||||
logstore: Name of the logstore (for logging purposes)
|
||||
log_enabled: Whether to enable logging
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
|
||||
Raises:
|
||||
psycopg2.Error: If database operation fails after all retries
|
||||
"""
|
||||
"""Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
|
||||
if log_enabled:
|
||||
logger.info(
|
||||
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
|
||||
|
|
@ -341,20 +215,16 @@ class AliyunLogStorePG:
|
|||
sql,
|
||||
)
|
||||
|
||||
# Retry configuration
|
||||
max_retries = 3
|
||||
retry_delay = 0.1 # Start with 100ms
|
||||
retry_delay = 0.1
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
|
||||
# Get column names from cursor description
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
|
||||
# Fetch all results and convert to list of dicts
|
||||
result = []
|
||||
for row in cursor.fetchall():
|
||||
row_dict = {}
|
||||
|
|
@ -372,36 +242,31 @@ class AliyunLogStorePG:
|
|||
return result
|
||||
|
||||
except psycopg2.Error as e:
|
||||
# Check if error is retriable
|
||||
if not self._is_retriable_error(e):
|
||||
# Not a retriable error (e.g., SQL syntax error), fail immediately
|
||||
logger.exception(
|
||||
"Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
|
||||
"Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
|
||||
logstore,
|
||||
sql,
|
||||
)
|
||||
raise
|
||||
|
||||
# Retriable error - log and retry if we have attempts left
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
"Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
|
||||
"Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
|
||||
logstore,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
str(e),
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2 # Exponential backoff
|
||||
retry_delay *= 2
|
||||
else:
|
||||
# Last attempt failed
|
||||
logger.exception(
|
||||
"Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
|
||||
"Failed to execute SQL on logstore %s after %d attempts: sql=%s",
|
||||
logstore,
|
||||
max_retries,
|
||||
sql,
|
||||
)
|
||||
raise
|
||||
|
||||
# This line should never be reached due to raise above, but makes type checker happy
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
"""
|
||||
LogStore repository utilities.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def safe_float(value: Any, default: float = 0.0) -> float:
|
||||
"""
|
||||
Safely convert a value to float, handling 'null' strings and None.
|
||||
"""
|
||||
if value is None or value in {"null", ""}:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
|
||||
def safe_int(value: Any, default: int = 0) -> int:
|
||||
"""
|
||||
Safely convert a value to int, handling 'null' strings and None.
|
||||
"""
|
||||
if value is None or value in {"null", ""}:
|
||||
return default
|
||||
try:
|
||||
return int(float(value))
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
|
@ -14,6 +14,8 @@ from typing import Any
|
|||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from extensions.logstore.repositories import safe_float, safe_int
|
||||
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
|
|
@ -52,9 +54,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
|
|||
model.created_by_role = data.get("created_by_role") or ""
|
||||
model.created_by = data.get("created_by") or ""
|
||||
|
||||
# Numeric fields with defaults
|
||||
model.index = int(data.get("index", 0))
|
||||
model.elapsed_time = float(data.get("elapsed_time", 0))
|
||||
model.index = safe_int(data.get("index", 0))
|
||||
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
|
||||
|
||||
# Optional fields
|
||||
model.workflow_run_id = data.get("workflow_run_id")
|
||||
|
|
@ -130,6 +131,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||
node_id,
|
||||
)
|
||||
try:
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_workflow_id = escape_identifier(workflow_id)
|
||||
escaped_node_id = escape_identifier(node_id)
|
||||
|
||||
# Check if PG protocol is supported
|
||||
if self.logstore_client.supports_pg_protocol:
|
||||
# Use PG protocol with SQL query (get latest version of each record)
|
||||
|
|
@ -138,10 +145,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
|
||||
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
|
||||
WHERE tenant_id = '{tenant_id}'
|
||||
AND app_id = '{app_id}'
|
||||
AND workflow_id = '{workflow_id}'
|
||||
AND node_id = '{node_id}'
|
||||
WHERE tenant_id = '{escaped_tenant_id}'
|
||||
AND app_id = '{escaped_app_id}'
|
||||
AND workflow_id = '{escaped_workflow_id}'
|
||||
AND node_id = '{escaped_node_id}'
|
||||
AND __time__ > 0
|
||||
) AS subquery WHERE rn = 1
|
||||
LIMIT 100
|
||||
|
|
@ -153,7 +160,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||
else:
|
||||
# Use SDK with LogStore query syntax
|
||||
query = (
|
||||
f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
|
||||
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
|
||||
f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
|
||||
)
|
||||
from_time = 0
|
||||
to_time = int(time.time()) # now
|
||||
|
|
@ -227,6 +235,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||
workflow_run_id,
|
||||
)
|
||||
try:
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_workflow_run_id = escape_identifier(workflow_run_id)
|
||||
|
||||
# Check if PG protocol is supported
|
||||
if self.logstore_client.supports_pg_protocol:
|
||||
# Use PG protocol with SQL query (get latest version of each record)
|
||||
|
|
@ -235,9 +248,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
|
||||
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
|
||||
WHERE tenant_id = '{tenant_id}'
|
||||
AND app_id = '{app_id}'
|
||||
AND workflow_run_id = '{workflow_run_id}'
|
||||
WHERE tenant_id = '{escaped_tenant_id}'
|
||||
AND app_id = '{escaped_app_id}'
|
||||
AND workflow_run_id = '{escaped_workflow_run_id}'
|
||||
AND __time__ > 0
|
||||
) AS subquery WHERE rn = 1
|
||||
LIMIT 1000
|
||||
|
|
@ -248,7 +261,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||
)
|
||||
else:
|
||||
# Use SDK with LogStore query syntax
|
||||
query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
|
||||
query = (
|
||||
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
|
||||
f"and workflow_run_id: {escaped_workflow_run_id}"
|
||||
)
|
||||
from_time = 0
|
||||
to_time = int(time.time()) # now
|
||||
|
||||
|
|
@ -313,16 +329,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||
"""
|
||||
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
|
||||
try:
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_execution_id = escape_identifier(execution_id)
|
||||
|
||||
# Check if PG protocol is supported
|
||||
if self.logstore_client.supports_pg_protocol:
|
||||
# Use PG protocol with SQL query (get latest version of record)
|
||||
tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
|
||||
if tenant_id:
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
|
||||
else:
|
||||
tenant_filter = ""
|
||||
|
||||
sql_query = f"""
|
||||
SELECT * FROM (
|
||||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
|
||||
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
|
||||
WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
|
||||
WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
|
||||
) AS subquery WHERE rn = 1
|
||||
LIMIT 1
|
||||
"""
|
||||
|
|
@ -332,10 +356,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||
)
|
||||
else:
|
||||
# Use SDK with LogStore query syntax
|
||||
# Note: Values must be quoted in LogStore query syntax to prevent injection
|
||||
if tenant_id:
|
||||
query = f"id: {execution_id} and tenant_id: {tenant_id}"
|
||||
query = (
|
||||
f"id:{escape_logstore_query_value(execution_id)} "
|
||||
f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
|
||||
)
|
||||
else:
|
||||
query = f"id: {execution_id}"
|
||||
query = f"id:{escape_logstore_query_value(execution_id)}"
|
||||
|
||||
from_time = 0
|
||||
to_time = int(time.time()) # now
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ Key Features:
|
|||
- Optimized deduplication using finished_at IS NOT NULL filter
|
||||
- Window functions only when necessary (running status queries)
|
||||
- Multi-tenant data isolation and security
|
||||
- SQL injection prevention via parameter escaping
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -22,6 +23,8 @@ from typing import Any, cast
|
|||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from extensions.logstore.repositories import safe_float, safe_int
|
||||
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
|
|
@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
|
|||
model.created_by_role = data.get("created_by_role") or ""
|
||||
model.created_by = data.get("created_by") or ""
|
||||
|
||||
# Numeric fields with defaults
|
||||
model.total_tokens = int(data.get("total_tokens", 0))
|
||||
model.total_steps = int(data.get("total_steps", 0))
|
||||
model.exceptions_count = int(data.get("exceptions_count", 0))
|
||||
model.total_tokens = safe_int(data.get("total_tokens", 0))
|
||||
model.total_steps = safe_int(data.get("total_steps", 0))
|
||||
model.exceptions_count = safe_int(data.get("exceptions_count", 0))
|
||||
|
||||
# Optional fields
|
||||
model.graph = data.get("graph")
|
||||
|
|
@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
|
|||
if model.finished_at and model.created_at:
|
||||
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
|
||||
else:
|
||||
model.elapsed_time = float(data.get("elapsed_time", 0))
|
||||
# Use safe conversion to handle 'null' strings and None values
|
||||
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
|
||||
|
||||
return model
|
||||
|
||||
|
|
@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
status,
|
||||
)
|
||||
# Convert triggered_from to list if needed
|
||||
if isinstance(triggered_from, WorkflowRunTriggeredFrom):
|
||||
if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
|
||||
triggered_from_list = [triggered_from]
|
||||
else:
|
||||
triggered_from_list = list(triggered_from)
|
||||
|
||||
# Build triggered_from filter
|
||||
triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
|
||||
# Build status filter
|
||||
status_filter = f"AND status='{status}'" if status else ""
|
||||
# Build triggered_from filter with escaped values
|
||||
# Support both enum and string values for triggered_from
|
||||
triggered_from_filter = " OR ".join(
|
||||
[
|
||||
f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
|
||||
for tf in triggered_from_list
|
||||
]
|
||||
)
|
||||
|
||||
# Build status filter with escaped value
|
||||
status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
|
||||
|
||||
# Build last_id filter for pagination
|
||||
# Note: This is simplified. In production, you'd need to track created_at from last record
|
||||
|
|
@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
SELECT * FROM (
|
||||
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND ({triggered_from_filter})
|
||||
{status_filter}
|
||||
{last_id_filter}
|
||||
|
|
@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
|
||||
|
||||
try:
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_run_id = escape_identifier(run_id)
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
|
||||
# Check if PG protocol is supported
|
||||
if self.logstore_client.supports_pg_protocol:
|
||||
# Use PG protocol with SQL query (get latest version of record)
|
||||
|
|
@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
|
||||
FROM "{AliyunLogStore.workflow_execution_logstore}"
|
||||
WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
|
||||
WHERE id = '{escaped_run_id}'
|
||||
AND tenant_id = '{escaped_tenant_id}'
|
||||
AND app_id = '{escaped_app_id}'
|
||||
AND __time__ > 0
|
||||
) AS subquery WHERE rn = 1
|
||||
LIMIT 100
|
||||
"""
|
||||
|
|
@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
)
|
||||
else:
|
||||
# Use SDK with LogStore query syntax
|
||||
query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
|
||||
# Note: Values must be quoted in LogStore query syntax to prevent injection
|
||||
query = (
|
||||
f"id:{escape_logstore_query_value(run_id)} "
|
||||
f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
|
||||
f"and app_id:{escape_logstore_query_value(app_id)}"
|
||||
)
|
||||
from_time = 0
|
||||
to_time = int(time.time()) # now
|
||||
|
||||
|
|
@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
|
||||
|
||||
try:
|
||||
# Escape parameter to prevent SQL injection
|
||||
escaped_run_id = escape_identifier(run_id)
|
||||
|
||||
# Check if PG protocol is supported
|
||||
if self.logstore_client.supports_pg_protocol:
|
||||
# Use PG protocol with SQL query (get latest version of record)
|
||||
|
|
@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
|
||||
FROM "{AliyunLogStore.workflow_execution_logstore}"
|
||||
WHERE id = '{run_id}' AND __time__ > 0
|
||||
WHERE id = '{escaped_run_id}' AND __time__ > 0
|
||||
) AS subquery WHERE rn = 1
|
||||
LIMIT 100
|
||||
"""
|
||||
|
|
@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
)
|
||||
else:
|
||||
# Use SDK with LogStore query syntax
|
||||
query = f"id: {run_id}"
|
||||
# Note: Values must be quoted in LogStore query syntax
|
||||
query = f"id:{escape_logstore_query_value(run_id)}"
|
||||
from_time = 0
|
||||
to_time = int(time.time()) # now
|
||||
|
||||
|
|
@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
triggered_from,
|
||||
status,
|
||||
)
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_triggered_from = escape_sql_string(triggered_from)
|
||||
|
||||
# Build time range filter
|
||||
time_filter = ""
|
||||
if time_range:
|
||||
|
|
@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
|
||||
# If status is provided, simple count
|
||||
if status:
|
||||
escaped_status = escape_sql_string(status)
|
||||
|
||||
if status == "running":
|
||||
# Running status requires window function
|
||||
sql = f"""
|
||||
|
|
@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
FROM (
|
||||
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND status='running'
|
||||
{time_filter}
|
||||
) t
|
||||
|
|
@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
sql = f"""
|
||||
SELECT COUNT(DISTINCT id) as count
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
AND status='{status}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND status='{escaped_status}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
"""
|
||||
|
|
@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
# No status filter - get counts grouped by status
|
||||
# Use optimized query for finished runs, separate query for running
|
||||
try:
|
||||
# Escape parameters (already escaped above, reuse variables)
|
||||
# Count finished runs grouped by status
|
||||
finished_sql = f"""
|
||||
SELECT status, COUNT(DISTINCT id) as count
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
GROUP BY status
|
||||
|
|
@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
FROM (
|
||||
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND status='running'
|
||||
{time_filter}
|
||||
) t
|
||||
|
|
@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
logger.debug(
|
||||
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
|
||||
)
|
||||
# Build time range filter
|
||||
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_triggered_from = escape_sql_string(triggered_from)
|
||||
|
||||
# Build time range filter (datetime.isoformat() is safe)
|
||||
time_filter = ""
|
||||
if start_date:
|
||||
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
|
||||
|
|
@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
sql = f"""
|
||||
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
GROUP BY date
|
||||
|
|
@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
app_id,
|
||||
triggered_from,
|
||||
)
|
||||
# Build time range filter
|
||||
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_triggered_from = escape_sql_string(triggered_from)
|
||||
|
||||
# Build time range filter (datetime.isoformat() is safe)
|
||||
time_filter = ""
|
||||
if start_date:
|
||||
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
|
||||
|
|
@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
sql = f"""
|
||||
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
GROUP BY date
|
||||
|
|
@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
app_id,
|
||||
triggered_from,
|
||||
)
|
||||
# Build time range filter
|
||||
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_triggered_from = escape_sql_string(triggered_from)
|
||||
|
||||
# Build time range filter (datetime.isoformat() is safe)
|
||||
time_filter = ""
|
||||
if start_date:
|
||||
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
|
||||
|
|
@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
sql = f"""
|
||||
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
GROUP BY date
|
||||
|
|
@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
app_id,
|
||||
triggered_from,
|
||||
)
|
||||
# Build time range filter
|
||||
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_tenant_id = escape_identifier(tenant_id)
|
||||
escaped_app_id = escape_identifier(app_id)
|
||||
escaped_triggered_from = escape_sql_string(triggered_from)
|
||||
|
||||
# Build time range filter (datetime.isoformat() is safe)
|
||||
time_filter = ""
|
||||
if start_date:
|
||||
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
|
||||
|
|
@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
created_by,
|
||||
COUNT(DISTINCT id) AS interactions
|
||||
FROM {AliyunLogStore.workflow_execution_logstore}
|
||||
WHERE tenant_id='{tenant_id}'
|
||||
AND app_id='{app_id}'
|
||||
AND triggered_from='{triggered_from}'
|
||||
WHERE tenant_id='{escaped_tenant_id}'
|
||||
AND app_id='{escaped_app_id}'
|
||||
AND triggered_from='{escaped_triggered_from}'
|
||||
AND finished_at IS NOT NULL
|
||||
{time_filter}
|
||||
GROUP BY date, created_by
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
|
|||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.entities import WorkflowExecution
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import (
|
||||
|
|
@ -22,18 +23,6 @@ from models.enums import WorkflowRunTriggeredFrom
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_serializable(obj):
|
||||
"""
|
||||
Convert non-JSON-serializable objects into JSON-compatible formats.
|
||||
|
||||
- Uses `to_dict()` if it's a callable method.
|
||||
- Falls back to string representation.
|
||||
"""
|
||||
if hasattr(obj, "to_dict") and callable(obj.to_dict):
|
||||
return obj.to_dict()
|
||||
return str(obj)
|
||||
|
||||
|
||||
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -79,7 +68,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
|
||||
# Control flag for dual-write (write to both LogStore and SQL database)
|
||||
# Set to True to enable dual-write for safe migration, False to use LogStore only
|
||||
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
|
||||
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
|
||||
|
||||
# Control flag for whether to write the `graph` field to LogStore.
|
||||
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
|
||||
|
|
@ -113,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
# Generate log_version as nanosecond timestamp for record versioning
|
||||
log_version = str(time.time_ns())
|
||||
|
||||
# Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
|
||||
logstore_model = [
|
||||
("id", domain_model.id_),
|
||||
("log_version", log_version), # Add log_version field for append-only writes
|
||||
|
|
@ -127,19 +119,19 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
("version", domain_model.workflow_version),
|
||||
(
|
||||
"graph",
|
||||
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
|
||||
json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
|
||||
if domain_model.graph and self._enable_put_graph_field
|
||||
else "{}",
|
||||
),
|
||||
(
|
||||
"inputs",
|
||||
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
|
||||
json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
|
||||
if domain_model.inputs
|
||||
else "{}",
|
||||
),
|
||||
(
|
||||
"outputs",
|
||||
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
|
||||
json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
|
||||
if domain_model.outputs
|
||||
else "{}",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ from core.workflow.enums import NodeType
|
|||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from extensions.logstore.repositories import safe_float, safe_int
|
||||
from extensions.logstore.sql_escape import escape_identifier
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import (
|
||||
Account,
|
||||
|
|
@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
|
|||
node_execution_id=data.get("node_execution_id"),
|
||||
workflow_id=data.get("workflow_id", ""),
|
||||
workflow_execution_id=data.get("workflow_run_id"),
|
||||
index=int(data.get("index", 0)),
|
||||
index=safe_int(data.get("index", 0)),
|
||||
predecessor_node_id=data.get("predecessor_node_id"),
|
||||
node_id=data.get("node_id", ""),
|
||||
node_type=NodeType(data.get("node_type", "start")),
|
||||
|
|
@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
|
|||
outputs=outputs,
|
||||
status=status,
|
||||
error=data.get("error"),
|
||||
elapsed_time=float(data.get("elapsed_time", 0.0)),
|
||||
elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
|
||||
metadata=domain_metadata,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
|
|
@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||
|
||||
# Control flag for dual-write (write to both LogStore and SQL database)
|
||||
# Set to True to enable dual-write for safe migration, False to use LogStore only
|
||||
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
|
||||
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
|
||||
|
||||
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
|
||||
logger.debug(
|
||||
|
|
@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||
Save or update the inputs, process_data, or outputs associated with a specific
|
||||
node_execution record.
|
||||
|
||||
For LogStore implementation, this is similar to save() since we always write
|
||||
complete records. We append a new record with updated data fields.
|
||||
For LogStore implementation, this is a no-op for the LogStore write because save()
|
||||
already writes all fields including inputs, process_data, and outputs. The caller
|
||||
typically calls save() first to persist status/metadata, then calls save_execution_data()
|
||||
to persist data fields. Since LogStore writes complete records atomically, we don't
|
||||
need a separate write here to avoid duplicate records.
|
||||
|
||||
However, if dual-write is enabled, we still need to call the SQL repository's
|
||||
save_execution_data() method to properly update the SQL database.
|
||||
|
||||
Args:
|
||||
execution: The NodeExecution instance with data to save
|
||||
"""
|
||||
logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
|
||||
# In LogStore, we simply write a new complete record with the data
|
||||
# The log_version timestamp will ensure this is treated as the latest version
|
||||
self.save(execution)
|
||||
logger.debug(
|
||||
"save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
|
||||
execution.id,
|
||||
execution.node_execution_id,
|
||||
)
|
||||
# No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
|
||||
# Calling save() again would create a duplicate record in the append-only LogStore
|
||||
|
||||
# Dual-write to SQL database if enabled (for safe migration)
|
||||
if self._enable_dual_write:
|
||||
try:
|
||||
self.sql_repository.save_execution_data(execution)
|
||||
logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
|
||||
except Exception:
|
||||
logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
|
||||
# Don't raise - LogStore write succeeded, SQL is just a backup
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
|
|
@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all NodeExecution instances for a specific workflow run.
|
||||
Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
|
||||
This ensures we only get the final version of each node execution.
|
||||
Uses LogStore SQL query with window function to get the latest version of each node execution.
|
||||
This ensures we only get the most recent version of each node execution record.
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
|
|
@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||
A list of NodeExecution instances
|
||||
|
||||
Note:
|
||||
This method filters by finished_at IS NOT NULL to avoid duplicates from
|
||||
version updates. For complete history including intermediate states,
|
||||
a different query strategy would be needed.
|
||||
This method uses ROW_NUMBER() window function partitioned by node_execution_id
|
||||
to get the latest version (highest log_version) of each node execution.
|
||||
"""
|
||||
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
|
||||
# Build SQL query with deduplication using finished_at IS NOT NULL
|
||||
# This optimization avoids window functions for common case where we only
|
||||
# want the final state of each node execution
|
||||
# Build SQL query with deduplication using window function
|
||||
# ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
|
||||
# ensures we get the latest version of each node execution
|
||||
|
||||
# Build ORDER BY clause
|
||||
# Escape parameters to prevent SQL injection
|
||||
escaped_workflow_run_id = escape_identifier(workflow_run_id)
|
||||
escaped_tenant_id = escape_identifier(self._tenant_id)
|
||||
|
||||
# Build ORDER BY clause for outer query
|
||||
order_clause = ""
|
||||
if order_config and order_config.order_by:
|
||||
order_fields = []
|
||||
|
|
@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||
if order_fields:
|
||||
order_clause = "ORDER BY " + ", ".join(order_fields)
|
||||
|
||||
sql = f"""
|
||||
SELECT *
|
||||
FROM {AliyunLogStore.workflow_node_execution_logstore}
|
||||
WHERE workflow_run_id='{workflow_run_id}'
|
||||
AND tenant_id='{self._tenant_id}'
|
||||
AND finished_at IS NOT NULL
|
||||
"""
|
||||
|
||||
# Build app_id filter for subquery
|
||||
app_id_filter = ""
|
||||
if self._app_id:
|
||||
sql += f" AND app_id='{self._app_id}'"
|
||||
escaped_app_id = escape_identifier(self._app_id)
|
||||
app_id_filter = f" AND app_id='{escaped_app_id}'"
|
||||
|
||||
# Use window function to get latest version of each node execution
|
||||
sql = f"""
|
||||
SELECT * FROM (
|
||||
SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
|
||||
FROM {AliyunLogStore.workflow_node_execution_logstore}
|
||||
WHERE workflow_run_id='{escaped_workflow_run_id}'
|
||||
AND tenant_id='{escaped_tenant_id}'
|
||||
{app_id_filter}
|
||||
) t
|
||||
WHERE rn = 1
|
||||
"""
|
||||
|
||||
if order_clause:
|
||||
sql += f" {order_clause}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
SQL Escape Utility for LogStore Queries
|
||||
|
||||
This module provides escaping utilities to prevent injection attacks in LogStore queries.
|
||||
|
||||
LogStore supports two query modes:
|
||||
1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
|
||||
2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
|
||||
|
||||
Key Security Concerns:
|
||||
- Prevent tenant A from accessing tenant B's data via injection
|
||||
- SLS queries are read-only, so we focus on data access control
|
||||
- Different escaping strategies for SQL vs LogStore query syntax
|
||||
"""
|
||||
|
||||
|
||||
def escape_sql_string(value: str) -> str:
|
||||
"""
|
||||
Escape a string value for safe use in SQL queries.
|
||||
|
||||
This function escapes single quotes by doubling them, which is the standard
|
||||
SQL escaping method. This prevents SQL injection by ensuring that user input
|
||||
cannot break out of string literals.
|
||||
|
||||
Args:
|
||||
value: The string value to escape
|
||||
|
||||
Returns:
|
||||
Escaped string safe for use in SQL queries
|
||||
|
||||
Examples:
|
||||
>>> escape_sql_string("normal_value")
|
||||
"normal_value"
|
||||
>>> escape_sql_string("value' OR '1'='1")
|
||||
"value'' OR ''1''=''1"
|
||||
>>> escape_sql_string("tenant's_id")
|
||||
"tenant''s_id"
|
||||
|
||||
Security:
|
||||
- Prevents breaking out of string literals
|
||||
- Stops injection attacks like: ' OR '1'='1
|
||||
- Protects against cross-tenant data access
|
||||
"""
|
||||
if not value:
|
||||
return value
|
||||
|
||||
# Escape single quotes by doubling them (standard SQL escaping)
|
||||
# This prevents breaking out of string literals in SQL queries
|
||||
return value.replace("'", "''")
|
||||
|
||||
|
||||
def escape_identifier(value: str) -> str:
|
||||
"""
|
||||
Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
|
||||
|
||||
This function is for PG protocol mode (SQL syntax).
|
||||
For SDK mode, use escape_logstore_query_value() instead.
|
||||
|
||||
Args:
|
||||
value: The identifier value to escape
|
||||
|
||||
Returns:
|
||||
Escaped identifier safe for use in SQL queries
|
||||
|
||||
Examples:
|
||||
>>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
|
||||
"550e8400-e29b-41d4-a716-446655440000"
|
||||
>>> escape_identifier("tenant_id' OR '1'='1")
|
||||
"tenant_id'' OR ''1''=''1"
|
||||
|
||||
Security:
|
||||
- Prevents SQL injection via identifiers
|
||||
- Stops cross-tenant access attempts
|
||||
- Works for UUIDs, alphanumeric IDs, and similar identifiers
|
||||
"""
|
||||
# For identifiers, use the same escaping as strings
|
||||
# This is simple and effective for preventing injection
|
||||
return escape_sql_string(value)
|
||||
|
||||
|
||||
def escape_logstore_query_value(value: str) -> str:
|
||||
"""
|
||||
Escape value for LogStore query syntax (SDK mode).
|
||||
|
||||
LogStore query syntax rules:
|
||||
1. Keywords (and/or/not) are case-insensitive
|
||||
2. Single quotes are ordinary characters (no special meaning)
|
||||
3. Double quotes wrap values: key:"value"
|
||||
4. Backslash is the escape character:
|
||||
- \" for double quote inside value
|
||||
- \\ for backslash itself
|
||||
5. Parentheses can change query structure
|
||||
|
||||
To prevent injection:
|
||||
- Wrap value in double quotes to treat special chars as literals
|
||||
- Escape backslashes and double quotes using backslash
|
||||
|
||||
Args:
|
||||
value: The value to escape for LogStore query syntax
|
||||
|
||||
Returns:
|
||||
Quoted and escaped value safe for LogStore query syntax (includes the quotes)
|
||||
|
||||
Examples:
|
||||
>>> escape_logstore_query_value("normal_value")
|
||||
'"normal_value"'
|
||||
>>> escape_logstore_query_value("value or field:evil")
|
||||
'"value or field:evil"' # 'or' and ':' are now literals
|
||||
>>> escape_logstore_query_value('value"test')
|
||||
'"value\\"test"' # Internal double quote escaped
|
||||
>>> escape_logstore_query_value('value\\test')
|
||||
'"value\\\\test"' # Backslash escaped
|
||||
|
||||
Security:
|
||||
- Prevents injection via and/or/not keywords
|
||||
- Prevents injection via colons (:)
|
||||
- Prevents injection via parentheses
|
||||
- Protects against cross-tenant data access
|
||||
|
||||
Note:
|
||||
Escape order is critical: backslash first, then double quotes.
|
||||
Otherwise, we'd double-escape the escape character itself.
|
||||
"""
|
||||
if not value:
|
||||
return '""'
|
||||
|
||||
# IMPORTANT: Escape backslashes FIRST, then double quotes
|
||||
# This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
|
||||
escaped = value.replace("\\", "\\\\") # \ -> \\
|
||||
escaped = escaped.replace('"', '\\"') # " -> \"
|
||||
|
||||
# Wrap in double quotes to treat as literal string
|
||||
# This prevents and/or/not/:/() from being interpreted as operators
|
||||
return f'"{escaped}"'
|
||||
|
|
@ -38,7 +38,7 @@ from core.variables.variables import (
|
|||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
VariableBase,
|
||||
)
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
|
|
@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
|
||||
if not mapping.get("name"):
|
||||
raise VariableError("missing name")
|
||||
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
|
||||
|
||||
|
||||
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
|
||||
if not mapping.get("name"):
|
||||
raise VariableError("missing name")
|
||||
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
|
||||
|
||||
|
||||
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
|
||||
if not mapping.get("variable"):
|
||||
raise VariableError("missing variable")
|
||||
return mapping["variable"]
|
||||
|
||||
|
||||
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
||||
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
|
||||
"""
|
||||
This factory function is used to create the environment variable or the conversation variable,
|
||||
not support the File type.
|
||||
|
|
@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
|||
if (value := mapping.get("value")) is None:
|
||||
raise VariableError("missing value")
|
||||
|
||||
result: Variable
|
||||
result: VariableBase
|
||||
match value_type:
|
||||
case SegmentType.STRING:
|
||||
result = StringVariable.model_validate(mapping)
|
||||
|
|
@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
|||
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
|
||||
if not result.selector:
|
||||
result = result.model_copy(update={"selector": selector})
|
||||
return cast(Variable, result)
|
||||
return cast(VariableBase, result)
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
|
|
@ -285,8 +285,8 @@ def segment_to_variable(
|
|||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str = "",
|
||||
) -> Variable:
|
||||
if isinstance(segment, Variable):
|
||||
) -> VariableBase:
|
||||
if isinstance(segment, VariableBase):
|
||||
return segment
|
||||
name = name or selector[-1]
|
||||
id = id or str(uuid4())
|
||||
|
|
@ -297,7 +297,7 @@ def segment_to_variable(
|
|||
|
||||
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
|
||||
return cast(
|
||||
Variable,
|
||||
VariableBase,
|
||||
variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import fields
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.variables import SecretVariable, SegmentType, Variable
|
||||
from core.variables import SecretVariable, SegmentType, VariableBase
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
|
|
@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
|
|||
"value_type": value.value_type.value,
|
||||
"description": value.description,
|
||||
}
|
||||
if isinstance(value, Variable):
|
||||
if isinstance(value, VariableBase):
|
||||
return {
|
||||
"id": value.id,
|
||||
"name": value.name,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
|
@ -46,7 +44,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
|
||||
from core.helper import encrypter
|
||||
from core.variables import SecretVariable, Segment, SegmentType, Variable
|
||||
from core.variables import SecretVariable, Segment, SegmentType, VariableBase
|
||||
from factories import variable_factory
|
||||
from libs import helper
|
||||
|
||||
|
|
@ -69,7 +67,7 @@ class WorkflowType(StrEnum):
|
|||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> WorkflowType:
|
||||
def value_of(cls, value: str) -> "WorkflowType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
|
@ -82,7 +80,7 @@ class WorkflowType(StrEnum):
|
|||
raise ValueError(f"invalid workflow type value {value}")
|
||||
|
||||
@classmethod
|
||||
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
|
||||
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
|
||||
"""
|
||||
Get workflow type from app mode.
|
||||
|
||||
|
|
@ -178,12 +176,12 @@ class Workflow(Base): # bug
|
|||
graph: str,
|
||||
features: str,
|
||||
created_by: str,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
environment_variables: Sequence[VariableBase],
|
||||
conversation_variables: Sequence[VariableBase],
|
||||
rag_pipeline_variables: list[dict],
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> Workflow:
|
||||
) -> "Workflow":
|
||||
workflow = Workflow()
|
||||
workflow.id = str(uuid4())
|
||||
workflow.tenant_id = tenant_id
|
||||
|
|
@ -447,7 +445,7 @@ class Workflow(Base): # bug
|
|||
|
||||
# decrypt secret variables value
|
||||
def decrypt_func(
|
||||
var: Variable,
|
||||
var: VariableBase,
|
||||
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
|
||||
if isinstance(var, SecretVariable):
|
||||
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||
|
|
@ -463,7 +461,7 @@ class Workflow(Base): # bug
|
|||
return decrypted_results
|
||||
|
||||
@environment_variables.setter
|
||||
def environment_variables(self, value: Sequence[Variable]):
|
||||
def environment_variables(self, value: Sequence[VariableBase]):
|
||||
if not value:
|
||||
self._environment_variables = "{}"
|
||||
return
|
||||
|
|
@ -487,7 +485,7 @@ class Workflow(Base): # bug
|
|||
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
|
||||
|
||||
# encrypt secret variables value
|
||||
def encrypt_func(var: Variable) -> Variable:
|
||||
def encrypt_func(var: VariableBase) -> VariableBase:
|
||||
if isinstance(var, SecretVariable):
|
||||
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||
else:
|
||||
|
|
@ -517,7 +515,7 @@ class Workflow(Base): # bug
|
|||
return result
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[Variable]:
|
||||
def conversation_variables(self) -> Sequence[VariableBase]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
|
|
@ -527,7 +525,7 @@ class Workflow(Base): # bug
|
|||
return results
|
||||
|
||||
@conversation_variables.setter
|
||||
def conversation_variables(self, value: Sequence[Variable]):
|
||||
def conversation_variables(self, value: Sequence[VariableBase]):
|
||||
self._conversation_variables = json.dumps(
|
||||
{var.name: var.model_dump() for var in value},
|
||||
ensure_ascii=False,
|
||||
|
|
@ -622,7 +620,7 @@ class WorkflowRun(Base):
|
|||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
|
||||
pause: Mapped[WorkflowPause | None] = orm.relationship(
|
||||
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
|
||||
"WorkflowPause",
|
||||
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
|
||||
uselist=False,
|
||||
|
|
@ -692,7 +690,7 @@ class WorkflowRun(Base):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
|
||||
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
|
||||
return cls(
|
||||
id=data.get("id"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
|
|
@ -844,7 +842,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||
created_by: Mapped[str] = mapped_column(StringUUID)
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
|
||||
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
|
||||
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
|
||||
"WorkflowNodeExecutionOffload",
|
||||
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
|
||||
uselist=True,
|
||||
|
|
@ -854,13 +852,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||
|
||||
@staticmethod
|
||||
def preload_offload_data(
|
||||
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
|
||||
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
|
||||
):
|
||||
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
|
||||
|
||||
@staticmethod
|
||||
def preload_offload_data_and_files(
|
||||
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
|
||||
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
|
||||
):
|
||||
return query.options(
|
||||
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
|
||||
|
|
@ -935,7 +933,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||
)
|
||||
return extras
|
||||
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
||||
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
||||
|
||||
@property
|
||||
|
|
@ -1049,7 +1047,7 @@ class WorkflowNodeExecutionOffload(Base):
|
|||
back_populates="offload_data",
|
||||
)
|
||||
|
||||
file: Mapped[UploadFile | None] = orm.relationship(
|
||||
file: Mapped[Optional["UploadFile"]] = orm.relationship(
|
||||
foreign_keys=[file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
|
|
@ -1067,7 +1065,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
|
|||
INSTALLED_APP = "installed-app"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
|
||||
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
|
@ -1184,7 +1182,7 @@ class ConversationVariable(TypeBase):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
|
||||
def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
|
||||
obj = cls(
|
||||
id=variable.id,
|
||||
app_id=app_id,
|
||||
|
|
@ -1193,7 +1191,7 @@ class ConversationVariable(TypeBase):
|
|||
)
|
||||
return obj
|
||||
|
||||
def to_variable(self) -> Variable:
|
||||
def to_variable(self) -> VariableBase:
|
||||
mapping = json.loads(self.data)
|
||||
return variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
|
||||
|
|
@ -1337,7 +1335,7 @@ class WorkflowDraftVariable(Base):
|
|||
)
|
||||
|
||||
# Relationship to WorkflowDraftVariableFile
|
||||
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
|
||||
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
|
||||
foreign_keys=[file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
|
|
@ -1507,7 +1505,7 @@ class WorkflowDraftVariable(Base):
|
|||
node_execution_id: str | None,
|
||||
description: str = "",
|
||||
file_id: str | None = None,
|
||||
) -> WorkflowDraftVariable:
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = WorkflowDraftVariable()
|
||||
variable.id = str(uuid4())
|
||||
variable.created_at = naive_utc_now()
|
||||
|
|
@ -1530,7 +1528,7 @@ class WorkflowDraftVariable(Base):
|
|||
name: str,
|
||||
value: Segment,
|
||||
description: str = "",
|
||||
) -> WorkflowDraftVariable:
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
node_id=CONVERSATION_VARIABLE_NODE_ID,
|
||||
|
|
@ -1551,7 +1549,7 @@ class WorkflowDraftVariable(Base):
|
|||
value: Segment,
|
||||
node_execution_id: str,
|
||||
editable: bool = False,
|
||||
) -> WorkflowDraftVariable:
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||
|
|
@ -1574,7 +1572,7 @@ class WorkflowDraftVariable(Base):
|
|||
visible: bool = True,
|
||||
editable: bool = True,
|
||||
file_id: str | None = None,
|
||||
) -> WorkflowDraftVariable:
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
|
|
@ -1670,7 +1668,7 @@ class WorkflowDraftVariableFile(Base):
|
|||
)
|
||||
|
||||
# Relationship to UploadFile
|
||||
upload_file: Mapped[UploadFile] = orm.relationship(
|
||||
upload_file: Mapped["UploadFile"] = orm.relationship(
|
||||
foreign_keys=[upload_file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
|
|
@ -1737,7 +1735,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
|
|||
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
||||
|
||||
# Relationship to WorkflowRun
|
||||
workflow_run: Mapped[WorkflowRun] = orm.relationship(
|
||||
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
|
||||
foreign_keys=[workflow_run_id],
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
|
|
@ -1793,7 +1791,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
|
||||
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
|
||||
if isinstance(pause_reason, HumanInputRequired):
|
||||
return cls(
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.11.2"
|
||||
version = "1.11.3"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from hashlib import sha256
|
|||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
|
|
@ -748,6 +748,21 @@ class AccountService:
|
|||
cls.email_code_login_rate_limiter.increment_rate_limit(email)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
|
||||
"""
|
||||
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
|
||||
|
||||
This keeps backward compatibility for older records that stored uppercase emails while the
|
||||
rest of the system gradually normalizes new inputs.
|
||||
"""
|
||||
query_session = session or db.session
|
||||
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
return TokenManager.get_token_data(token, "email_code_login")
|
||||
|
|
@ -1363,16 +1378,22 @@ class RegisterService:
|
|||
if not inviter:
|
||||
raise ValueError("Inviter is required")
|
||||
|
||||
normalized_email = email.lower()
|
||||
|
||||
"""Invite new member"""
|
||||
with Session(db.engine) as session:
|
||||
account = session.query(Account).filter_by(email=email).first()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
name = email.split("@")[0]
|
||||
name = normalized_email.split("@")[0]
|
||||
|
||||
account = cls.register(
|
||||
email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
|
||||
email=normalized_email,
|
||||
name=name,
|
||||
language=language,
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
# Create new tenant member for invited tenant
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
|
|
@ -1394,7 +1415,7 @@ class RegisterService:
|
|||
# send email
|
||||
send_invite_member_mail_task.delay(
|
||||
language=language,
|
||||
to=email,
|
||||
to=account.email,
|
||||
token=token,
|
||||
inviter_name=inviter.name if inviter else "Dify",
|
||||
workspace_name=tenant.name,
|
||||
|
|
@ -1493,6 +1514,16 @@ class RegisterService:
|
|||
invitation: dict = json.loads(data)
|
||||
return invitation
|
||||
|
||||
@classmethod
|
||||
def get_invitation_with_case_fallback(
|
||||
cls, workspace_id: str | None, email: str | None, token: str
|
||||
) -> dict[str, Any] | None:
|
||||
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
|
||||
if invitation or not email or email == email.lower():
|
||||
return invitation
|
||||
normalized_email = email.lower()
|
||||
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
|
||||
|
||||
|
||||
def _generate_refresh_token(length: int = 64):
|
||||
token = secrets.token_hex(length)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from core.variables.variables import VariableBase
|
||||
from models import ConversationVariable
|
||||
|
||||
|
||||
|
|
@ -13,7 +13,7 @@ class ConversationVariableUpdater:
|
|||
def __init__(self, session_maker: sessionmaker[Session]) -> None:
|
||||
self._session_maker: sessionmaker[Session] = session_maker
|
||||
|
||||
def update(self, conversation_id: str, variable: Variable) -> None:
|
||||
def update(self, conversation_id: str, variable: VariableBase) -> None:
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
import logging
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseRequest:
|
||||
proxies: Mapping[str, str] | None = {
|
||||
|
|
@ -38,6 +43,15 @@ class BaseRequest:
|
|||
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
mounts = cls._build_mounts()
|
||||
|
||||
try:
|
||||
# ensure traceparent even when OTEL is disabled
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
headers["traceparent"] = traceparent
|
||||
except Exception:
|
||||
logger.debug("Failed to generate traceparent header", exc_info=True)
|
||||
|
||||
with httpx.Client(mounts=mounts) as client:
|
||||
response = client.request(method, url, json=json, params=params, headers=headers)
|
||||
return response.json()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
|
|||
from mimetypes import guess_type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import (
|
|||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import ProviderCredential
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.errors.plugin import PluginInstallationForbiddenError
|
||||
from services.feature_service import FeatureService, PluginInstallationScope
|
||||
|
|
@ -506,6 +509,33 @@ class PluginService:
|
|||
@staticmethod
|
||||
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
|
||||
manager = PluginInstaller()
|
||||
|
||||
# Get plugin info before uninstalling to delete associated credentials
|
||||
try:
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
|
||||
|
||||
if plugin:
|
||||
plugin_id = plugin.plugin_id
|
||||
logger.info("Deleting credentials for plugin: %s", plugin_id)
|
||||
|
||||
# Delete provider credentials that match this plugin
|
||||
credentials = db.session.scalars(
|
||||
select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == tenant_id,
|
||||
ProviderCredential.provider_name.like(f"{plugin_id}/%"),
|
||||
)
|
||||
).all()
|
||||
|
||||
for cred in credentials:
|
||||
db.session.delete(cred)
|
||||
|
||||
db.session.commit()
|
||||
logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete credentials: %s", e)
|
||||
# Continue with uninstall even if credential deletion fails
|
||||
|
||||
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from core.rag.entities.event import (
|
|||
)
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables.variables import Variable
|
||||
from core.variables.variables import VariableBase
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
|
|
@ -270,8 +270,8 @@ class RagPipelineService:
|
|||
graph: dict,
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
environment_variables: Sequence[VariableBase],
|
||||
conversation_variables: Sequence[VariableBase],
|
||||
rag_pipeline_variables: list,
|
||||
) -> Workflow:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from libs.passport import PassportService
|
|||
from libs.password import compare_password
|
||||
from models import Account, AccountStatus
|
||||
from models.model import App, EndUser, Site
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
|
|
@ -32,7 +33,7 @@ class WebAppAuthService:
|
|||
@staticmethod
|
||||
def authenticate(email: str, password: str) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
|
||||
|
|
@ -52,7 +53,7 @@ class WebAppAuthService:
|
|||
|
||||
@classmethod
|
||||
def get_user_through_email(cls, email: str):
|
||||
account = db.session.query(Account).where(Account.email == email).first()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
|
|||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.variables import Segment, StringSegment, Variable
|
||||
from core.variables import Segment, StringSegment, VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
|
|
@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
|
|||
# Application ID for which variables are being loaded.
|
||||
_app_id: str
|
||||
_tenant_id: str
|
||||
_fallback_variables: Sequence[Variable]
|
||||
_fallback_variables: Sequence[VariableBase]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
fallback_variables: Sequence[Variable] | None = None,
|
||||
fallback_variables: Sequence[VariableBase] | None = None,
|
||||
):
|
||||
self._engine = engine
|
||||
self._app_id = app_id
|
||||
|
|
@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
|
|||
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
|
||||
return (selector[0], selector[1])
|
||||
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
|
||||
if not selectors:
|
||||
return []
|
||||
|
||||
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
|
||||
variable_by_selector: dict[tuple[str, str], Variable] = {}
|
||||
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
|
||||
variable_by_selector: dict[tuple[str, str], VariableBase] = {}
|
||||
|
||||
with Session(bind=self._engine, expire_on_commit=False) as session:
|
||||
srv = WorkflowDraftVariableService(session)
|
||||
|
|
@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
|
|||
|
||||
return list(variable_by_selector.values())
|
||||
|
||||
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]:
|
||||
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
|
||||
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
|
||||
# and must remain synchronized with it.
|
||||
# Ideally, these should be co-located for better maintainability.
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
|||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.file import File
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables import VariableBase
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
|
|
@ -198,8 +198,8 @@ class WorkflowService:
|
|||
features: dict,
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
environment_variables: Sequence[VariableBase],
|
||||
conversation_variables: Sequence[VariableBase],
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
|
|
@ -1044,7 +1044,7 @@ def _setup_variable_pool(
|
|||
workflow: Workflow,
|
||||
node_type: NodeType,
|
||||
conversation_id: str,
|
||||
conversation_variables: list[Variable],
|
||||
conversation_variables: list[VariableBase],
|
||||
):
|
||||
# Only inject system variables for START node type.
|
||||
if node_type == NodeType.START or node_type.is_trigger_node:
|
||||
|
|
@ -1070,9 +1070,9 @@ def _setup_variable_pool(
|
|||
system_variables=system_variable,
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=cast(list[VariableUnion], conversation_variables), #
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
conversation_variables=cast(list[Variable], conversation_variables), #
|
||||
)
|
||||
|
||||
return variable_pool
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class TestActivateCheckApi:
|
|||
"tenant": tenant,
|
||||
}
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking valid invitation token.
|
||||
|
|
@ -66,7 +66,7 @@ class TestActivateCheckApi:
|
|||
assert response["data"]["workspace_id"] == "workspace-123"
|
||||
assert response["data"]["email"] == "invitee@example.com"
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_invalid_invitation_token(self, mock_get_invitation, app):
|
||||
"""
|
||||
Test checking invalid invitation token.
|
||||
|
|
@ -88,7 +88,7 @@ class TestActivateCheckApi:
|
|||
# Assert
|
||||
assert response["is_valid"] is False
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking token without workspace ID.
|
||||
|
|
@ -109,7 +109,7 @@ class TestActivateCheckApi:
|
|||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking token without email parameter.
|
||||
|
|
@ -130,6 +130,20 @@ class TestActivateCheckApi:
|
|||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation):
|
||||
"""Ensure token validation uses lowercase emails."""
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
|
||||
with app.test_request_context(
|
||||
"/activate/check?workspace_id=workspace-123&email=Invitee@Example.com&token=valid_token"
|
||||
):
|
||||
api = ActivateCheckApi()
|
||||
response = api.get()
|
||||
|
||||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
|
||||
|
||||
|
||||
class TestActivateApi:
|
||||
"""Test cases for account activation endpoint."""
|
||||
|
|
@ -212,7 +226,7 @@ class TestActivateApi:
|
|||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_activation_with_invalid_token(self, mock_get_invitation, app):
|
||||
"""
|
||||
Test account activation with invalid token.
|
||||
|
|
@ -241,7 +255,7 @@ class TestActivateApi:
|
|||
with pytest.raises(AlreadyActivateError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_sets_interface_theme(
|
||||
|
|
@ -290,7 +304,7 @@ class TestActivateApi:
|
|||
("es-ES", "Europe/Madrid"),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_with_different_locales(
|
||||
|
|
@ -336,7 +350,7 @@ class TestActivateApi:
|
|||
assert mock_account.interface_language == language
|
||||
assert mock_account.timezone == timezone
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_returns_success_response(
|
||||
|
|
@ -376,7 +390,7 @@ class TestActivateApi:
|
|||
# Assert
|
||||
assert response == {"result": "success"}
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_without_workspace_id(
|
||||
|
|
@ -415,3 +429,37 @@ class TestActivateApi:
|
|||
# Assert
|
||||
assert response["result"] == "success"
|
||||
mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_normalizes_email_before_lookup(
|
||||
self,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_account,
|
||||
):
|
||||
"""Ensure uppercase emails are normalized before lookup and revocation."""
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
|
||||
with app.test_request_context(
|
||||
"/activate",
|
||||
method="POST",
|
||||
json={
|
||||
"workspace_id": "workspace-123",
|
||||
"email": "Invitee@Example.com",
|
||||
"token": "valid_token",
|
||||
"name": "John Doe",
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
):
|
||||
api = ActivateApi()
|
||||
response = api.post()
|
||||
|
||||
assert response["result"] == "success"
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
|
||||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class TestAuthenticationSecurity:
|
|||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_invalid_email_with_registration_allowed(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
|
|
@ -67,7 +67,7 @@ class TestAuthenticationSecurity:
|
|||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_wrong_password_returns_error(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
|
||||
):
|
||||
|
|
@ -100,7 +100,7 @@ class TestAuthenticationSecurity:
|
|||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_invalid_email_with_registration_disabled(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,177 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.email_register import (
|
||||
EmailRegisterCheckApi,
|
||||
EmailRegisterResetApi,
|
||||
EmailRegisterSendEmailApi,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
class TestEmailRegisterSendEmailApi:
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
|
||||
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
|
||||
@patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_send_email_normalizes_and_falls_back(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_is_email_send_ip_limit,
|
||||
mock_is_freeze,
|
||||
mock_send_mail,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_is_freeze.return_value = False
|
||||
mock_account = MagicMock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register/send-email",
|
||||
method="POST",
|
||||
json={"email": "Invitee@Example.com", "language": "en-US"},
|
||||
):
|
||||
response = EmailRegisterSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_is_freeze.assert_called_once_with("invitee@example.com")
|
||||
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
|
||||
|
||||
class TestEmailRegisterCheckApi:
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_email_register_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.generate_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.add_email_register_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.AccountService.is_email_register_error_rate_limit")
|
||||
def test_validity_normalizes_email_before_checks(
|
||||
self,
|
||||
mock_rate_limit_check,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_rate_limit_check.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": "4321", "token": "token-123"},
|
||||
):
|
||||
response = EmailRegisterCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_rate_limit_check.assert_called_once_with("user@example.com")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="4321", additional_data={"phase": "register"}
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke.assert_called_once_with("token-123")
|
||||
|
||||
|
||||
class TestEmailRegisterResetApi:
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.login")
|
||||
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_reset_creates_account_with_normalized_email(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app,
|
||||
):
|
||||
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
|
||||
mock_create_account.return_value = MagicMock()
|
||||
token_pair = MagicMock()
|
||||
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
|
||||
mock_login.return_value = token_pair
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = None
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register",
|
||||
method="POST",
|
||||
json={"token": "token-123", "new_password": "ValidPass123!", "password_confirm": "ValidPass123!"},
|
||||
):
|
||||
response = EmailRegisterResetApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
|
||||
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!")
|
||||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
|
@ -0,0 +1,176 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
ForgotPasswordResetApi,
|
||||
ForgotPasswordSendEmailApi,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_send_normalizes_email(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token-123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
controller_features = SimpleNamespace(is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch(
|
||||
"controllers.console.auth.forgot_password.FeatureService.get_system_features",
|
||||
return_value=controller_features,
|
||||
),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=mock_account,
|
||||
email="user@example.com",
|
||||
language="zh-Hans",
|
||||
is_allow_register=True,
|
||||
)
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_extract_ip.assert_called_once()
|
||||
|
||||
|
||||
class TestForgotPasswordCheckApi:
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_check_normalizes_email(
|
||||
self,
|
||||
mock_rate_limit_check,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_rate_limit_check.return_value = False
|
||||
mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "ADMIN@Example.com", "code": "4321", "token": "token-123"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "admin@example.com", "token": "new-token"}
|
||||
mock_rate_limit_check.assert_called_once_with("admin@example.com")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"Admin@Example.com",
|
||||
code="4321",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("admin@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
|
||||
|
||||
class TestForgotPasswordResetApi:
|
||||
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_fetches_account_with_original_email(
|
||||
self,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_update_account,
|
||||
app,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_reset_data.assert_called_once_with("token-123")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_update_account.assert_called_once()
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
|
@ -76,7 +76,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
|
|
@ -120,7 +120,7 @@ class TestLoginApi:
|
|||
response = login_api.post()
|
||||
|
||||
# Assert
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!")
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None)
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
assert response.json["result"] == "success"
|
||||
|
|
@ -128,7 +128,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
|
|
@ -182,7 +182,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test login rejection when rate limit is exceeded.
|
||||
|
|
@ -230,7 +230,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
def test_login_fails_with_invalid_credentials(
|
||||
|
|
@ -269,7 +269,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
def test_login_fails_for_banned_account(
|
||||
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app
|
||||
|
|
@ -298,7 +298,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
|
|
@ -343,7 +343,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test login failure when invitation email doesn't match login email.
|
||||
|
|
@ -371,6 +371,52 @@ class TestLoginApi:
|
|||
with pytest.raises(InvalidEmailError):
|
||||
login_api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
|
||||
def test_login_retries_with_lowercase_email(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_login_service,
|
||||
mock_get_tenants,
|
||||
mock_add_rate_limit,
|
||||
mock_authenticate,
|
||||
mock_get_invitation,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""Test that login retries with lowercase email when uppercase lookup fails."""
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
mock_login_service.return_value = mock_token_pair
|
||||
|
||||
with app.test_request_context(
|
||||
"/login",
|
||||
method="POST",
|
||||
json={"email": "Upper@Example.com", "password": encode_password("ValidPass123!")},
|
||||
):
|
||||
response = LoginApi().post()
|
||||
|
||||
assert response.json["result"] == "success"
|
||||
assert mock_authenticate.call_args_list == [
|
||||
(("Upper@Example.com", "ValidPass123!", None), {}),
|
||||
(("upper@example.com", "ValidPass123!", None), {}),
|
||||
]
|
||||
mock_add_rate_limit.assert_not_called()
|
||||
mock_reset_rate_limit.assert_called_once_with("upper@example.com")
|
||||
|
||||
|
||||
class TestLogoutApi:
|
||||
"""Test cases for the LogoutApi endpoint."""
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from controllers.console.auth.oauth import (
|
|||
)
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
|
||||
|
|
@ -215,6 +216,34 @@ class TestOAuthCallback:
|
|||
assert status_code == 400
|
||||
assert response["error"] == expected_error
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_invitation_comparison_is_case_insensitive(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_register_service,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
):
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo(
|
||||
id="123", name="Test User", email="User@Example.com"
|
||||
)
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
mock_register_service.is_valid_invite_token.return_value = True
|
||||
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
|
||||
resource.get("github")
|
||||
|
||||
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("account_status", "expected_redirect"),
|
||||
[
|
||||
|
|
@ -395,12 +424,12 @@ class TestAccountGeneration:
|
|||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.Account")
|
||||
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.oauth.Session")
|
||||
@patch("controllers.console.auth.oauth.select")
|
||||
@patch("controllers.console.auth.oauth.Account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_get_account_by_openid_or_email(
|
||||
self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
|
||||
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
|
||||
):
|
||||
# Mock db.engine for Session creation
|
||||
mock_db.engine = MagicMock()
|
||||
|
|
@ -410,15 +439,31 @@ class TestAccountGeneration:
|
|||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
|
||||
mock_get_account.assert_not_called()
|
||||
|
||||
# Test fallback to email
|
||||
# Test fallback to email lookup
|
||||
mock_account_model.get_by_openid.return_value = None
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
|
||||
mock_session = MagicMock()
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_result = MagicMock()
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
|
||||
assert result == expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("allow_register", "existing_account", "should_create"),
|
||||
|
|
@ -466,6 +511,35 @@ class TestAccountGeneration:
|
|||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
)
|
||||
else:
|
||||
mock_register_service.register.assert_not_called()
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_register_with_lowercase_email(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app,
|
||||
):
|
||||
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
mock_register_service.register.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US"}):
|
||||
_generate_account("github", user_info)
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
|
|
|
|||
|
|
@ -28,6 +28,22 @@ from controllers.console.auth.forgot_password import (
|
|||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_session():
|
||||
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_session_cls.return_value.__exit__.return_value = None
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_db():
|
||||
with patch("controllers.console.auth.forgot_password.db") as mock_db:
|
||||
mock_db.engine = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
"""Test cases for sending password reset emails."""
|
||||
|
||||
|
|
@ -47,20 +63,16 @@ class TestForgotPasswordSendEmailApi:
|
|||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
|
||||
def test_send_reset_email_success(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_send_email,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
|
|
@ -75,11 +87,8 @@ class TestForgotPasswordSendEmailApi:
|
|||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "reset_token_123"
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
|
||||
|
|
@ -125,20 +134,16 @@ class TestForgotPasswordSendEmailApi:
|
|||
],
|
||||
)
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
|
||||
def test_send_reset_email_language_handling(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_send_email,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
|
|
@ -154,11 +159,8 @@ class TestForgotPasswordSendEmailApi:
|
|||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
|
||||
|
|
@ -229,8 +231,46 @@ class TestForgotPasswordCheckApi:
|
|||
assert response["email"] == "test@example.com"
|
||||
assert response["token"] == "new_token"
|
||||
mock_revoke_token.assert_called_once_with("old_token")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"test@example.com", code="123456", additional_data={"phase": "reset"}
|
||||
)
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
def test_verify_code_preserves_token_email_case(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_generate_token,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
|
||||
mock_generate_token.return_value = (None, "fresh-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "code": "999888", "token": "upper_token"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "fresh-token"}
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"User@Example.com", code="999888", additional_data={"phase": "reset"}
|
||||
)
|
||||
mock_revoke_token.assert_called_once_with("upper_token")
|
||||
mock_reset_rate_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
|
||||
|
|
@ -355,20 +395,16 @@ class TestForgotPasswordResetApi:
|
|||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
|
||||
def test_reset_password_success(
|
||||
self,
|
||||
mock_get_tenants,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
|
|
@ -383,11 +419,8 @@ class TestForgotPasswordResetApi:
|
|||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
|
||||
# Act
|
||||
|
|
@ -475,13 +508,11 @@ class TestForgotPasswordResetApi:
|
|||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
def test_reset_password_account_not_found(
|
||||
self, mock_select, mock_session, mock_revoke_token, mock_get_data, mock_forgot_db, mock_wraps_db, app
|
||||
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
|
||||
):
|
||||
"""
|
||||
Test password reset for non-existent account.
|
||||
|
|
@ -491,11 +522,8 @@ class TestForgotPasswordResetApi:
|
|||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.console.setup import SetupApi
|
||||
|
||||
|
||||
class TestSetupApi:
|
||||
def test_post_lowercases_email_before_register(self):
|
||||
"""Ensure setup registration normalizes email casing."""
|
||||
payload = {
|
||||
"email": "Admin@Example.com",
|
||||
"name": "Admin User",
|
||||
"password": "ValidPass123!",
|
||||
"language": "en-US",
|
||||
}
|
||||
setup_api = SetupApi(api=None)
|
||||
|
||||
mock_console_ns = SimpleNamespace(payload=payload)
|
||||
|
||||
with (
|
||||
patch("controllers.console.setup.console_ns", mock_console_ns),
|
||||
patch("controllers.console.setup.get_setup_status", return_value=False),
|
||||
patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0),
|
||||
patch("controllers.console.setup.get_init_validate_status", return_value=True),
|
||||
patch("controllers.console.setup.extract_remote_ip", return_value="127.0.0.1"),
|
||||
patch("controllers.console.setup.request", object()),
|
||||
patch("controllers.console.setup.RegisterService.setup") as mock_register,
|
||||
):
|
||||
response, status = setup_api.post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
assert status == 201
|
||||
mock_register.assert_called_once_with(
|
||||
email="admin@example.com",
|
||||
name=payload["name"],
|
||||
password=payload["password"],
|
||||
ip_address="127.0.0.1",
|
||||
language=payload["language"],
|
||||
)
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.workspace.account import (
|
||||
AccountDeleteUpdateFeedbackApi,
|
||||
ChangeEmailCheckApi,
|
||||
ChangeEmailResetApi,
|
||||
ChangeEmailSendEmailApi,
|
||||
CheckEmailUnique,
|
||||
)
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
app.login_manager = SimpleNamespace(_load_user=lambda: None)
|
||||
return app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
|
||||
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
|
||||
account = Account(name=account_id, email=email)
|
||||
account.email = email
|
||||
account.id = account_id
|
||||
account.status = "active"
|
||||
account._current_tenant = tenant_obj
|
||||
return account
|
||||
|
||||
|
||||
def _set_logged_in_user(account: Account):
|
||||
g._login_user = account
|
||||
g._current_tenant = account.current_tenant
|
||||
|
||||
|
||||
class TestChangeEmailSend:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_phase(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {"email": "current@example.com"}
|
||||
mock_send_email.return_value = "token-abc"
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-abc"}
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=None,
|
||||
email="new@example.com",
|
||||
old_email="current@example.com",
|
||||
language="en-US",
|
||||
phase="new_email",
|
||||
)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailValidity:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_validate_with_normalized_email(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("user@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailReset:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_before_update(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {"old_email": "OLD@example.com"}
|
||||
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
|
||||
mock_update_account.return_value = mock_account_after_update
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "New@Example.com", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_is_freeze.assert_called_once_with("new@example.com")
|
||||
mock_check_unique.assert_called_once_with("new@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_update_account.assert_called_once_with(current_user, email="new@example.com")
|
||||
mock_send_notify.assert_called_once_with(email="new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
|
||||
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
with app.test_request_context(
|
||||
"/account/delete/feedback",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "feedback": "test"},
|
||||
):
|
||||
response = AccountDeleteUpdateFeedbackApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_update.assert_called_once_with("User@Example.com", "test")
|
||||
|
||||
|
||||
class TestCheckEmailUnique:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/check-email-unique",
|
||||
method="POST",
|
||||
json={"email": "Case@Test.com"},
|
||||
):
|
||||
response = CheckEmailUnique().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_is_freeze.assert_called_once_with("case@test.com")
|
||||
mock_check_unique.assert_called_once_with("case@test.com")
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
session = MagicMock()
|
||||
first = MagicMock()
|
||||
first.scalar_one_or_none.return_value = None
|
||||
second = MagicMock()
|
||||
expected_account = MagicMock()
|
||||
second.scalar_one_or_none.return_value = expected_account
|
||||
session.execute.side_effect = [first, second]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session)
|
||||
|
||||
assert result is expected_account
|
||||
assert session.execute.call_count == 2
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.workspace.members import MemberInviteEmailApi
|
||||
from models.account import Account, TenantAccountRole
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.login_manager = SimpleNamespace(_load_user=lambda: None)
|
||||
return flask_app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_feature_flags():
|
||||
placeholder_quota = SimpleNamespace(limit=0, size=0)
|
||||
workspace_members = SimpleNamespace(is_available=lambda count: True)
|
||||
return SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
workspace_members=workspace_members,
|
||||
members=placeholder_quota,
|
||||
apps=placeholder_quota,
|
||||
vector_space=placeholder_quota,
|
||||
documents_upload_quota=placeholder_quota,
|
||||
annotation_quota_limit=placeholder_quota,
|
||||
)
|
||||
|
||||
|
||||
class TestMemberInviteEmailApi:
|
||||
@patch("controllers.console.workspace.members.FeatureService.get_features")
|
||||
@patch("controllers.console.workspace.members.RegisterService.invite_new_member")
|
||||
@patch("controllers.console.workspace.members.current_account_with_tenant")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
def test_invite_normalizes_emails(
|
||||
self,
|
||||
mock_csrf,
|
||||
mock_db,
|
||||
mock_current_account,
|
||||
mock_invite_member,
|
||||
mock_get_features,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_get_features.return_value = _build_feature_flags()
|
||||
mock_invite_member.return_value = "token-abc"
|
||||
|
||||
tenant = SimpleNamespace(id="tenant-1", name="Test Tenant")
|
||||
inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active")
|
||||
mock_current_account.return_value = (inviter, tenant.id)
|
||||
|
||||
with patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"):
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/members/invite-email",
|
||||
method="POST",
|
||||
json={"emails": ["User@Example.com"], "role": TenantAccountRole.EDITOR.value, "language": "en-US"},
|
||||
):
|
||||
account = Account(name="tester", email="tester@example.com")
|
||||
account._current_tenant = tenant
|
||||
g._login_user = account
|
||||
g._current_tenant = tenant
|
||||
response, status_code = MemberInviteEmailApi().post()
|
||||
|
||||
assert status_code == 201
|
||||
assert response["invitation_results"][0]["email"] == "user@example.com"
|
||||
|
||||
assert mock_invite_member.call_count == 1
|
||||
call_args = mock_invite_member.call_args
|
||||
assert call_args.kwargs["tenant"] == tenant
|
||||
assert call_args.kwargs["email"] == "User@Example.com"
|
||||
assert call_args.kwargs["language"] == "en-US"
|
||||
assert call_args.kwargs["role"] == TenantAccountRole.EDITOR
|
||||
assert call_args.kwargs["inviter"] == inviter
|
||||
mock_csrf.assert_called_once()
|
||||
|
|
@ -1,195 +0,0 @@
|
|||
"""Unit tests for controllers.web.forgot_password endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import builtins
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
# Ensure flask_restx.api finds MethodView during import.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _load_controller_module():
|
||||
"""Import controllers.web.forgot_password using a stub package."""
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
parent_module_name = "controllers.web"
|
||||
module_name = f"{parent_module_name}.forgot_password"
|
||||
|
||||
if parent_module_name not in sys.modules:
|
||||
from flask_restx import Namespace
|
||||
|
||||
stub = ModuleType(parent_module_name)
|
||||
stub.__file__ = "controllers/web/__init__.py"
|
||||
stub.__path__ = ["controllers/web"]
|
||||
stub.__package__ = "controllers"
|
||||
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
|
||||
stub.web_ns = Namespace("web", description="Web API", path="/")
|
||||
sys.modules[parent_module_name] = stub
|
||||
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
forgot_password_module = _load_controller_module()
|
||||
ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi
|
||||
ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi
|
||||
ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Configure a minimal Flask app for request contexts."""
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_web_endpoint_guards():
|
||||
"""Stub enterprise and feature toggles used by route decorators."""
|
||||
|
||||
features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_controller_db():
|
||||
"""Replace controller-level db reference with a simple stub."""
|
||||
|
||||
fake_db = SimpleNamespace(engine=MagicMock(name="engine"))
|
||||
fake_wraps_db = SimpleNamespace(
|
||||
session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True))))
|
||||
)
|
||||
with (
|
||||
patch("controllers.web.forgot_password.db", fake_db),
|
||||
patch("controllers.console.wraps.db", fake_wraps_db),
|
||||
):
|
||||
yield fake_db
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10")
|
||||
def test_send_reset_email_success(
|
||||
mock_extract_ip: MagicMock,
|
||||
mock_is_ip_limit: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
mock_send_email: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password returns token when email exists and limits allow."""
|
||||
|
||||
mock_account = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = session_ctx
|
||||
session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "user@example.com"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "reset-token"}
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("203.0.113.10")
|
||||
mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US")
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token"))
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False)
|
||||
def test_check_token_success(
|
||||
mock_is_rate_limited: MagicMock,
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke: MagicMock,
|
||||
mock_generate: MagicMock,
|
||||
mock_reset_limit: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password/validity validates the code and refreshes token."""
|
||||
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "123456"}
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "code": "123456", "token": "old-token"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limited.assert_called_once_with("user@example.com")
|
||||
mock_get_data.assert_called_once_with("old-token")
|
||||
mock_revoke.assert_called_once_with("old-token")
|
||||
mock_generate.assert_called_once_with(
|
||||
"user@example.com",
|
||||
code="123456",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_reset_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_success(
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke_token: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
mock_token_bytes: MagicMock,
|
||||
mock_hash_password: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password/resets updates the stored password when token is valid."""
|
||||
|
||||
mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"}
|
||||
account = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = session_ctx
|
||||
session_ctx.execute.return_value.scalar_one_or_none.return_value = account
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_data.assert_called_once_with("reset-token")
|
||||
mock_revoke_token.assert_called_once_with("reset-token")
|
||||
mock_token_bytes.assert_called_once_with(16)
|
||||
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
|
||||
expected_password = base64.b64encode(b"hashed-value").decode()
|
||||
assert account.password == expected_password
|
||||
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
|
||||
assert account.password_salt == expected_salt
|
||||
session_ctx.commit.assert_called_once()
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
import base64
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
ForgotPasswordResetApi,
|
||||
ForgotPasswordSendEmailApi,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_wraps():
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
dify_settings = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD")
|
||||
with (
|
||||
patch("controllers.console.wraps.db") as mock_db,
|
||||
patch("controllers.console.wraps.dify_config", dify_settings),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
def test_should_normalize_email_before_sending(
|
||||
self,
|
||||
mock_session_cls,
|
||||
mock_extract_ip,
|
||||
mock_rate_limit,
|
||||
mock_get_account,
|
||||
mock_send_mail,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_rate_limit.assert_called_once_with("127.0.0.1")
|
||||
|
||||
|
||||
class TestForgotPasswordCheckApi:
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_should_normalize_email_for_validity_checks(
|
||||
self,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("user@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"User@Example.com",
|
||||
code="1234",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_should_preserve_token_email_case(
|
||||
self,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"}
|
||||
mock_generate_token.return_value = (None, "fresh-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "mixedcase@example.com", "code": "5678", "token": "token-upper"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "mixedcase@example.com", "token": "fresh-token"}
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"MixedCase@Example.com",
|
||||
code="5678",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_revoke_token.assert_called_once_with("token-upper")
|
||||
mock_reset_rate.assert_called_once_with("mixedcase@example.com")
|
||||
|
||||
|
||||
class TestForgotPasswordResetApi:
|
||||
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_should_fetch_account_with_fallback(
|
||||
self,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_session_cls,
|
||||
mock_get_account,
|
||||
mock_update_account,
|
||||
app,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_update_account.assert_called_once()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
|
||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
def test_should_update_password_and_commit(
|
||||
self,
|
||||
mock_get_account,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_session_cls,
|
||||
mock_token_bytes,
|
||||
mock_hash_password,
|
||||
app,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
|
||||
account = MagicMock()
|
||||
mock_get_account.return_value = account
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_reset_data.assert_called_once_with("reset-token")
|
||||
mock_revoke_token.assert_called_once_with("reset-token")
|
||||
mock_token_bytes.assert_called_once_with(16)
|
||||
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
|
||||
expected_password = base64.b64encode(b"hashed-value").decode()
|
||||
assert account.password == expected_password
|
||||
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
|
||||
assert account.password_salt == expected_salt
|
||||
mock_session.commit.assert_called_once()
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import base64
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
|
||||
|
||||
|
||||
def encode_code(code: str) -> str:
|
||||
return base64.b64encode(code.encode("utf-8")).decode()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_wraps():
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
console_dify = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD")
|
||||
web_dify = SimpleNamespace(ENTERPRISE_ENABLED=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.db") as mock_db,
|
||||
patch("controllers.console.wraps.dify_config", console_dify),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
patch("controllers.web.login.dify_config", web_dify),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
class TestEmailCodeLoginSendEmailApi:
|
||||
@patch("controllers.web.login.WebAppAuthService.send_email_code_login_email")
|
||||
@patch("controllers.web.login.WebAppAuthService.get_user_through_email")
|
||||
def test_should_fetch_account_with_original_email(
|
||||
self,
|
||||
mock_get_user,
|
||||
mock_send_email,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "token-123"
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/email-code-login",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "en-US"},
|
||||
):
|
||||
response = EmailCodeLoginSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_user.assert_called_once_with("User@Example.com")
|
||||
mock_send_email.assert_called_once_with(account=mock_account, language="en-US")
|
||||
|
||||
|
||||
class TestEmailCodeLoginApi:
|
||||
@patch("controllers.web.login.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.web.login.WebAppAuthService.login", return_value="new-access-token")
|
||||
@patch("controllers.web.login.WebAppAuthService.get_user_through_email")
|
||||
@patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token")
|
||||
@patch("controllers.web.login.WebAppAuthService.get_email_code_login_data")
|
||||
def test_should_normalize_email_before_validating(
|
||||
self,
|
||||
mock_get_token_data,
|
||||
mock_revoke_token,
|
||||
mock_get_user,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app,
|
||||
):
|
||||
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
|
||||
mock_get_user.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
|
||||
):
|
||||
response = EmailCodeLoginApi().post()
|
||||
|
||||
assert response.get_json() == {"result": "success", "data": {"access_token": "new-access-token"}}
|
||||
mock_get_user.assert_called_once_with("User@Example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_login_rate.assert_called_once_with("user@example.com")
|
||||
|
|
@ -228,11 +228,28 @@ def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.M
|
|||
def scalar(self, _stmt):
|
||||
return self.results.pop(0)
|
||||
|
||||
# SQLAlchemy Session APIs used by code under test
|
||||
def expunge(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
# support `with session_factory.create_session() as session:`
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.close()
|
||||
|
||||
tenant = SimpleNamespace(id="tenant_id")
|
||||
end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id")
|
||||
db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user]))
|
||||
|
||||
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
|
||||
# Monkeypatch session factory to return our stub session
|
||||
monkeypatch.setattr(
|
||||
"core.tools.workflow_as_tool.tool.session_factory.create_session",
|
||||
lambda: StubSession([tenant, None, end_user]),
|
||||
)
|
||||
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
|
|
@ -266,8 +283,23 @@ def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pyt
|
|||
def scalar(self, _stmt):
|
||||
return self.results.pop(0)
|
||||
|
||||
db_stub = SimpleNamespace(session=StubSession([None]))
|
||||
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
|
||||
def expunge(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.close()
|
||||
|
||||
# Monkeypatch session factory to return our stub session with no tenant
|
||||
monkeypatch.setattr(
|
||||
"core.tools.workflow_as_tool.tool.session_factory.create_session",
|
||||
lambda: StubSession([None]),
|
||||
)
|
||||
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ from core.variables.variables import (
|
|||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
VariableUnion,
|
||||
)
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
@ -96,7 +95,7 @@ class _Segments(BaseModel):
|
|||
|
||||
|
||||
class _Variables(BaseModel):
|
||||
variables: list[VariableUnion]
|
||||
variables: list[Variable]
|
||||
|
||||
|
||||
def create_test_file(
|
||||
|
|
@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad:
|
|||
# Create one instance of each variable type
|
||||
test_file = create_test_file()
|
||||
|
||||
all_variables: list[VariableUnion] = [
|
||||
all_variables: list[Variable] = [
|
||||
NoneVariable(name="none_var"),
|
||||
StringVariable(value="test string", name="string_var"),
|
||||
IntegerVariable(value=42, name="int_var"),
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from core.variables import (
|
|||
SegmentType,
|
||||
StringVariable,
|
||||
)
|
||||
from core.variables.variables import Variable
|
||||
from core.variables.variables import VariableBase
|
||||
|
||||
|
||||
def test_frozen_variables():
|
||||
|
|
@ -76,7 +76,7 @@ def test_object_variable_to_object():
|
|||
|
||||
|
||||
def test_variable_to_object():
|
||||
var: Variable = StringVariable(name="text", value="text")
|
||||
var: VariableBase = StringVariable(name="text", value="text")
|
||||
assert var.to_object() == "text"
|
||||
var = IntegerVariable(name="integer", value=42)
|
||||
assert var.to_object() == 42
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from core.variables.variables import (
|
|||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
StringVariable,
|
||||
VariableUnion,
|
||||
Variable,
|
||||
)
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
|
@ -160,7 +160,7 @@ class TestVariablePoolSerialization:
|
|||
)
|
||||
|
||||
# Create environment variables with all types including ArrayFileVariable
|
||||
env_vars: list[VariableUnion] = [
|
||||
env_vars: list[Variable] = [
|
||||
StringVariable(
|
||||
id="env_string_id",
|
||||
name="env_string",
|
||||
|
|
@ -182,7 +182,7 @@ class TestVariablePoolSerialization:
|
|||
]
|
||||
|
||||
# Create conversation variables with complex data
|
||||
conv_vars: list[VariableUnion] = [
|
||||
conv_vars: list[Variable] = [
|
||||
StringVariable(
|
||||
id="conv_string_id",
|
||||
name="conv_string",
|
||||
|
|
|
|||
|
|
@ -2,13 +2,17 @@ from types import SimpleNamespace
|
|||
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.enums import FileType
|
||||
from core.file.models import File, FileTransferMethod
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
|
@ -96,6 +100,58 @@ class TestWorkflowEntry:
|
|||
assert output_var is not None
|
||||
assert output_var.value == "system_user"
|
||||
|
||||
def test_single_step_run_injects_code_limits(self):
|
||||
"""Ensure single-step CodeNode execution configures limits."""
|
||||
# Arrange
|
||||
node_id = "code_node"
|
||||
node_data = {
|
||||
"type": "code",
|
||||
"title": "Code",
|
||||
"desc": None,
|
||||
"variables": [],
|
||||
"code_language": CodeLanguage.PYTHON3,
|
||||
"code": "def main():\n return {}",
|
||||
"outputs": {},
|
||||
}
|
||||
node_config = {"id": node_id, "data": node_data}
|
||||
|
||||
class StubWorkflow:
|
||||
def __init__(self):
|
||||
self.tenant_id = "tenant"
|
||||
self.app_id = "app"
|
||||
self.id = "workflow"
|
||||
self.graph_dict = {"nodes": [node_config], "edges": []}
|
||||
|
||||
def get_node_config_by_id(self, target_id: str):
|
||||
assert target_id == node_id
|
||||
return node_config
|
||||
|
||||
workflow = StubWorkflow()
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={})
|
||||
expected_limits = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
min_number=dify_config.CODE_MIN_NUMBER,
|
||||
max_precision=dify_config.CODE_MAX_PRECISION,
|
||||
max_depth=dify_config.CODE_MAX_DEPTH,
|
||||
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
|
||||
# Act
|
||||
node, _ = WorkflowEntry.single_step_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_id="user",
|
||||
user_inputs={},
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(node, CodeNode)
|
||||
assert node._limits == expected_limits
|
||||
|
||||
def test_mapping_user_inputs_to_variable_pool_with_env_variables(self):
|
||||
"""Test mapping environment variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with environment variables
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
"""LogStore extension unit tests."""
|
||||
|
|
@ -0,0 +1,469 @@
|
|||
"""
|
||||
Unit tests for SQL escape utility functions.
|
||||
|
||||
These tests ensure that SQL injection attacks are properly prevented
|
||||
in LogStore queries, particularly for cross-tenant access scenarios.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
|
||||
|
||||
|
||||
class TestEscapeSQLString:
|
||||
"""Test escape_sql_string function."""
|
||||
|
||||
def test_escape_empty_string(self):
|
||||
"""Test escaping empty string."""
|
||||
assert escape_sql_string("") == ""
|
||||
|
||||
def test_escape_normal_string(self):
|
||||
"""Test escaping string without special characters."""
|
||||
assert escape_sql_string("tenant_abc123") == "tenant_abc123"
|
||||
assert escape_sql_string("app-uuid-1234") == "app-uuid-1234"
|
||||
|
||||
def test_escape_single_quote(self):
|
||||
"""Test escaping single quote."""
|
||||
# Single quote should be doubled
|
||||
assert escape_sql_string("tenant'id") == "tenant''id"
|
||||
assert escape_sql_string("O'Reilly") == "O''Reilly"
|
||||
|
||||
def test_escape_multiple_quotes(self):
|
||||
"""Test escaping multiple single quotes."""
|
||||
assert escape_sql_string("a'b'c") == "a''b''c"
|
||||
assert escape_sql_string("'''") == "''''''"
|
||||
|
||||
# === SQL Injection Attack Scenarios ===
|
||||
|
||||
def test_prevent_boolean_injection(self):
|
||||
"""Test prevention of boolean injection attacks."""
|
||||
# Classic OR 1=1 attack
|
||||
malicious_input = "tenant' OR '1'='1"
|
||||
escaped = escape_sql_string(malicious_input)
|
||||
assert escaped == "tenant'' OR ''1''=''1"
|
||||
|
||||
# When used in SQL, this becomes a safe string literal
|
||||
sql = f"WHERE tenant_id='{escaped}'"
|
||||
assert sql == "WHERE tenant_id='tenant'' OR ''1''=''1'"
|
||||
# The entire input is now a string literal that won't match any tenant
|
||||
|
||||
def test_prevent_or_injection(self):
|
||||
"""Test prevention of OR-based injection."""
|
||||
malicious_input = "tenant_a' OR tenant_id='tenant_b"
|
||||
escaped = escape_sql_string(malicious_input)
|
||||
assert escaped == "tenant_a'' OR tenant_id=''tenant_b"
|
||||
|
||||
sql = f"WHERE tenant_id='{escaped}'"
|
||||
# The OR is now part of the string literal, not SQL logic
|
||||
assert "OR tenant_id=" in sql
|
||||
# The SQL has: opening ', doubled internal quotes '', and closing '
|
||||
assert sql == "WHERE tenant_id='tenant_a'' OR tenant_id=''tenant_b'"
|
||||
|
||||
def test_prevent_union_injection(self):
|
||||
"""Test prevention of UNION-based injection."""
|
||||
malicious_input = "xxx' UNION SELECT password FROM users WHERE '1'='1"
|
||||
escaped = escape_sql_string(malicious_input)
|
||||
assert escaped == "xxx'' UNION SELECT password FROM users WHERE ''1''=''1"
|
||||
|
||||
# UNION becomes part of the string literal
|
||||
assert "UNION" in escaped
|
||||
assert escaped.count("''") == 4 # All internal quotes are doubled
|
||||
|
||||
def test_prevent_comment_injection(self):
|
||||
"""Test prevention of comment-based injection."""
|
||||
# SQL comment to bypass remaining conditions
|
||||
malicious_input = "tenant' --"
|
||||
escaped = escape_sql_string(malicious_input)
|
||||
assert escaped == "tenant'' --"
|
||||
|
||||
sql = f"WHERE tenant_id='{escaped}' AND deleted=false"
|
||||
# The -- is now inside the string, not a SQL comment
|
||||
assert "--" in sql
|
||||
assert "AND deleted=false" in sql # This part is NOT commented out
|
||||
|
||||
def test_prevent_semicolon_injection(self):
|
||||
"""Test prevention of semicolon-based multi-statement injection."""
|
||||
malicious_input = "tenant'; DROP TABLE users; --"
|
||||
escaped = escape_sql_string(malicious_input)
|
||||
assert escaped == "tenant''; DROP TABLE users; --"
|
||||
|
||||
# Semicolons and DROP are now part of the string
|
||||
assert "DROP TABLE" in escaped
|
||||
|
||||
def test_prevent_time_based_blind_injection(self):
|
||||
"""Test prevention of time-based blind SQL injection."""
|
||||
malicious_input = "tenant' AND SLEEP(5) --"
|
||||
escaped = escape_sql_string(malicious_input)
|
||||
assert escaped == "tenant'' AND SLEEP(5) --"
|
||||
|
||||
# SLEEP becomes part of the string
|
||||
assert "SLEEP" in escaped
|
||||
|
||||
def test_prevent_wildcard_injection(self):
|
||||
"""Test prevention of wildcard-based injection."""
|
||||
malicious_input = "tenant' OR tenant_id LIKE '%"
|
||||
escaped = escape_sql_string(malicious_input)
|
||||
assert escaped == "tenant'' OR tenant_id LIKE ''%"
|
||||
|
||||
# The LIKE and wildcard are now part of the string
|
||||
assert "LIKE" in escaped
|
||||
|
||||
def test_prevent_null_byte_injection(self):
|
||||
"""Test handling of null bytes."""
|
||||
# Null bytes can sometimes bypass filters
|
||||
malicious_input = "tenant\x00' OR '1'='1"
|
||||
escaped = escape_sql_string(malicious_input)
|
||||
# Null byte is preserved, but quote is escaped
|
||||
assert "''1''=''1" in escaped
|
||||
|
||||
# === Real-world SAAS Scenarios ===
|
||||
|
||||
def test_cross_tenant_access_attempt(self):
|
||||
"""Test prevention of cross-tenant data access."""
|
||||
# Attacker tries to access another tenant's data
|
||||
attacker_input = "tenant_b' OR tenant_id='tenant_a"
|
||||
escaped = escape_sql_string(attacker_input)
|
||||
|
||||
sql = f"SELECT * FROM workflow_runs WHERE tenant_id='{escaped}'"
|
||||
# The query will look for a tenant literally named "tenant_b' OR tenant_id='tenant_a"
|
||||
# which doesn't exist - preventing access to either tenant's data
|
||||
assert "tenant_b'' OR tenant_id=''tenant_a" in sql
|
||||
|
||||
def test_cross_app_access_attempt(self):
|
||||
"""Test prevention of cross-application data access."""
|
||||
attacker_input = "app1' OR app_id='app2"
|
||||
escaped = escape_sql_string(attacker_input)
|
||||
|
||||
sql = f"WHERE app_id='{escaped}'"
|
||||
# Cannot access app2's data
|
||||
assert "app1'' OR app_id=''app2" in sql
|
||||
|
||||
def test_bypass_status_filter(self):
|
||||
"""Test prevention of bypassing status filters."""
|
||||
# Try to see all statuses instead of just 'running'
|
||||
attacker_input = "running' OR status LIKE '%"
|
||||
escaped = escape_sql_string(attacker_input)
|
||||
|
||||
sql = f"WHERE status='{escaped}'"
|
||||
# Status condition is not bypassed
|
||||
assert "running'' OR status LIKE ''%" in sql
|
||||
|
||||
# === Edge Cases ===
|
||||
|
||||
def test_escape_only_quotes(self):
|
||||
"""Test string with only quotes."""
|
||||
assert escape_sql_string("'") == "''"
|
||||
assert escape_sql_string("''") == "''''"
|
||||
|
||||
def test_escape_mixed_content(self):
|
||||
"""Test string with mixed quotes and other chars."""
|
||||
input_str = "It's a 'test' of O'Reilly's code"
|
||||
escaped = escape_sql_string(input_str)
|
||||
assert escaped == "It''s a ''test'' of O''Reilly''s code"
|
||||
|
||||
def test_escape_unicode_with_quotes(self):
|
||||
"""Test Unicode strings with quotes."""
|
||||
input_str = "租户' OR '1'='1"
|
||||
escaped = escape_sql_string(input_str)
|
||||
assert escaped == "租户'' OR ''1''=''1"
|
||||
|
||||
|
||||
class TestEscapeIdentifier:
|
||||
"""Test escape_identifier function."""
|
||||
|
||||
def test_escape_uuid(self):
|
||||
"""Test escaping UUID identifiers."""
|
||||
uuid = "550e8400-e29b-41d4-a716-446655440000"
|
||||
assert escape_identifier(uuid) == uuid
|
||||
|
||||
def test_escape_alphanumeric_id(self):
|
||||
"""Test escaping alphanumeric identifiers."""
|
||||
assert escape_identifier("tenant_123") == "tenant_123"
|
||||
assert escape_identifier("app-abc-123") == "app-abc-123"
|
||||
|
||||
def test_escape_identifier_with_quote(self):
|
||||
"""Test escaping identifier with single quote."""
|
||||
malicious = "tenant' OR '1'='1"
|
||||
escaped = escape_identifier(malicious)
|
||||
assert escaped == "tenant'' OR ''1''=''1"
|
||||
|
||||
def test_identifier_injection_attempt(self):
|
||||
"""Test prevention of injection through identifiers."""
|
||||
# Common identifier injection patterns
|
||||
test_cases = [
|
||||
("id' OR '1'='1", "id'' OR ''1''=''1"),
|
||||
("id'; DROP TABLE", "id''; DROP TABLE"),
|
||||
("id' UNION SELECT", "id'' UNION SELECT"),
|
||||
]
|
||||
|
||||
for malicious, expected in test_cases:
|
||||
assert escape_identifier(malicious) == expected
|
||||
|
||||
|
||||
class TestSQLInjectionIntegration:
|
||||
"""Integration tests simulating real SQL construction scenarios."""
|
||||
|
||||
def test_complete_where_clause_safety(self):
|
||||
"""Test that a complete WHERE clause is safe from injection."""
|
||||
# Simulating typical query construction
|
||||
tenant_id = "tenant' OR '1'='1"
|
||||
app_id = "app' UNION SELECT"
|
||||
run_id = "run' --"
|
||||
|
||||
escaped_tenant = escape_identifier(tenant_id)
|
||||
escaped_app = escape_identifier(app_id)
|
||||
escaped_run = escape_identifier(run_id)
|
||||
|
||||
sql = f"""
|
||||
SELECT * FROM workflow_runs
|
||||
WHERE tenant_id='{escaped_tenant}'
|
||||
AND app_id='{escaped_app}'
|
||||
AND id='{escaped_run}'
|
||||
"""
|
||||
|
||||
# Verify all special characters are escaped
|
||||
assert "tenant'' OR ''1''=''1" in sql
|
||||
assert "app'' UNION SELECT" in sql
|
||||
assert "run'' --" in sql
|
||||
|
||||
# Verify SQL structure is preserved (3 conditions with AND)
|
||||
assert sql.count("AND") == 2
|
||||
|
||||
def test_multiple_conditions_with_injection_attempts(self):
|
||||
"""Test multiple conditions all attempting injection."""
|
||||
conditions = {
|
||||
"tenant_id": "t1' OR tenant_id='t2",
|
||||
"app_id": "a1' OR app_id='a2",
|
||||
"status": "running' OR '1'='1",
|
||||
}
|
||||
|
||||
where_parts = []
|
||||
for field, value in conditions.items():
|
||||
escaped = escape_sql_string(value)
|
||||
where_parts.append(f"{field}='{escaped}'")
|
||||
|
||||
where_clause = " AND ".join(where_parts)
|
||||
|
||||
# All injection attempts are neutralized
|
||||
assert "t1'' OR tenant_id=''t2" in where_clause
|
||||
assert "a1'' OR app_id=''a2" in where_clause
|
||||
assert "running'' OR ''1''=''1" in where_clause
|
||||
|
||||
# AND structure is preserved
|
||||
assert where_clause.count(" AND ") == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("attack_vector", "description"),
|
||||
[
|
||||
("' OR '1'='1", "Boolean injection"),
|
||||
("' OR '1'='1' --", "Boolean with comment"),
|
||||
("' UNION SELECT * FROM users --", "Union injection"),
|
||||
("'; DROP TABLE workflow_runs; --", "Destructive command"),
|
||||
("' AND SLEEP(10) --", "Time-based blind"),
|
||||
("' OR tenant_id LIKE '%", "Wildcard injection"),
|
||||
("admin' --", "Comment bypass"),
|
||||
("' OR 1=1 LIMIT 1 --", "Limit bypass"),
|
||||
],
|
||||
)
|
||||
def test_common_injection_vectors(self, attack_vector, description):
|
||||
"""Test protection against common injection attack vectors."""
|
||||
escaped = escape_sql_string(attack_vector)
|
||||
|
||||
# Build SQL
|
||||
sql = f"WHERE tenant_id='{escaped}'"
|
||||
|
||||
# Verify the attack string is now a safe literal
|
||||
# The key indicator: all internal single quotes are doubled
|
||||
internal_quotes = escaped.count("''")
|
||||
original_quotes = attack_vector.count("'")
|
||||
|
||||
# Each original quote should be doubled
|
||||
assert internal_quotes == original_quotes
|
||||
|
||||
# Verify SQL has exactly 2 quotes (opening and closing)
|
||||
assert sql.count("'") >= 2 # At least opening and closing
|
||||
|
||||
def test_logstore_specific_scenario(self):
|
||||
"""Test SQL injection prevention in LogStore-specific scenarios."""
|
||||
# Simulate LogStore query with window function
|
||||
tenant_id = "tenant' OR '1'='1"
|
||||
app_id = "app' UNION SELECT"
|
||||
|
||||
escaped_tenant = escape_identifier(tenant_id)
|
||||
escaped_app = escape_identifier(app_id)
|
||||
|
||||
sql = f"""
|
||||
SELECT * FROM (
|
||||
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
|
||||
FROM workflow_execution_logstore
|
||||
WHERE tenant_id='{escaped_tenant}'
|
||||
AND app_id='{escaped_app}'
|
||||
AND __time__ > 0
|
||||
) AS subquery WHERE rn = 1
|
||||
"""
|
||||
|
||||
# Complex query structure is maintained
|
||||
assert "ROW_NUMBER()" in sql
|
||||
assert "PARTITION BY id" in sql
|
||||
|
||||
# Injection attempts are escaped
|
||||
assert "tenant'' OR ''1''=''1" in sql
|
||||
assert "app'' UNION SELECT" in sql
|
||||
|
||||
|
||||
# ====================================================================================
|
||||
# Tests for LogStore Query Syntax (SDK Mode)
|
||||
# ====================================================================================
|
||||
|
||||
|
||||
class TestLogStoreQueryEscape:
|
||||
"""Test escape_logstore_query_value for SDK mode query syntax."""
|
||||
|
||||
def test_normal_value(self):
|
||||
"""Test escaping normal alphanumeric value."""
|
||||
value = "550e8400-e29b-41d4-a716-446655440000"
|
||||
escaped = escape_logstore_query_value(value)
|
||||
|
||||
# Should be wrapped in double quotes
|
||||
assert escaped == '"550e8400-e29b-41d4-a716-446655440000"'
|
||||
|
||||
def test_empty_value(self):
|
||||
"""Test escaping empty string."""
|
||||
assert escape_logstore_query_value("") == '""'
|
||||
|
||||
def test_value_with_and_keyword(self):
|
||||
"""Test that 'and' keyword is neutralized when quoted."""
|
||||
malicious = "value and field:evil"
|
||||
escaped = escape_logstore_query_value(malicious)
|
||||
|
||||
# Should be wrapped in quotes, making 'and' a literal
|
||||
assert escaped == '"value and field:evil"'
|
||||
|
||||
# Simulate using in query
|
||||
query = f"tenant_id:{escaped}"
|
||||
assert query == 'tenant_id:"value and field:evil"'
|
||||
|
||||
def test_value_with_or_keyword(self):
|
||||
"""Test that 'or' keyword is neutralized when quoted."""
|
||||
malicious = "tenant_a or tenant_id:tenant_b"
|
||||
escaped = escape_logstore_query_value(malicious)
|
||||
|
||||
assert escaped == '"tenant_a or tenant_id:tenant_b"'
|
||||
|
||||
query = f"tenant_id:{escaped}"
|
||||
assert "or" in query # Present but as literal string
|
||||
|
||||
def test_value_with_not_keyword(self):
|
||||
"""Test that 'not' keyword is neutralized when quoted."""
|
||||
malicious = "not field:value"
|
||||
escaped = escape_logstore_query_value(malicious)
|
||||
|
||||
assert escaped == '"not field:value"'
|
||||
|
||||
def test_value_with_parentheses(self):
|
||||
"""Test that parentheses are neutralized when quoted."""
|
||||
malicious = "(tenant_a or tenant_b)"
|
||||
escaped = escape_logstore_query_value(malicious)
|
||||
|
||||
assert escaped == '"(tenant_a or tenant_b)"'
|
||||
assert "(" in escaped # Present as literal
|
||||
assert ")" in escaped # Present as literal
|
||||
|
||||
def test_value_with_colon(self):
|
||||
"""Test that colons are neutralized when quoted."""
|
||||
malicious = "field:value"
|
||||
escaped = escape_logstore_query_value(malicious)
|
||||
|
||||
assert escaped == '"field:value"'
|
||||
assert ":" in escaped # Present as literal
|
||||
|
||||
def test_value_with_double_quotes(self):
|
||||
"""Test that internal double quotes are escaped."""
|
||||
value_with_quotes = 'tenant"test"value'
|
||||
escaped = escape_logstore_query_value(value_with_quotes)
|
||||
|
||||
# Double quotes should be escaped with backslash
|
||||
assert escaped == '"tenant\\"test\\"value"'
|
||||
# Should have outer quotes plus escaped inner quotes
|
||||
assert '\\"' in escaped
|
||||
|
||||
def test_value_with_backslash(self):
|
||||
"""Test that backslashes are escaped."""
|
||||
value_with_backslash = "tenant\\test"
|
||||
escaped = escape_logstore_query_value(value_with_backslash)
|
||||
|
||||
# Backslash should be escaped
|
||||
assert escaped == '"tenant\\\\test"'
|
||||
assert "\\\\" in escaped
|
||||
|
||||
def test_value_with_backslash_and_quote(self):
|
||||
"""Test escaping both backslash and double quote."""
|
||||
value = 'path\\to\\"file"'
|
||||
escaped = escape_logstore_query_value(value)
|
||||
|
||||
# Both should be escaped
|
||||
assert escaped == '"path\\\\to\\\\\\"file\\""'
|
||||
# Verify escape order is correct
|
||||
assert "\\\\" in escaped # Escaped backslash
|
||||
assert '\\"' in escaped # Escaped double quote
|
||||
|
||||
def test_complex_injection_attempt(self):
|
||||
"""Test complex injection combining multiple operators."""
|
||||
malicious = 'tenant_a" or (tenant_id:"tenant_b" and app_id:"evil")'
|
||||
escaped = escape_logstore_query_value(malicious)
|
||||
|
||||
# All special chars should be literals or escaped
|
||||
assert escaped.startswith('"')
|
||||
assert escaped.endswith('"')
|
||||
# Inner double quotes escaped, operators become literals
|
||||
assert "or" in escaped
|
||||
assert "and" in escaped
|
||||
assert '\\"' in escaped # Escaped quotes
|
||||
|
||||
def test_only_backslash(self):
|
||||
"""Test escaping a single backslash."""
|
||||
assert escape_logstore_query_value("\\") == '"\\\\"'
|
||||
|
||||
def test_only_double_quote(self):
|
||||
"""Test escaping a single double quote."""
|
||||
assert escape_logstore_query_value('"') == '"\\""'
|
||||
|
||||
def test_multiple_backslashes(self):
|
||||
"""Test escaping multiple consecutive backslashes."""
|
||||
assert escape_logstore_query_value("\\\\\\") == '"\\\\\\\\\\\\"' # 3 backslashes -> 6
|
||||
|
||||
def test_escape_sequence_like_input(self):
|
||||
"""Test that existing escape sequences are properly escaped."""
|
||||
# Input looks like already escaped, but we still escape it
|
||||
value = 'value\\"test'
|
||||
escaped = escape_logstore_query_value(value)
|
||||
# \\ -> \\\\, " -> \"
|
||||
assert escaped == '"value\\\\\\"test"'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("attack_scenario", "field", "malicious_value"),
|
||||
[
|
||||
("Cross-tenant via OR", "tenant_id", "tenant_a or tenant_id:tenant_b"),
|
||||
("Cross-app via AND", "app_id", "app_a and (app_id:app_b or app_id:app_c)"),
|
||||
("Boolean logic", "status", "succeeded or status:failed"),
|
||||
("Negation", "tenant_id", "not tenant_a"),
|
||||
("Field injection", "run_id", "run123 and tenant_id:evil_tenant"),
|
||||
("Parentheses grouping", "app_id", "app1 or (app_id:app2 and tenant_id:tenant2)"),
|
||||
("Quote breaking attempt", "tenant_id", 'tenant" or "1"="1'),
|
||||
("Backslash escape bypass", "app_id", "app\\ and app_id:evil"),
|
||||
],
|
||||
)
|
||||
def test_logstore_query_injection_scenarios(attack_scenario: str, field: str, malicious_value: str):
|
||||
"""Test that various LogStore query injection attempts are neutralized."""
|
||||
escaped = escape_logstore_query_value(malicious_value)
|
||||
|
||||
# Build query
|
||||
query = f"{field}:{escaped}"
|
||||
|
||||
# All operators should be within quoted string (literals)
|
||||
assert escaped.startswith('"')
|
||||
assert escaped.endswith('"')
|
||||
|
||||
# Verify the full query structure is safe
|
||||
assert query.count(":") >= 1 # At least the main field:value separator
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
"""Unit tests for traceparent header propagation in EnterpriseRequest.
|
||||
|
||||
This test module verifies that the W3C traceparent header is properly
|
||||
generated and included in HTTP requests made by EnterpriseRequest.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
|
||||
class TestTraceparentPropagation:
|
||||
"""Unit tests for traceparent header propagation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_enterprise_config(self):
|
||||
"""Mock EnterpriseRequest configuration."""
|
||||
with (
|
||||
patch.object(EnterpriseRequest, "base_url", "https://enterprise-api.example.com"),
|
||||
patch.object(EnterpriseRequest, "secret_key", "test-secret-key"),
|
||||
patch.object(EnterpriseRequest, "secret_key_header", "Enterprise-Api-Secret-Key"),
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client(self):
|
||||
"""Mock httpx.Client for testing."""
|
||||
with patch("services.enterprise.base.httpx.Client") as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client
|
||||
mock_client_class.return_value.__exit__.return_value = None
|
||||
|
||||
# Setup default response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_client.request.return_value = mock_response
|
||||
|
||||
yield mock_client
|
||||
|
||||
def test_traceparent_header_included_when_generated(self, mock_enterprise_config, mock_httpx_client):
|
||||
"""Test that traceparent header is included when successfully generated."""
|
||||
# Arrange
|
||||
expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
|
||||
|
||||
with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent):
|
||||
# Act
|
||||
EnterpriseRequest.send_request("GET", "/test")
|
||||
|
||||
# Assert
|
||||
mock_httpx_client.request.assert_called_once()
|
||||
call_args = mock_httpx_client.request.call_args
|
||||
headers = call_args[1]["headers"]
|
||||
|
||||
assert "traceparent" in headers
|
||||
assert headers["traceparent"] == expected_traceparent
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["Enterprise-Api-Secret-Key"] == "test-secret-key"
|
||||
|
|
@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from models.account import Account
|
||||
from models.account import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
|
|
@ -1147,9 +1147,13 @@ class TestRegisterService:
|
|||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
|
||||
|
||||
with patch("services.account_service.Session") as mock_session_class:
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = None
|
||||
|
||||
# Mock RegisterService.register
|
||||
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
|
|
@ -1182,9 +1186,59 @@ class TestRegisterService:
|
|||
email="newuser@example.com",
|
||||
name="newuser",
|
||||
language="en-US",
|
||||
status="pending",
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session)
|
||||
|
||||
def test_invite_new_member_normalizes_new_account_email(
|
||||
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
|
||||
):
|
||||
"""Ensure inviting with mixed-case email normalizes before registering."""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-456"
|
||||
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
||||
mixed_email = "Invitee@Example.com"
|
||||
|
||||
mock_session = MagicMock()
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = None
|
||||
|
||||
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
account_id="new-user-789", email="invitee@example.com", name="invitee", status="pending"
|
||||
)
|
||||
with patch("services.account_service.RegisterService.register") as mock_register:
|
||||
mock_register.return_value = mock_new_account
|
||||
with (
|
||||
patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
|
||||
patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
|
||||
patch("services.account_service.TenantService.switch_tenant") as mock_switch_tenant,
|
||||
patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token,
|
||||
):
|
||||
mock_generate_token.return_value = "invite-token-abc"
|
||||
|
||||
RegisterService.invite_new_member(
|
||||
tenant=mock_tenant,
|
||||
email=mixed_email,
|
||||
language="en-US",
|
||||
role="normal",
|
||||
inviter=mock_inviter,
|
||||
)
|
||||
|
||||
mock_register.assert_called_once_with(
|
||||
email="invitee@example.com",
|
||||
name="invitee",
|
||||
language="en-US",
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with(mixed_email, session=mock_session)
|
||||
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
|
||||
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
|
||||
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
|
||||
mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account)
|
||||
|
|
@ -1207,9 +1261,13 @@ class TestRegisterService:
|
|||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
|
||||
|
||||
with patch("services.account_service.Session") as mock_session_class:
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
|
||||
# Mock the db.session.query for TenantAccountJoin
|
||||
mock_db_query = MagicMock()
|
||||
|
|
@ -1238,6 +1296,7 @@ class TestRegisterService:
|
|||
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
|
||||
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
|
||||
mock_task_dependencies.delay.assert_called_once()
|
||||
mock_lookup.assert_called_once_with("existing@example.com", session=mock_session)
|
||||
|
||||
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
|
||||
"""Test inviting a member who is already in the tenant."""
|
||||
|
|
@ -1251,7 +1310,6 @@ class TestRegisterService:
|
|||
|
||||
# Mock database queries
|
||||
query_results = {
|
||||
("Account", "email", "existing@example.com"): mock_existing_account,
|
||||
(
|
||||
"TenantAccountJoin",
|
||||
"tenant_id",
|
||||
|
|
@ -1261,7 +1319,11 @@ class TestRegisterService:
|
|||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
|
||||
# Mock TenantService methods
|
||||
with patch("services.account_service.TenantService.check_member_permission") as mock_check_permission:
|
||||
with (
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
|
||||
):
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
AccountAlreadyInTenantError,
|
||||
|
|
@ -1272,6 +1334,7 @@ class TestRegisterService:
|
|||
role="normal",
|
||||
inviter=mock_inviter,
|
||||
)
|
||||
mock_lookup.assert_called_once()
|
||||
|
||||
def test_invite_new_member_no_inviter(self):
|
||||
"""Test inviting a member without providing an inviter."""
|
||||
|
|
@ -1497,6 +1560,30 @@ class TestRegisterService:
|
|||
# Verify results
|
||||
assert result is None
|
||||
|
||||
def test_get_invitation_with_case_fallback_returns_initial_match(self):
|
||||
"""Fallback helper should return the initial invitation when present."""
|
||||
invitation = {"workspace_id": "tenant-456"}
|
||||
with patch(
|
||||
"services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation
|
||||
) as mock_get:
|
||||
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
|
||||
|
||||
assert result == invitation
|
||||
mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123")
|
||||
|
||||
def test_get_invitation_with_case_fallback_retries_with_lowercase(self):
|
||||
"""Fallback helper should retry with lowercase email when needed."""
|
||||
invitation = {"workspace_id": "tenant-456"}
|
||||
with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get:
|
||||
mock_get.side_effect = [None, invitation]
|
||||
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
|
||||
|
||||
assert result == invitation
|
||||
assert mock_get.call_args_list == [
|
||||
(("tenant-456", "User@Test.com", "token-123"),),
|
||||
(("tenant-456", "user@test.com", "token-123"),),
|
||||
]
|
||||
|
||||
# ==================== Helper Method Tests ====================
|
||||
|
||||
def test_get_invitation_token_key(self):
|
||||
|
|
|
|||
14
api/uv.lock
14
api/uv.lock
|
|
@ -453,15 +453,15 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "azure-core"
|
||||
version = "1.36.0"
|
||||
version = "1.38.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "requests" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0a/c4/d4ff3bc3ddf155156460bff340bbe9533f99fac54ddea165f35a8619f162/azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7", size = 351139, upload-time = "2025-10-15T00:33:49.083Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/dc/1b/e503e08e755ea94e7d3419c9242315f888fc664211c90d032e40479022bf/azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993", size = 363033, upload-time = "2026-01-12T17:03:05.535Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b", size = 213302, upload-time = "2025-10-15T00:33:51.058Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/d8/b8fcba9464f02b121f39de2db2bf57f0b216fe11d014513d666e8634380d/azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335", size = 217825, upload-time = "2026-01-12T17:03:07.291Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1368,7 +1368,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.11.2"
|
||||
version = "1.11.3"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
|
|
@ -1965,11 +1965,11 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.20.0"
|
||||
version = "3.20.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1d/65/ce7f1b70157833bf3cb851b556a37d4547ceafc158aa9b34b36782f23696/filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1", size = 19485, upload-time = "2026-01-09T17:55:05.421Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/36/7fb70f04bf00bc646cd5bb45aa9eddb15e19437a28b8fb2b4a5249fac770/filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1", size = 16701, upload-time = "2026-01-09T17:55:04.334Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -1037,18 +1037,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
|
|||
# Options:
|
||||
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
|
||||
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
|
||||
# - extensions.logstore.repositories.logstore_workflow_execution_repository.LogstoreWorkflowExecutionRepository
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
|
||||
|
||||
# Core workflow node execution repository implementation
|
||||
# Options:
|
||||
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
|
||||
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
|
||||
# - extensions.logstore.repositories.logstore_workflow_node_execution_repository.LogstoreWorkflowNodeExecutionRepository
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# API workflow run repository implementation
|
||||
# Options:
|
||||
# - repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository (default)
|
||||
# - extensions.logstore.repositories.logstore_api_workflow_run_repository.LogstoreAPIWorkflowRunRepository
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
# API workflow node execution repository implementation
|
||||
# Options:
|
||||
# - repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository (default)
|
||||
# - extensions.logstore.repositories.logstore_api_workflow_node_execution_repository.LogstoreAPIWorkflowNodeExecutionRepository
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# Workflow log cleanup configuration
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ services:
|
|||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.3
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -63,7 +63,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.3
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -102,7 +102,7 @@ services:
|
|||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.3
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -132,7 +132,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.11.2
|
||||
image: langgenius/dify-web:1.11.3
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
|
|||
|
|
@ -704,7 +704,7 @@ services:
|
|||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.3
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -746,7 +746,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.3
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -785,7 +785,7 @@ services:
|
|||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.3
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -815,7 +815,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.11.2
|
||||
image: langgenius/dify-web:1.11.3
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false
|
|||
|
||||
# The timeout for the text generation in millisecond
|
||||
NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000
|
||||
# Used by web/docker/entrypoint.sh to overwrite/export NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS at container startup (Docker only)
|
||||
TEXT_GENERATION_TIMEOUT_MS=60000
|
||||
|
||||
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
|
||||
NEXT_PUBLIC_CSP_WHITELIST=
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ vi.mock('@/context/global-public-context', () => {
|
|||
)
|
||||
return {
|
||||
useGlobalPublicStore,
|
||||
useIsSystemFeaturesPending: () => false,
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ import {
|
|||
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
|
||||
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
|
||||
} from '@/app/education-apply/constants'
|
||||
import { fetchSetupStatus } from '@/service/common'
|
||||
import { sendGAEvent } from '@/utils/gtag'
|
||||
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
|
||||
import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect'
|
||||
import { trackEvent } from './base/amplitude'
|
||||
|
||||
|
|
@ -33,15 +33,8 @@ export const AppInitializer = ({
|
|||
|
||||
const isSetupFinished = useCallback(async () => {
|
||||
try {
|
||||
if (localStorage.getItem('setup_status') === 'finished')
|
||||
return true
|
||||
const setUpStatus = await fetchSetupStatus()
|
||||
if (setUpStatus.step !== 'finished') {
|
||||
localStorage.removeItem('setup_status')
|
||||
return false
|
||||
}
|
||||
localStorage.setItem('setup_status', 'finished')
|
||||
return true
|
||||
const setUpStatus = await fetchSetupStatusWithCache()
|
||||
return setUpStatus.step === 'finished'
|
||||
}
|
||||
catch (error) {
|
||||
console.error(error)
|
||||
|
|
|
|||
|
|
@ -34,13 +34,6 @@ vi.mock('@/context/app-context', () => ({
|
|||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/common', () => ({
|
||||
fetchCurrentWorkspace: vi.fn(),
|
||||
fetchLangGeniusVersion: vi.fn(),
|
||||
fetchUserProfile: vi.fn(),
|
||||
getSystemFeatures: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/access-control', () => ({
|
||||
useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args),
|
||||
useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args),
|
||||
|
|
@ -125,7 +118,6 @@ const resetAccessControlStore = () => {
|
|||
const resetGlobalStore = () => {
|
||||
useGlobalPublicStore.setState({
|
||||
systemFeatures: defaultSystemFeatures,
|
||||
isGlobalPending: false,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ const pageNameEnrichmentPlugin = (): amplitude.Types.EnrichmentPlugin => {
|
|||
}
|
||||
|
||||
const AmplitudeProvider: FC<IAmplitudeProps> = ({
|
||||
sessionReplaySampleRate = 1,
|
||||
sessionReplaySampleRate = 0.5,
|
||||
}) => {
|
||||
useEffect(() => {
|
||||
// Only enable in Saas edition with valid API key
|
||||
|
|
|
|||
|
|
@ -170,8 +170,12 @@ describe('useChatWithHistory', () => {
|
|||
await waitFor(() => {
|
||||
expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', false, 'app-1')
|
||||
})
|
||||
expect(result.current.pinnedConversationList).toEqual(pinnedData.data)
|
||||
expect(result.current.conversationList).toEqual(listData.data)
|
||||
await waitFor(() => {
|
||||
expect(result.current.pinnedConversationList).toEqual(pinnedData.data)
|
||||
})
|
||||
await waitFor(() => {
|
||||
expect(result.current.conversationList).toEqual(listData.data)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
|||
import * as React from 'react'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
|
||||
import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing'
|
||||
import { fetchSubscriptionUrls } from '@/service/billing'
|
||||
import { consoleClient } from '@/service/client'
|
||||
import Toast from '../../../../base/toast'
|
||||
import { ALL_PLANS } from '../../../config'
|
||||
import { Plan } from '../../../type'
|
||||
|
|
@ -21,10 +22,15 @@ vi.mock('@/context/app-context', () => ({
|
|||
}))
|
||||
|
||||
vi.mock('@/service/billing', () => ({
|
||||
fetchBillingUrl: vi.fn(),
|
||||
fetchSubscriptionUrls: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/client', () => ({
|
||||
consoleClient: {
|
||||
billingUrl: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-async-window-open', () => ({
|
||||
useAsyncWindowOpen: vi.fn(),
|
||||
}))
|
||||
|
|
@ -37,7 +43,7 @@ vi.mock('../../assets', () => ({
|
|||
|
||||
const mockUseAppContext = useAppContext as Mock
|
||||
const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock
|
||||
const mockFetchBillingUrl = fetchBillingUrl as Mock
|
||||
const mockBillingUrl = consoleClient.billingUrl as Mock
|
||||
const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock
|
||||
const mockToastNotify = Toast.notify as Mock
|
||||
|
||||
|
|
@ -69,7 +75,7 @@ beforeEach(() => {
|
|||
vi.clearAllMocks()
|
||||
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
|
||||
mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open()))
|
||||
mockFetchBillingUrl.mockResolvedValue({ url: 'https://billing.example' })
|
||||
mockBillingUrl.mockResolvedValue({ url: 'https://billing.example' })
|
||||
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://subscription.example' })
|
||||
assignedHref = ''
|
||||
})
|
||||
|
|
@ -143,7 +149,7 @@ describe('CloudPlanItem', () => {
|
|||
type: 'error',
|
||||
message: 'billing.buyPermissionDeniedTip',
|
||||
}))
|
||||
expect(mockFetchBillingUrl).not.toHaveBeenCalled()
|
||||
expect(mockBillingUrl).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should open billing portal when upgrading current paid plan', async () => {
|
||||
|
|
@ -162,7 +168,7 @@ describe('CloudPlanItem', () => {
|
|||
fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchBillingUrl).toHaveBeenCalledTimes(1)
|
||||
expect(mockBillingUrl).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
expect(openWindow).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ import { useMemo } from 'react'
|
|||
import { useTranslation } from 'react-i18next'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
|
||||
import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing'
|
||||
import { fetchSubscriptionUrls } from '@/service/billing'
|
||||
import { consoleClient } from '@/service/client'
|
||||
import Toast from '../../../../base/toast'
|
||||
import { ALL_PLANS } from '../../../config'
|
||||
import { Plan } from '../../../type'
|
||||
|
|
@ -76,7 +77,7 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
|
|||
try {
|
||||
if (isCurrentPaidPlan) {
|
||||
await openAsyncWindow(async () => {
|
||||
const res = await fetchBillingUrl()
|
||||
const res = await consoleClient.billingUrl()
|
||||
if (res.url)
|
||||
return res.url
|
||||
throw new Error('Failed to open billing page')
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) =
|
|||
category: PluginCategoryEnum.datasource,
|
||||
exclude,
|
||||
type: 'plugin',
|
||||
sortBy: 'install_count',
|
||||
sortOrder: 'DESC',
|
||||
sort_by: 'install_count',
|
||||
sort_order: 'DESC',
|
||||
})
|
||||
}
|
||||
else {
|
||||
|
|
@ -39,10 +39,10 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) =
|
|||
query: '',
|
||||
category: PluginCategoryEnum.datasource,
|
||||
type: 'plugin',
|
||||
pageSize: 1000,
|
||||
page_size: 1000,
|
||||
exclude,
|
||||
sortBy: 'install_count',
|
||||
sortOrder: 'DESC',
|
||||
sort_by: 'install_count',
|
||||
sort_order: 'DESC',
|
||||
})
|
||||
}
|
||||
}, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])
|
||||
|
|
|
|||
|
|
@ -275,8 +275,8 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
|
|||
category: PluginCategoryEnum.model,
|
||||
exclude,
|
||||
type: 'plugin',
|
||||
sortBy: 'install_count',
|
||||
sortOrder: 'DESC',
|
||||
sort_by: 'install_count',
|
||||
sort_order: 'DESC',
|
||||
})
|
||||
}
|
||||
else {
|
||||
|
|
@ -284,10 +284,10 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
|
|||
query: '',
|
||||
category: PluginCategoryEnum.model,
|
||||
type: 'plugin',
|
||||
pageSize: 1000,
|
||||
page_size: 1000,
|
||||
exclude,
|
||||
sortBy: 'install_count',
|
||||
sortOrder: 'DESC',
|
||||
sort_by: 'install_count',
|
||||
sort_order: 'DESC',
|
||||
})
|
||||
}
|
||||
}, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])
|
||||
|
|
|
|||
|
|
@ -100,11 +100,11 @@ export const useMarketplacePlugins = () => {
|
|||
const [queryParams, setQueryParams] = useState<PluginsSearchParams>()
|
||||
|
||||
const normalizeParams = useCallback((pluginsSearchParams: PluginsSearchParams) => {
|
||||
const pageSize = pluginsSearchParams.pageSize || 40
|
||||
const page_size = pluginsSearchParams.page_size || 40
|
||||
|
||||
return {
|
||||
...pluginsSearchParams,
|
||||
pageSize,
|
||||
page_size,
|
||||
}
|
||||
}, [])
|
||||
|
||||
|
|
@ -116,20 +116,20 @@ export const useMarketplacePlugins = () => {
|
|||
plugins: [] as Plugin[],
|
||||
total: 0,
|
||||
page: 1,
|
||||
pageSize: 40,
|
||||
page_size: 40,
|
||||
}
|
||||
}
|
||||
|
||||
const params = normalizeParams(queryParams)
|
||||
const {
|
||||
query,
|
||||
sortBy,
|
||||
sortOrder,
|
||||
sort_by,
|
||||
sort_order,
|
||||
category,
|
||||
tags,
|
||||
exclude,
|
||||
type,
|
||||
pageSize,
|
||||
page_size,
|
||||
} = params
|
||||
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
|
||||
|
||||
|
|
@ -137,10 +137,10 @@ export const useMarketplacePlugins = () => {
|
|||
const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
|
||||
body: {
|
||||
page: pageParam,
|
||||
page_size: pageSize,
|
||||
page_size,
|
||||
query,
|
||||
sort_by: sortBy,
|
||||
sort_order: sortOrder,
|
||||
sort_by,
|
||||
sort_order,
|
||||
category: category !== 'all' ? category : '',
|
||||
tags,
|
||||
exclude,
|
||||
|
|
@ -154,7 +154,7 @@ export const useMarketplacePlugins = () => {
|
|||
plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)),
|
||||
total: res.data.total,
|
||||
page: pageParam,
|
||||
pageSize,
|
||||
page_size,
|
||||
}
|
||||
}
|
||||
catch {
|
||||
|
|
@ -162,13 +162,13 @@ export const useMarketplacePlugins = () => {
|
|||
plugins: [],
|
||||
total: 0,
|
||||
page: pageParam,
|
||||
pageSize,
|
||||
page_size,
|
||||
}
|
||||
}
|
||||
},
|
||||
getNextPageParam: (lastPage) => {
|
||||
const nextPage = lastPage.page + 1
|
||||
const loaded = lastPage.page * lastPage.pageSize
|
||||
const loaded = lastPage.page * lastPage.page_size
|
||||
return loaded < (lastPage.total || 0) ? nextPage : undefined
|
||||
},
|
||||
initialPageParam: 1,
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ import type { SearchParams } from 'nuqs'
|
|||
import { dehydrate, HydrationBoundary } from '@tanstack/react-query'
|
||||
import { createLoader } from 'nuqs/server'
|
||||
import { getQueryClientServer } from '@/context/query-client-server'
|
||||
import { marketplaceQuery } from '@/service/client'
|
||||
import { PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants'
|
||||
import { marketplaceKeys } from './query'
|
||||
import { marketplaceSearchParamsParsers } from './search-params'
|
||||
import { getCollectionsParams, getMarketplaceCollectionsAndPlugins } from './utils'
|
||||
|
||||
|
|
@ -23,7 +23,7 @@ async function getDehydratedState(searchParams?: Promise<SearchParams>) {
|
|||
const queryClient = getQueryClientServer()
|
||||
|
||||
await queryClient.prefetchQuery({
|
||||
queryKey: marketplaceKeys.collections(getCollectionsParams(params.category)),
|
||||
queryKey: marketplaceQuery.collections.queryKey({ input: { query: getCollectionsParams(params.category) } }),
|
||||
queryFn: () => getMarketplaceCollectionsAndPlugins(getCollectionsParams(params.category)),
|
||||
})
|
||||
return dehydrate(queryClient)
|
||||
|
|
|
|||
|
|
@ -60,10 +60,10 @@ vi.mock('@/service/use-plugins', () => ({
|
|||
// Mock tanstack query
|
||||
const mockFetchNextPage = vi.fn()
|
||||
const mockHasNextPage = false
|
||||
let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, pageSize: number }> } | undefined
|
||||
let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, page_size: number }> } | undefined
|
||||
let capturedInfiniteQueryFn: ((ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown>) | null = null
|
||||
let capturedQueryFn: ((ctx: { signal: AbortSignal }) => Promise<unknown>) | null = null
|
||||
let capturedGetNextPageParam: ((lastPage: { page: number, pageSize: number, total: number }) => number | undefined) | null = null
|
||||
let capturedGetNextPageParam: ((lastPage: { page: number, page_size: number, total: number }) => number | undefined) | null = null
|
||||
|
||||
vi.mock('@tanstack/react-query', () => ({
|
||||
useQuery: vi.fn(({ queryFn, enabled }: { queryFn: (ctx: { signal: AbortSignal }) => Promise<unknown>, enabled: boolean }) => {
|
||||
|
|
@ -83,7 +83,7 @@ vi.mock('@tanstack/react-query', () => ({
|
|||
}),
|
||||
useInfiniteQuery: vi.fn(({ queryFn, getNextPageParam, enabled: _enabled }: {
|
||||
queryFn: (ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown>
|
||||
getNextPageParam: (lastPage: { page: number, pageSize: number, total: number }) => number | undefined
|
||||
getNextPageParam: (lastPage: { page: number, page_size: number, total: number }) => number | undefined
|
||||
enabled: boolean
|
||||
}) => {
|
||||
// Capture queryFn and getNextPageParam for later testing
|
||||
|
|
@ -97,9 +97,9 @@ vi.mock('@tanstack/react-query', () => ({
|
|||
// Call getNextPageParam to increase coverage
|
||||
if (getNextPageParam) {
|
||||
// Test with more data available
|
||||
getNextPageParam({ page: 1, pageSize: 40, total: 100 })
|
||||
getNextPageParam({ page: 1, page_size: 40, total: 100 })
|
||||
// Test with no more data
|
||||
getNextPageParam({ page: 3, pageSize: 40, total: 100 })
|
||||
getNextPageParam({ page: 3, page_size: 40, total: 100 })
|
||||
}
|
||||
return {
|
||||
data: mockInfiniteQueryData,
|
||||
|
|
@ -151,6 +151,7 @@ vi.mock('@/service/base', () => ({
|
|||
|
||||
// Mock config
|
||||
vi.mock('@/config', () => ({
|
||||
API_PREFIX: '/api',
|
||||
APP_VERSION: '1.0.0',
|
||||
IS_MARKETPLACE: false,
|
||||
MARKETPLACE_API_PREFIX: 'https://marketplace.dify.ai/api/v1',
|
||||
|
|
@ -731,10 +732,10 @@ describe('useMarketplacePlugins', () => {
|
|||
expect(() => {
|
||||
result.current.queryPlugins({
|
||||
query: 'test',
|
||||
sortBy: 'install_count',
|
||||
sortOrder: 'DESC',
|
||||
sort_by: 'install_count',
|
||||
sort_order: 'DESC',
|
||||
category: 'tool',
|
||||
pageSize: 20,
|
||||
page_size: 20,
|
||||
})
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
|
@ -747,7 +748,7 @@ describe('useMarketplacePlugins', () => {
|
|||
result.current.queryPlugins({
|
||||
query: 'test',
|
||||
type: 'bundle',
|
||||
pageSize: 40,
|
||||
page_size: 40,
|
||||
})
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
|
@ -798,8 +799,8 @@ describe('useMarketplacePlugins', () => {
|
|||
result.current.queryPlugins({
|
||||
query: 'test',
|
||||
category: 'all',
|
||||
sortBy: 'install_count',
|
||||
sortOrder: 'DESC',
|
||||
sort_by: 'install_count',
|
||||
sort_order: 'DESC',
|
||||
})
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
|
@ -824,7 +825,7 @@ describe('useMarketplacePlugins', () => {
|
|||
expect(() => {
|
||||
result.current.queryPlugins({
|
||||
query: 'test',
|
||||
pageSize: 100,
|
||||
page_size: 100,
|
||||
})
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
|
@ -843,7 +844,7 @@ describe('Hooks queryFn Coverage', () => {
|
|||
// Set mock data to have pages
|
||||
mockInfiniteQueryData = {
|
||||
pages: [
|
||||
{ plugins: [{ name: 'plugin1' }], total: 10, page: 1, pageSize: 40 },
|
||||
{ plugins: [{ name: 'plugin1' }], total: 10, page: 1, page_size: 40 },
|
||||
],
|
||||
}
|
||||
|
||||
|
|
@ -863,8 +864,8 @@ describe('Hooks queryFn Coverage', () => {
|
|||
it('should expose page and total from infinite query data', async () => {
|
||||
mockInfiniteQueryData = {
|
||||
pages: [
|
||||
{ plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, pageSize: 40 },
|
||||
{ plugins: [{ name: 'plugin3' }], total: 20, page: 2, pageSize: 40 },
|
||||
{ plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, page_size: 40 },
|
||||
{ plugins: [{ name: 'plugin3' }], total: 20, page: 2, page_size: 40 },
|
||||
],
|
||||
}
|
||||
|
||||
|
|
@ -893,7 +894,7 @@ describe('Hooks queryFn Coverage', () => {
|
|||
it('should return total from first page when query is set and data exists', async () => {
|
||||
mockInfiniteQueryData = {
|
||||
pages: [
|
||||
{ plugins: [], total: 50, page: 1, pageSize: 40 },
|
||||
{ plugins: [], total: 50, page: 1, page_size: 40 },
|
||||
],
|
||||
}
|
||||
|
||||
|
|
@ -917,8 +918,8 @@ describe('Hooks queryFn Coverage', () => {
|
|||
type: 'plugin',
|
||||
query: 'search test',
|
||||
category: 'model',
|
||||
sortBy: 'version_updated_at',
|
||||
sortOrder: 'ASC',
|
||||
sort_by: 'version_updated_at',
|
||||
sort_order: 'ASC',
|
||||
})
|
||||
|
||||
expect(result.current).toBeDefined()
|
||||
|
|
@ -1027,13 +1028,13 @@ describe('Advanced Hook Integration', () => {
|
|||
// Test with all possible parameters
|
||||
result.current.queryPlugins({
|
||||
query: 'comprehensive test',
|
||||
sortBy: 'install_count',
|
||||
sortOrder: 'DESC',
|
||||
sort_by: 'install_count',
|
||||
sort_order: 'DESC',
|
||||
category: 'tool',
|
||||
tags: ['tag1', 'tag2'],
|
||||
exclude: ['excluded-plugin'],
|
||||
type: 'plugin',
|
||||
pageSize: 50,
|
||||
page_size: 50,
|
||||
})
|
||||
|
||||
expect(result.current).toBeDefined()
|
||||
|
|
@ -1081,9 +1082,9 @@ describe('Direct queryFn Coverage', () => {
|
|||
result.current.queryPlugins({
|
||||
query: 'direct test',
|
||||
category: 'tool',
|
||||
sortBy: 'install_count',
|
||||
sortOrder: 'DESC',
|
||||
pageSize: 40,
|
||||
sort_by: 'install_count',
|
||||
sort_order: 'DESC',
|
||||
page_size: 40,
|
||||
})
|
||||
|
||||
// Now queryFn should be captured and enabled
|
||||
|
|
@ -1255,7 +1256,7 @@ describe('Direct queryFn Coverage', () => {
|
|||
|
||||
result.current.queryPlugins({
|
||||
query: 'structure test',
|
||||
pageSize: 20,
|
||||
page_size: 20,
|
||||
})
|
||||
|
||||
if (capturedInfiniteQueryFn) {
|
||||
|
|
@ -1264,14 +1265,14 @@ describe('Direct queryFn Coverage', () => {
|
|||
plugins: unknown[]
|
||||
total: number
|
||||
page: number
|
||||
pageSize: number
|
||||
page_size: number
|
||||
}
|
||||
|
||||
// Verify the returned structure
|
||||
expect(response).toHaveProperty('plugins')
|
||||
expect(response).toHaveProperty('total')
|
||||
expect(response).toHaveProperty('page')
|
||||
expect(response).toHaveProperty('pageSize')
|
||||
expect(response).toHaveProperty('page_size')
|
||||
}
|
||||
})
|
||||
})
|
||||
|
|
@ -1296,7 +1297,7 @@ describe('flatMap Coverage', () => {
|
|||
],
|
||||
total: 5,
|
||||
page: 1,
|
||||
pageSize: 40,
|
||||
page_size: 40,
|
||||
},
|
||||
{
|
||||
plugins: [
|
||||
|
|
@ -1304,7 +1305,7 @@ describe('flatMap Coverage', () => {
|
|||
],
|
||||
total: 5,
|
||||
page: 2,
|
||||
pageSize: 40,
|
||||
page_size: 40,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
|
@ -1336,8 +1337,8 @@ describe('flatMap Coverage', () => {
|
|||
it('should test hook with pages data for flatMap path', async () => {
|
||||
mockInfiniteQueryData = {
|
||||
pages: [
|
||||
{ plugins: [], total: 100, page: 1, pageSize: 40 },
|
||||
{ plugins: [], total: 100, page: 2, pageSize: 40 },
|
||||
{ plugins: [], total: 100, page: 1, page_size: 40 },
|
||||
{ plugins: [], total: 100, page: 2, page_size: 40 },
|
||||
],
|
||||
}
|
||||
|
||||
|
|
@ -1371,7 +1372,7 @@ describe('flatMap Coverage', () => {
|
|||
plugins: unknown[]
|
||||
total: number
|
||||
page: number
|
||||
pageSize: number
|
||||
page_size: number
|
||||
}
|
||||
// When error is caught, should return fallback data
|
||||
expect(response.plugins).toEqual([])
|
||||
|
|
@ -1392,15 +1393,15 @@ describe('flatMap Coverage', () => {
|
|||
// Test getNextPageParam function directly
|
||||
if (capturedGetNextPageParam) {
|
||||
// When there are more pages
|
||||
const nextPage = capturedGetNextPageParam({ page: 1, pageSize: 40, total: 100 })
|
||||
const nextPage = capturedGetNextPageParam({ page: 1, page_size: 40, total: 100 })
|
||||
expect(nextPage).toBe(2)
|
||||
|
||||
// When all data is loaded
|
||||
const noMorePages = capturedGetNextPageParam({ page: 3, pageSize: 40, total: 100 })
|
||||
const noMorePages = capturedGetNextPageParam({ page: 3, page_size: 40, total: 100 })
|
||||
expect(noMorePages).toBeUndefined()
|
||||
|
||||
// Edge case: exactly at boundary
|
||||
const atBoundary = capturedGetNextPageParam({ page: 2, pageSize: 50, total: 100 })
|
||||
const atBoundary = capturedGetNextPageParam({ page: 2, page_size: 50, total: 100 })
|
||||
expect(atBoundary).toBeUndefined()
|
||||
}
|
||||
})
|
||||
|
|
@ -1427,7 +1428,7 @@ describe('flatMap Coverage', () => {
|
|||
plugins: unknown[]
|
||||
total: number
|
||||
page: number
|
||||
pageSize: number
|
||||
page_size: number
|
||||
}
|
||||
// Catch block should return fallback values
|
||||
expect(response.plugins).toEqual([])
|
||||
|
|
@ -1446,7 +1447,7 @@ describe('flatMap Coverage', () => {
|
|||
plugins: [{ name: 'test-plugin-1' }, { name: 'test-plugin-2' }],
|
||||
total: 10,
|
||||
page: 1,
|
||||
pageSize: 40,
|
||||
page_size: 40,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
|
@ -1489,9 +1490,12 @@ describe('Async Utils', () => {
|
|||
{ type: 'plugin', org: 'test', name: 'plugin2' },
|
||||
]
|
||||
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
|
||||
})
|
||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
|
||||
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
||||
const result = await getMarketplacePluginsByCollectionId('test-collection', {
|
||||
|
|
@ -1514,19 +1518,26 @@ describe('Async Utils', () => {
|
|||
})
|
||||
|
||||
it('should pass abort signal when provided', async () => {
|
||||
const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }]
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
|
||||
})
|
||||
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
|
||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
|
||||
const controller = new AbortController()
|
||||
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
||||
await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal })
|
||||
|
||||
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
|
||||
expect(globalThis.fetch).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({ signal: controller.signal }),
|
||||
expect.any(Request),
|
||||
expect.any(Object),
|
||||
)
|
||||
const call = vi.mocked(globalThis.fetch).mock.calls[0]
|
||||
const request = call[0] as Request
|
||||
expect(request.url).toContain('test-collection')
|
||||
})
|
||||
})
|
||||
|
||||
|
|
@ -1535,19 +1546,25 @@ describe('Async Utils', () => {
|
|||
const mockCollections = [
|
||||
{ name: 'collection1', label: {}, description: {}, rule: '', created_at: '', updated_at: '' },
|
||||
]
|
||||
const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }]
|
||||
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
|
||||
|
||||
let callCount = 0
|
||||
globalThis.fetch = vi.fn().mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({
|
||||
json: () => Promise.resolve({ data: { collections: mockCollections } }),
|
||||
})
|
||||
return Promise.resolve(
|
||||
new Response(JSON.stringify({ data: { collections: mockCollections } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
}
|
||||
return Promise.resolve({
|
||||
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
|
||||
})
|
||||
return Promise.resolve(
|
||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
||||
|
|
@ -1571,9 +1588,12 @@ describe('Async Utils', () => {
|
|||
})
|
||||
|
||||
it('should append condition and type to URL when provided', async () => {
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
json: () => Promise.resolve({ data: { collections: [] } }),
|
||||
})
|
||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
||||
new Response(JSON.stringify({ data: { collections: [] } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
|
||||
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
||||
await getMarketplaceCollectionsAndPlugins({
|
||||
|
|
@ -1581,10 +1601,11 @@ describe('Async Utils', () => {
|
|||
type: 'bundle',
|
||||
})
|
||||
|
||||
expect(globalThis.fetch).toHaveBeenCalledWith(
|
||||
expect.stringContaining('condition=category=tool'),
|
||||
expect.any(Object),
|
||||
)
|
||||
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
|
||||
expect(globalThis.fetch).toHaveBeenCalled()
|
||||
const call = vi.mocked(globalThis.fetch).mock.calls[0]
|
||||
const request = call[0] as Request
|
||||
expect(request.url).toContain('condition=category%3Dtool')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,22 +1,14 @@
|
|||
import type { CollectionsAndPluginsSearchParams, PluginsSearchParams } from './types'
|
||||
import type { PluginsSearchParams } from './types'
|
||||
import type { MarketPlaceInputs } from '@/contract/router'
|
||||
import { useInfiniteQuery, useQuery } from '@tanstack/react-query'
|
||||
import { marketplaceQuery } from '@/service/client'
|
||||
import { getMarketplaceCollectionsAndPlugins, getMarketplacePlugins } from './utils'
|
||||
|
||||
// TODO: Avoid manual maintenance of query keys and better service management,
|
||||
// https://github.com/langgenius/dify/issues/30342
|
||||
|
||||
export const marketplaceKeys = {
|
||||
all: ['marketplace'] as const,
|
||||
collections: (params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collections', params] as const,
|
||||
collectionPlugins: (collectionId: string, params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collectionPlugins', collectionId, params] as const,
|
||||
plugins: (params?: PluginsSearchParams) => [...marketplaceKeys.all, 'plugins', params] as const,
|
||||
}
|
||||
|
||||
export function useMarketplaceCollectionsAndPlugins(
|
||||
collectionsParams: CollectionsAndPluginsSearchParams,
|
||||
collectionsParams: MarketPlaceInputs['collections']['query'],
|
||||
) {
|
||||
return useQuery({
|
||||
queryKey: marketplaceKeys.collections(collectionsParams),
|
||||
queryKey: marketplaceQuery.collections.queryKey({ input: { query: collectionsParams } }),
|
||||
queryFn: ({ signal }) => getMarketplaceCollectionsAndPlugins(collectionsParams, { signal }),
|
||||
})
|
||||
}
|
||||
|
|
@ -25,11 +17,16 @@ export function useMarketplacePlugins(
|
|||
queryParams: PluginsSearchParams | undefined,
|
||||
) {
|
||||
return useInfiniteQuery({
|
||||
queryKey: marketplaceKeys.plugins(queryParams),
|
||||
queryKey: marketplaceQuery.searchAdvanced.queryKey({
|
||||
input: {
|
||||
body: queryParams!,
|
||||
params: { kind: queryParams?.type === 'bundle' ? 'bundles' : 'plugins' },
|
||||
},
|
||||
}),
|
||||
queryFn: ({ pageParam = 1, signal }) => getMarketplacePlugins(queryParams, pageParam, signal),
|
||||
getNextPageParam: (lastPage) => {
|
||||
const nextPage = lastPage.page + 1
|
||||
const loaded = lastPage.page * lastPage.pageSize
|
||||
const loaded = lastPage.page * lastPage.page_size
|
||||
return loaded < (lastPage.total || 0) ? nextPage : undefined
|
||||
},
|
||||
initialPageParam: 1,
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ export function useMarketplaceData() {
|
|||
query: searchPluginText,
|
||||
category: activePluginType === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginType,
|
||||
tags: filterPluginTags,
|
||||
sortBy: sort.sortBy,
|
||||
sortOrder: sort.sortOrder,
|
||||
sort_by: sort.sortBy,
|
||||
sort_order: sort.sortOrder,
|
||||
type: getMarketplaceListFilterType(activePluginType),
|
||||
}
|
||||
}, [isSearchMode, searchPluginText, activePluginType, filterPluginTags, sort])
|
||||
|
|
|
|||
|
|
@ -30,9 +30,9 @@ export type MarketplaceCollectionPluginsResponse = {
|
|||
export type PluginsSearchParams = {
|
||||
query: string
|
||||
page?: number
|
||||
pageSize?: number
|
||||
sortBy?: string
|
||||
sortOrder?: string
|
||||
page_size?: number
|
||||
sort_by?: string
|
||||
sort_order?: string
|
||||
category?: string
|
||||
tags?: string[]
|
||||
exclude?: string[]
|
||||
|
|
|
|||
|
|
@ -4,14 +4,12 @@ import type {
|
|||
MarketplaceCollection,
|
||||
PluginsSearchParams,
|
||||
} from '@/app/components/plugins/marketplace/types'
|
||||
import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types'
|
||||
import type { Plugin } from '@/app/components/plugins/types'
|
||||
import { PluginCategoryEnum } from '@/app/components/plugins/types'
|
||||
import {
|
||||
APP_VERSION,
|
||||
IS_MARKETPLACE,
|
||||
MARKETPLACE_API_PREFIX,
|
||||
} from '@/config'
|
||||
import { postMarketplace } from '@/service/base'
|
||||
import { marketplaceClient } from '@/service/client'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
import { PLUGIN_TYPE_SEARCH_MAP } from './constants'
|
||||
|
||||
|
|
@ -19,10 +17,6 @@ type MarketplaceFetchOptions = {
|
|||
signal?: AbortSignal
|
||||
}
|
||||
|
||||
const getMarketplaceHeaders = () => new Headers({
|
||||
'X-Dify-Version': !IS_MARKETPLACE ? APP_VERSION : '999.0.0',
|
||||
})
|
||||
|
||||
export const getPluginIconInMarketplace = (plugin: Plugin) => {
|
||||
if (plugin.type === 'bundle')
|
||||
return `${MARKETPLACE_API_PREFIX}/bundles/${plugin.org}/${plugin.name}/icon`
|
||||
|
|
@ -65,24 +59,15 @@ export const getMarketplacePluginsByCollectionId = async (
|
|||
let plugins: Plugin[] = []
|
||||
|
||||
try {
|
||||
const url = `${MARKETPLACE_API_PREFIX}/collections/${collectionId}/plugins`
|
||||
const headers = getMarketplaceHeaders()
|
||||
const marketplaceCollectionPluginsData = await globalThis.fetch(
|
||||
url,
|
||||
{
|
||||
cache: 'no-store',
|
||||
method: 'POST',
|
||||
headers,
|
||||
signal: options?.signal,
|
||||
body: JSON.stringify({
|
||||
category: query?.category,
|
||||
exclude: query?.exclude,
|
||||
type: query?.type,
|
||||
}),
|
||||
const marketplaceCollectionPluginsDataJson = await marketplaceClient.collectionPlugins({
|
||||
params: {
|
||||
collectionId,
|
||||
},
|
||||
)
|
||||
const marketplaceCollectionPluginsDataJson = await marketplaceCollectionPluginsData.json()
|
||||
plugins = (marketplaceCollectionPluginsDataJson.data.plugins || []).map((plugin: Plugin) => getFormattedPlugin(plugin))
|
||||
body: query,
|
||||
}, {
|
||||
signal: options?.signal,
|
||||
})
|
||||
plugins = (marketplaceCollectionPluginsDataJson.data?.plugins || []).map(plugin => getFormattedPlugin(plugin))
|
||||
}
|
||||
// eslint-disable-next-line unused-imports/no-unused-vars
|
||||
catch (e) {
|
||||
|
|
@ -99,22 +84,16 @@ export const getMarketplaceCollectionsAndPlugins = async (
|
|||
let marketplaceCollections: MarketplaceCollection[] = []
|
||||
let marketplaceCollectionPluginsMap: Record<string, Plugin[]> = {}
|
||||
try {
|
||||
let marketplaceUrl = `${MARKETPLACE_API_PREFIX}/collections?page=1&page_size=100`
|
||||
if (query?.condition)
|
||||
marketplaceUrl += `&condition=${query.condition}`
|
||||
if (query?.type)
|
||||
marketplaceUrl += `&type=${query.type}`
|
||||
const headers = getMarketplaceHeaders()
|
||||
const marketplaceCollectionsData = await globalThis.fetch(
|
||||
marketplaceUrl,
|
||||
{
|
||||
headers,
|
||||
cache: 'no-store',
|
||||
signal: options?.signal,
|
||||
const marketplaceCollectionsDataJson = await marketplaceClient.collections({
|
||||
query: {
|
||||
...query,
|
||||
page: 1,
|
||||
page_size: 100,
|
||||
},
|
||||
)
|
||||
const marketplaceCollectionsDataJson = await marketplaceCollectionsData.json()
|
||||
marketplaceCollections = marketplaceCollectionsDataJson.data.collections || []
|
||||
}, {
|
||||
signal: options?.signal,
|
||||
})
|
||||
marketplaceCollections = marketplaceCollectionsDataJson.data?.collections || []
|
||||
await Promise.all(marketplaceCollections.map(async (collection: MarketplaceCollection) => {
|
||||
const plugins = await getMarketplacePluginsByCollectionId(collection.name, query, options)
|
||||
|
||||
|
|
@ -143,42 +122,42 @@ export const getMarketplacePlugins = async (
|
|||
plugins: [] as Plugin[],
|
||||
total: 0,
|
||||
page: 1,
|
||||
pageSize: 40,
|
||||
page_size: 40,
|
||||
}
|
||||
}
|
||||
|
||||
const {
|
||||
query,
|
||||
sortBy,
|
||||
sortOrder,
|
||||
sort_by,
|
||||
sort_order,
|
||||
category,
|
||||
tags,
|
||||
type,
|
||||
pageSize = 40,
|
||||
page_size = 40,
|
||||
} = queryParams
|
||||
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
|
||||
|
||||
try {
|
||||
const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
|
||||
const res = await marketplaceClient.searchAdvanced({
|
||||
params: {
|
||||
kind: type === 'bundle' ? 'bundles' : 'plugins',
|
||||
},
|
||||
body: {
|
||||
page: pageParam,
|
||||
page_size: pageSize,
|
||||
page_size,
|
||||
query,
|
||||
sort_by: sortBy,
|
||||
sort_order: sortOrder,
|
||||
sort_by,
|
||||
sort_order,
|
||||
category: category !== 'all' ? category : '',
|
||||
tags,
|
||||
type,
|
||||
},
|
||||
signal,
|
||||
})
|
||||
}, { signal })
|
||||
const resPlugins = res.data.bundles || res.data.plugins || []
|
||||
|
||||
return {
|
||||
plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)),
|
||||
total: res.data.total,
|
||||
page: pageParam,
|
||||
pageSize,
|
||||
page_size,
|
||||
}
|
||||
}
|
||||
catch {
|
||||
|
|
@ -186,7 +165,7 @@ export const getMarketplacePlugins = async (
|
|||
plugins: [],
|
||||
total: 0,
|
||||
page: pageParam,
|
||||
pageSize,
|
||||
page_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1606,6 +1606,7 @@ export const useNodesInteractions = () => {
|
|||
const offsetX = currentPosition.x - x
|
||||
const offsetY = currentPosition.y - y
|
||||
let idMapping: Record<string, string> = {}
|
||||
const parentChildrenToAppend: { parentId: string, childId: string, childType: BlockEnum }[] = []
|
||||
clipboardElements.forEach((nodeToPaste, index) => {
|
||||
const nodeType = nodeToPaste.data.type
|
||||
|
||||
|
|
@ -1619,6 +1620,7 @@ export const useNodesInteractions = () => {
|
|||
_isBundled: false,
|
||||
_connectedSourceHandleIds: [],
|
||||
_connectedTargetHandleIds: [],
|
||||
_dimmed: false,
|
||||
title: genNewNodeTitleFromOld(nodeToPaste.data.title),
|
||||
},
|
||||
position: {
|
||||
|
|
@ -1686,27 +1688,24 @@ export const useNodesInteractions = () => {
|
|||
return
|
||||
|
||||
// handle paste to nested block
|
||||
if (selectedNode.data.type === BlockEnum.Iteration) {
|
||||
newNode.data.isInIteration = true
|
||||
newNode.data.iteration_id = selectedNode.data.iteration_id
|
||||
newNode.parentId = selectedNode.id
|
||||
newNode.positionAbsolute = {
|
||||
x: newNode.position.x,
|
||||
y: newNode.position.y,
|
||||
}
|
||||
// set position base on parent node
|
||||
newNode.position = getNestedNodePosition(newNode, selectedNode)
|
||||
}
|
||||
else if (selectedNode.data.type === BlockEnum.Loop) {
|
||||
newNode.data.isInLoop = true
|
||||
newNode.data.loop_id = selectedNode.data.loop_id
|
||||
if (selectedNode.data.type === BlockEnum.Iteration || selectedNode.data.type === BlockEnum.Loop) {
|
||||
const isIteration = selectedNode.data.type === BlockEnum.Iteration
|
||||
|
||||
newNode.data.isInIteration = isIteration
|
||||
newNode.data.iteration_id = isIteration ? selectedNode.id : undefined
|
||||
newNode.data.isInLoop = !isIteration
|
||||
newNode.data.loop_id = !isIteration ? selectedNode.id : undefined
|
||||
|
||||
newNode.parentId = selectedNode.id
|
||||
newNode.zIndex = isIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX
|
||||
newNode.positionAbsolute = {
|
||||
x: newNode.position.x,
|
||||
y: newNode.position.y,
|
||||
}
|
||||
// set position base on parent node
|
||||
newNode.position = getNestedNodePosition(newNode, selectedNode)
|
||||
// update parent children array like native add
|
||||
parentChildrenToAppend.push({ parentId: selectedNode.id, childId: newNode.id, childType: newNode.data.type })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1737,7 +1736,17 @@ export const useNodesInteractions = () => {
|
|||
}
|
||||
})
|
||||
|
||||
setNodes([...nodes, ...nodesToPaste])
|
||||
const newNodes = produce(nodes, (draft: Node[]) => {
|
||||
parentChildrenToAppend.forEach(({ parentId, childId, childType }) => {
|
||||
const p = draft.find(n => n.id === parentId)
|
||||
if (p) {
|
||||
p.data._children?.push({ nodeId: childId, nodeType: childType })
|
||||
}
|
||||
})
|
||||
draft.push(...nodesToPaste)
|
||||
})
|
||||
|
||||
setNodes(newNodes)
|
||||
setEdges([...edges, ...edgesToPaste])
|
||||
saveStateToHistory(WorkflowHistoryEvent.NodePaste, {
|
||||
nodeId: nodesToPaste?.[0]?.id,
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue