Merge branch 'main' into feat/rag-2

This commit is contained in:
twwu 2025-08-27 17:28:21 +08:00
commit ee144452e2
47 changed files with 3150 additions and 3614 deletions

View File

@ -1,6 +1,6 @@
#!/bin/bash
npm add -g pnpm@10.15.0
corepack enable
cd web && pnpm install
pipx install uv

View File

@ -95,7 +95,6 @@ class ToolBuiltinProviderInfoApi(Resource):
def get(self, provider):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))

View File

@ -1,19 +1,20 @@
from flask import Blueprint
from flask_restx import Namespace
from libs.external_api import ExternalApi
from .files import FileApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
bp = Blueprint("web", __name__, url_prefix="/api")
api = ExternalApi(bp)
# Files
api.add_resource(FileApi, "/files/upload")
api = ExternalApi(
bp,
version="1.0",
title="Web API",
description="Public APIs for web applications including file uploads, chat interactions, and app management",
doc="/docs", # Enable Swagger UI at /api/docs
)
# Remote files
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Create namespace
web_ns = Namespace("web", description="Web application API operations", path="/")
from . import (
app,
@ -21,11 +22,15 @@ from . import (
completion,
conversation,
feature,
files,
forgot_password,
login,
message,
passport,
remote_files,
saved_message,
site,
workflow,
)
api.add_namespace(web_ns)

View File

@ -1,12 +1,21 @@
from flask_restx import Resource
from controllers.web import api
from controllers.web import web_ns
from services.feature_service import FeatureService
@web_ns.route("/system-features")
class SystemFeatureApi(Resource):
@web_ns.doc("get_system_features")
@web_ns.doc(description="Get system feature flags and configuration")
@web_ns.doc(responses={200: "System features retrieved successfully", 500: "Internal server error"})
def get(self):
"""Get system feature flags and configuration.
Returns the current system feature flags and configuration
that control various functionalities across the platform.
Returns:
dict: System feature configuration object
"""
return FeatureService.get_system_features().model_dump()
api.add_resource(SystemFeatureApi, "/system-features")

View File

@ -9,14 +9,50 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from fields.file_fields import file_fields
from fields.file_fields import build_file_model
from services.file_service import FileService
@web_ns.route("/files/upload")
class FileApi(WebApiResource):
@marshal_with(file_fields)
@web_ns.doc("upload_file")
@web_ns.doc(description="Upload a file for use in web applications")
@web_ns.doc(
responses={
201: "File uploaded successfully",
400: "Bad request - invalid file or parameters",
413: "File too large",
415: "Unsupported file type",
}
)
@marshal_with(build_file_model(web_ns))
def post(self, app_model, end_user):
"""Upload a file for use in web applications.
Accepts file uploads for use within web applications, supporting
multiple file types with automatic validation and storage.
Args:
app_model: The associated application model
end_user: The end user uploading the file
Form Parameters:
file: The file to upload (required)
source: Optional source type (datasets or None)
Returns:
dict: File information including ID, URL, and metadata
int: HTTP status code 201 for success
Raises:
NoFileUploadedError: No file provided in request
TooManyFilesError: Multiple files provided (only one allowed)
FilenameNotExistsError: File has no filename
FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported
"""
if "file" not in request.files:
raise NoFileUploadedError()

View File

@ -16,7 +16,7 @@ from controllers.console.auth.error import (
)
from controllers.console.error import EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import api
from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
@ -24,10 +24,21 @@ from models.account import Account
from services.account_service import AccountService
@web_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@web_ns.doc("send_forgot_password_email")
@web_ns.doc(description="Send password reset email")
@web_ns.doc(
responses={
200: "Password reset email sent successfully",
400: "Bad request - invalid email format",
404: "Account not found",
429: "Too many requests - rate limit exceeded",
}
)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
@ -54,10 +65,16 @@ class ForgotPasswordSendEmailApi(Resource):
return {"result": "success", "data": token}
@web_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@web_ns.doc("check_forgot_password_token")
@web_ns.doc(description="Verify password reset token validity")
@web_ns.doc(
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
@ -94,10 +111,21 @@ class ForgotPasswordCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@web_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@web_ns.doc("reset_password")
@web_ns.doc(description="Reset user password with verification token")
@web_ns.doc(
responses={
200: "Password reset successfully",
400: "Bad request - invalid parameters or password mismatch",
401: "Invalid or expired token",
404: "Account not found",
}
)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
@ -141,8 +169,3 @@ class ForgotPasswordResetApi(Resource):
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@ -9,18 +9,30 @@ from controllers.console.auth.error import (
)
from controllers.console.error import AccountBannedError
from controllers.console.wraps import only_edition_enterprise, setup_required
from controllers.web import api
from controllers.web import web_ns
from libs.helper import email
from libs.password import valid_password
from services.account_service import AccountService
from services.webapp_auth_service import WebAppAuthService
@web_ns.route("/login")
class LoginApi(Resource):
"""Resource for web app email/password login."""
@setup_required
@only_edition_enterprise
@web_ns.doc("web_app_login")
@web_ns.doc(description="Authenticate user for web application access")
@web_ns.doc(
responses={
200: "Authentication successful",
400: "Bad request - invalid email or password format",
401: "Authentication failed - email or password mismatch",
403: "Account banned or login disabled",
404: "Account not found",
}
)
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
@ -51,9 +63,19 @@ class LoginApi(Resource):
# return {"result": "success"}
@web_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
@only_edition_enterprise
@web_ns.doc("send_email_code_login")
@web_ns.doc(description="Send email verification code for login")
@web_ns.doc(
responses={
200: "Email code sent successfully",
400: "Bad request - invalid email format",
404: "Account not found",
}
)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
@ -74,9 +96,20 @@ class EmailCodeLoginSendEmailApi(Resource):
return {"result": "success", "data": token}
@web_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource):
@setup_required
@only_edition_enterprise
@web_ns.doc("verify_email_code_login")
@web_ns.doc(description="Verify email code and complete login")
@web_ns.doc(
responses={
200: "Email code verified and login successful",
400: "Bad request - invalid code or token",
401: "Invalid token or expired code",
404: "Account not found",
}
)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
@ -104,9 +137,3 @@ class EmailCodeLoginApi(Resource):
token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": {"access_token": token}}
api.add_resource(LoginApi, "/login")
# api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")

View File

@ -7,7 +7,7 @@ from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from controllers.web import api
from controllers.web import web_ns
from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
@ -17,9 +17,19 @@ from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
@web_ns.route("/passport")
class PassportResource(Resource):
"""Base resource for passport."""
@web_ns.doc("get_passport")
@web_ns.doc(description="Get authentication passport for web application access")
@web_ns.doc(
responses={
200: "Passport retrieved successfully",
401: "Unauthorized - missing app code or invalid authentication",
404: "Application or user not found",
}
)
def get(self):
system_features = FeatureService.get_system_features()
app_code = request.headers.get("X-App-Code")
@ -94,9 +104,6 @@ class PassportResource(Resource):
}
api.add_resource(PassportResource, "/passport")
def decode_enterprise_webapp_user_id(jwt_token: str | None):
"""
Decode the enterprise user session from the Authorization header.

View File

@ -10,16 +10,44 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
from services.file_service import FileService
@web_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(WebApiResource):
@marshal_with(remote_file_info_fields)
@web_ns.doc("get_remote_file_info")
@web_ns.doc(description="Get information about a remote file")
@web_ns.doc(
responses={
200: "Remote file information retrieved successfully",
400: "Bad request - invalid URL",
404: "Remote file not found",
500: "Failed to fetch remote file",
}
)
@marshal_with(build_remote_file_info_model(web_ns))
def get(self, app_model, end_user, url):
"""Get information about a remote file.
Retrieves basic information about a file located at a remote URL,
including content type and content length.
Args:
app_model: The associated application model
end_user: The end user making the request
url: URL-encoded path to the remote file
Returns:
dict: Remote file information including type and length
Raises:
HTTPException: If the remote file cannot be accessed
"""
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
@ -32,9 +60,42 @@ class RemoteFileInfoApi(WebApiResource):
}
@web_ns.route("/remote-files/upload")
class RemoteFileUploadApi(WebApiResource):
@marshal_with(file_fields_with_signed_url)
def post(self, app_model, end_user): # Add app_model and end_user parameters
@web_ns.doc("upload_remote_file")
@web_ns.doc(description="Upload a file from a remote URL")
@web_ns.doc(
responses={
201: "Remote file uploaded successfully",
400: "Bad request - invalid URL or parameters",
413: "File too large",
415: "Unsupported file type",
500: "Failed to fetch remote file",
}
)
@marshal_with(build_file_with_signed_url_model(web_ns))
def post(self, app_model, end_user):
"""Upload a file from a remote URL.
Downloads a file from the provided remote URL and uploads it
to the platform storage for use in web applications.
Args:
app_model: The associated application model
end_user: The end user making the request
JSON Parameters:
url: The remote URL to download the file from (required)
Returns:
dict: File information including ID, signed URL, and metadata
int: HTTP status code 201 for success
Raises:
RemoteFileUploadError: Failed to fetch file from remote URL
FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported
"""
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args()

View File

@ -210,13 +210,6 @@ class IndexingRunner:
documents.append(document)
# build index
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
self._load(

View File

@ -401,7 +401,6 @@ class LLMGenerator:
def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
) -> dict:
app: App | None = db.session.query(App).where(App.id == flow_id).first()
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
)
@ -572,5 +571,7 @@ class LLMGenerator:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=e)
logger.exception(
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
)
return {"error": f"An unexpected error occurred: {str(e)}"}

View File

@ -276,7 +276,6 @@ class OracleVector(BaseVector):
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 5 # Use default if invalid
# just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+")

View File

@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import sqlalchemy as sa
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from yarl import URL
import contexts
@ -617,8 +618,9 @@ class ToolManager:
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
with Session(db.engine, autoflush=False) as session:
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
@classmethod
def list_providers_from_api(

View File

@ -329,22 +329,16 @@ class Executor:
"""
do http request depending on api bundle
"""
if self.method not in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
@ -362,11 +356,11 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
response: httpx.Response = _METHOD_MAP[method_lc](**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response # type: ignore
return response
def invoke(self) -> Response:
# assemble headers

View File

@ -524,7 +524,12 @@ class LoopNode(BaseNode):
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
if var_type in [
# TODO: Refactor for maintainability:
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
# 2. Consider moving this method to LoopVariableData class for better encapsulation
if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN:
value = original_value
elif var_type in [
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
@ -534,8 +539,6 @@ class LoopNode(BaseNode):
else:
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
value = []
elif var_type == SegmentType.ARRAY_BOOLEAN:
value = original_value
else:
raise AssertionError("this statement should be unreachable.")
try:

View File

@ -292,7 +292,6 @@ class ClickZettaVolumeStorage(BaseStorage):
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":

View File

@ -7,7 +7,7 @@
import json
import logging
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from datetime import datetime
from enum import Enum
from typing import Any, Optional
@ -185,7 +185,6 @@ class FileLifecycleManager:
versions.append(current_metadata)
# 获取历史版本
version_pattern = f"{self._version_prefix}{filename}.v*"
try:
version_files = self._storage.scan(self._dataset_id or "", files=True)
for file_path in version_files:
@ -331,7 +330,6 @@ class FileLifecycleManager:
"""
try:
cleaned_count = 0
cutoff_date = datetime.now() - timedelta(days=max_age_days)
# 获取所有版本文件
try:

View File

@ -6,10 +6,10 @@ Create Date: 2025-08-09 15:53:54.341341
"""
from alembic import op
from libs.uuid_utils import uuidv7
import models as models
import sqlalchemy as sa
from sqlalchemy.sql import table, column
import uuid
# revision identifiers, used by Alembic.
revision = 'e8446f481c1e'
@ -21,7 +21,7 @@ depends_on = None
def upgrade():
# Create provider_credentials table
op.create_table('provider_credentials',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('provider_name', sa.String(length=255), nullable=False),
sa.Column('credential_name', sa.String(length=255), nullable=False),
@ -63,7 +63,7 @@ def migrate_existing_providers_data():
column('updated_at', sa.DateTime()),
column('credential_id', models.types.StringUUID()),
)
provider_credential_table = table('provider_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
@ -79,15 +79,15 @@ def migrate_existing_providers_data():
# Query all existing providers data
existing_providers = conn.execute(
sa.select(providers_table.c.id, providers_table.c.tenant_id,
sa.select(providers_table.c.id, providers_table.c.tenant_id,
providers_table.c.provider_name, providers_table.c.encrypted_config,
providers_table.c.created_at, providers_table.c.updated_at)
.where(providers_table.c.encrypted_config.isnot(None))
).fetchall()
# Iterate through each provider and insert into provider_credentials
for provider in existing_providers:
credential_id = str(uuid.uuid4())
credential_id = str(uuidv7())
if not provider.encrypted_config or provider.encrypted_config.strip() == '':
continue
@ -134,7 +134,7 @@ def downgrade():
def migrate_data_back_to_providers():
"""Migrate data back from provider_credentials to providers table for downgrade"""
# Define table structure for data manipulation
providers_table = table('providers',
column('id', models.types.StringUUID()),
@ -143,7 +143,7 @@ def migrate_data_back_to_providers():
column('encrypted_config', sa.Text()),
column('credential_id', models.types.StringUUID()),
)
provider_credential_table = table('provider_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
@ -160,18 +160,18 @@ def migrate_data_back_to_providers():
sa.select(providers_table.c.id, providers_table.c.credential_id)
.where(providers_table.c.credential_id.isnot(None))
).fetchall()
# For each provider, get the credential data and update providers table
for provider in providers_with_credentials:
credential = conn.execute(
sa.select(provider_credential_table.c.encrypted_config)
.where(provider_credential_table.c.id == provider.credential_id)
).fetchone()
if credential:
# Update providers table with encrypted_config from credential
conn.execute(
providers_table.update()
.where(providers_table.c.id == provider.id)
.values(encrypted_config=credential.encrypted_config)
)
)

View File

@ -5,9 +5,9 @@ Revises: e8446f481c1e
Create Date: 2025-08-13 16:05:42.657730
"""
import uuid
from alembic import op
from libs.uuid_utils import uuidv7
import models as models
import sqlalchemy as sa
from sqlalchemy.sql import table, column
@ -23,7 +23,7 @@ depends_on = None
def upgrade():
# Create provider_model_credentials table
op.create_table('provider_model_credentials',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('provider_name', sa.String(length=255), nullable=False),
sa.Column('model_name', sa.String(length=255), nullable=False),
@ -71,7 +71,7 @@ def migrate_existing_provider_models_data():
column('updated_at', sa.DateTime()),
column('credential_id', models.types.StringUUID()),
)
provider_model_credentials_table = table('provider_model_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
@ -90,19 +90,19 @@ def migrate_existing_provider_models_data():
# Query all existing provider_models data with encrypted_config
existing_provider_models = conn.execute(
sa.select(provider_models_table.c.id, provider_models_table.c.tenant_id,
sa.select(provider_models_table.c.id, provider_models_table.c.tenant_id,
provider_models_table.c.provider_name, provider_models_table.c.model_name,
provider_models_table.c.model_type, provider_models_table.c.encrypted_config,
provider_models_table.c.created_at, provider_models_table.c.updated_at)
.where(provider_models_table.c.encrypted_config.isnot(None))
).fetchall()
# Iterate through each provider_model and insert into provider_model_credentials
for provider_model in existing_provider_models:
if not provider_model.encrypted_config or provider_model.encrypted_config.strip() == '':
continue
credential_id = str(uuid.uuid4())
credential_id = str(uuidv7())
# Insert into provider_model_credentials table
conn.execute(
@ -148,14 +148,14 @@ def downgrade():
def migrate_data_back_to_provider_models():
"""Migrate data back from provider_model_credentials to provider_models table for downgrade"""
# Define table structure for data manipulation
provider_models_table = table('provider_models',
column('id', models.types.StringUUID()),
column('encrypted_config', sa.Text()),
column('credential_id', models.types.StringUUID()),
)
provider_model_credentials_table = table('provider_model_credentials',
column('id', models.types.StringUUID()),
column('encrypted_config', sa.Text()),
@ -169,14 +169,14 @@ def migrate_data_back_to_provider_models():
sa.select(provider_models_table.c.id, provider_models_table.c.credential_id)
.where(provider_models_table.c.credential_id.isnot(None))
).fetchall()
# For each provider_model, get the credential data and update provider_models table
for provider_model in provider_models_with_credentials:
credential = conn.execute(
sa.select(provider_model_credentials_table.c.encrypted_config)
.where(provider_model_credentials_table.c.id == provider_model.credential_id)
).fetchone()
if credential:
# Update provider_models table with encrypted_config from credential
conn.execute(

View File

@ -274,7 +274,7 @@ class ProviderCredential(Base):
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
@ -300,7 +300,7 @@ class ProviderModelCredential(Base):
),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.7.2"
version = "1.8.0"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -47,7 +47,6 @@ def clean_messages():
if not messages:
break
for message in messages:
plan_sandbox_clean_message_day = message.created_at
app = db.session.query(App).filter_by(id=message.app_id).first()
if not app:
logger.warning(

View File

@ -44,10 +44,10 @@ def queue_monitor_task():
if queue_length >= threshold:
warning_msg = f"Queue {queue_name} task count exceeded the limit.: {queue_length}/{threshold}"
logger.warning(click.style(warning_msg, fg="red"))
alter_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS
if alter_emails:
to_list = alter_emails.split(",")
logging.warning(click.style(warning_msg, fg="red"))
alert_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS
if alert_emails:
to_list = alert_emails.split(",")
email_service = get_email_i18n_service()
for to in to_list:
try:

View File

@ -42,7 +42,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.3.1"
CURRENT_DSL_VERSION = "0.4.0"
class ImportMode(StrEnum):

View File

@ -2881,16 +2881,6 @@ class SegmentService:
with redis_client.lock(lock_name, timeout=20):
index_node_id = str(uuid.uuid4())
index_node_hash = helper.generate_text_hash(content)
child_chunk_count = (
db.session.query(ChildChunk)
.where(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
)
.count()
)
max_position = (
db.session.query(func.max(ChildChunk.position))
.where(

View File

@ -9,6 +9,7 @@ from sqlalchemy import select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
from core.rag.entities.metadata_entities import MetadataCondition
from core.workflow.nodes.http_request.exc import InvalidHttpMethodError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import (
@ -185,9 +186,19 @@ class ExternalDatasetService:
"follow_redirects": True,
}
response: httpx.Response = getattr(ssrf_proxy, settings.request_method)(
data=json.dumps(settings.params), files=files, **kwargs
)
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
}
method_lc = settings.request_method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {settings.request_method}")
response: httpx.Response = _METHOD_MAP[method_lc](data=json.dumps(settings.params), files=files, **kwargs)
return response
@staticmethod

View File

@ -1,7 +1,7 @@
import logging
from typing import Optional
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager
@ -380,7 +380,7 @@ class ModelProviderService:
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider available models
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True)
# Group models by provider
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
@ -391,9 +391,6 @@ class ModelProviderService:
if model.deprecated:
continue
if model.status != ModelStatus.ACTIVE:
continue
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list

View File

@ -134,12 +134,21 @@ class OpsService:
# get project url
if tracing_provider in ("arize", "phoenix"):
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
try:
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
except Exception:
project_url = None
elif tracing_provider == "langfuse":
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
project_url = f"{tracing_config.get('host')}/project/{project_key}"
try:
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
project_url = f"{tracing_config.get('host')}/project/{project_key}"
except Exception:
project_url = None
elif tracing_provider in ("langsmith", "opik"):
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
try:
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
except Exception:
project_url = None
else:
project_url = None

View File

@ -431,7 +431,7 @@ class BuiltinToolManageService:
check if oauth system client exists
"""
tool_provider = ToolProviderID(provider_name)
with Session(db.engine).no_autoflush as session:
with Session(db.engine, autoflush=False) as session:
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
@ -445,7 +445,7 @@ class BuiltinToolManageService:
check if oauth custom client is enabled
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine).no_autoflush as session:
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
@ -470,7 +470,7 @@ class BuiltinToolManageService:
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine).no_autoflush as session:
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
@ -524,54 +524,53 @@ class BuiltinToolManageService:
# get all builtin providers
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
with db.session.no_autoflush:
# get all user added providers
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# get all user added providers
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# rewrite db_providers
for db_provider in db_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider))
# rewrite db_providers
for db_provider in db_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider))
# find provider
def find_provider(provider):
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
# find provider
def find_provider(provider):
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
result: list[ToolProviderApiEntity] = []
result: list[ToolProviderApiEntity] = []
for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
):
continue
for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
):
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True,
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True,
)
# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
tools = provider_controller.get_tools()
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
tools = provider_controller.get_tools()
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
result.append(user_builtin_provider)
except Exception as e:
raise e
result.append(user_builtin_provider)
except Exception as e:
raise e
return BuiltinToolProviderSort.sort(result)
@ -582,7 +581,7 @@ class BuiltinToolManageService:
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
with Session(db.engine) as session:
with Session(db.engine, autoflush=False) as session:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)

View File

@ -113,7 +113,7 @@ class WebAppAuthService:
@classmethod
def _get_account_jwt_token(cls, account: Account) -> str:
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
exp = int(exp_dt.timestamp())
payload = {

View File

@ -24,7 +24,6 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
Usage: retry_document_indexing_task.delay(dataset_id, document_ids)
"""
documents: list[Document] = []
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()

View File

@ -10,11 +10,13 @@ more reliable and realistic test scenarios.
import logging
import os
from collections.abc import Generator
from pathlib import Path
from typing import Optional
import pytest
from flask import Flask
from flask.testing import FlaskClient
from sqlalchemy import Engine, text
from sqlalchemy.orm import Session
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
@ -64,7 +66,7 @@ class DifyTestContainers:
# PostgreSQL is used for storing user data, workflows, and application state
logger.info("Initializing PostgreSQL container...")
self.postgres = PostgresContainer(
image="postgres:16-alpine",
image="postgres:14-alpine",
)
self.postgres.start()
db_host = self.postgres.get_container_host_ip()
@ -116,7 +118,7 @@ class DifyTestContainers:
# Start Redis container for caching and session management
# Redis is used for storing session data, cache entries, and temporary data
logger.info("Initializing Redis container...")
self.redis = RedisContainer(image="redis:latest", port=6379)
self.redis = RedisContainer(image="redis:6-alpine", port=6379)
self.redis.start()
redis_host = self.redis.get_container_host_ip()
redis_port = self.redis.get_exposed_port(6379)
@ -184,6 +186,57 @@ class DifyTestContainers:
_container_manager = DifyTestContainers()
def _get_migration_dir() -> Path:
conftest_dir = Path(__file__).parent
return conftest_dir.parent.parent / "migrations"
def _get_engine_url(engine: Engine):
try:
return engine.url.render_as_string(hide_password=False).replace("%", "%%")
except AttributeError:
return str(engine.url).replace("%", "%%")
_UUIDv7SQL = r"""
/* Main function to generate a uuidv7 value with millisecond precision */
CREATE FUNCTION uuidv7() RETURNS uuid
AS
$$
-- Replace the first 48 bits of a uuidv4 with the current
-- number of milliseconds since 1970-01-01 UTC
-- and set the "ver" field to 7 by setting additional bits
SELECT encode(
set_bit(
set_bit(
overlay(uuid_send(gen_random_uuid()) placing
substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from
3)
from 1 for 6),
52, 1),
53, 1), 'hex')::uuid;
$$ LANGUAGE SQL VOLATILE PARALLEL SAFE;
COMMENT ON FUNCTION uuidv7 IS
'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness';
CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid
AS
$$
/* uuid fields: version=0b0111, variant=0b10 */
SELECT encode(
overlay('\x00000000000070008000000000000000'::bytea
placing substring(int8send(floor(extract(epoch from $1) * 1000)::bigint) from 3)
from 1 for 6),
'hex')::uuid;
$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE;
COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS
'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0.
As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.';
"""
def _create_app_with_containers() -> Flask:
"""
Create Flask application configured to use test containers.
@ -211,7 +264,10 @@ def _create_app_with_containers() -> Flask:
# Initialize database schema
logger.info("Creating database schema...")
with app.app_context():
with db.engine.connect() as conn, conn.begin():
conn.execute(text(_UUIDv7SQL))
db.create_all()
logger.info("Database schema created successfully")

View File

@ -144,127 +144,6 @@ class TestAppDslService:
}
return yaml.dump(yaml_data, allow_unicode=True)
def test_import_app_yaml_content_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app import from YAML content.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create YAML content
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
# Import app
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name="Imported App",
description="Imported app description",
)
# Verify import result
assert result.status == ImportStatus.COMPLETED
assert result.app_id is not None
assert result.app_mode == "chat"
assert result.imported_dsl_version == "0.3.0"
assert result.error == ""
# Verify app was created in database
imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first()
assert imported_app is not None
assert imported_app.name == "Imported App"
assert imported_app.description == "Imported app description"
assert imported_app.mode == "chat"
assert imported_app.tenant_id == account.current_tenant_id
assert imported_app.created_by == account.id
# Verify model config was created
model_config = (
db_session_with_containers.query(AppModelConfig).filter(AppModelConfig.app_id == result.app_id).first()
)
assert model_config is not None
# The provider and model_id are stored in the model field as JSON
model_dict = model_config.model_dict
assert model_dict["provider"] == "openai"
assert model_dict["name"] == "gpt-3.5-turbo"
def test_import_app_yaml_url_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app import from YAML URL.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create YAML content for mock response
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
# Setup mock response
mock_response = MagicMock()
mock_response.content = yaml_content.encode("utf-8")
mock_response.raise_for_status.return_value = None
mock_external_service_dependencies["ssrf_proxy"].get.return_value = mock_response
# Import app from URL
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_URL,
yaml_url="https://example.com/app.yaml",
name="URL Imported App",
description="App imported from URL",
)
# Verify import result
assert result.status == ImportStatus.COMPLETED
assert result.app_id is not None
assert result.app_mode == "chat"
assert result.imported_dsl_version == "0.3.0"
assert result.error == ""
# Verify app was created in database
imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first()
assert imported_app is not None
assert imported_app.name == "URL Imported App"
assert imported_app.description == "App imported from URL"
assert imported_app.mode == "chat"
assert imported_app.tenant_id == account.current_tenant_id
# Verify ssrf_proxy was called
mock_external_service_dependencies["ssrf_proxy"].get.assert_called_once_with(
"https://example.com/app.yaml", follow_redirects=True, timeout=(10, 10)
)
def test_import_app_invalid_yaml_format(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app import with invalid YAML format.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create invalid YAML content
invalid_yaml = "invalid: yaml: content: ["
# Import app with invalid YAML
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=invalid_yaml,
name="Invalid App",
)
# Verify import failed
assert result.status == ImportStatus.FAILED
assert result.app_id is None
assert "Invalid YAML format" in result.error
assert result.imported_dsl_version == ""
# Verify no app was created in database
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
assert apps_count == 1 # Only the original test app
def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app import with missing YAML content.

View File

@ -1067,7 +1067,7 @@ class TestModelProviderService:
# Verify mock interactions
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
mock_provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM)
mock_provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
def test_get_model_parameter_rules_success(self, db_session_with_containers, mock_external_service_dependencies):
"""

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:1.7.2
image: langgenius/dify-api:1.8.0
restart: always
environment:
# Use the shared environment variables.
@ -31,7 +31,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:1.7.2
image: langgenius/dify-api:1.8.0
restart: always
environment:
# Use the shared environment variables.
@ -58,7 +58,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.7.2
image: langgenius/dify-api:1.8.0
restart: always
environment:
# Use the shared environment variables.
@ -76,7 +76,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.7.2
image: langgenius/dify-web:1.8.0
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -580,7 +580,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:1.7.2
image: langgenius/dify-api:1.8.0
restart: always
environment:
# Use the shared environment variables.
@ -609,7 +609,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:1.7.2
image: langgenius/dify-api:1.8.0
restart: always
environment:
# Use the shared environment variables.
@ -636,7 +636,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.7.2
image: langgenius/dify-api:1.8.0
restart: always
environment:
# Use the shared environment variables.
@ -654,7 +654,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.7.2
image: langgenius/dify-web:1.8.0
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -1,47 +0,0 @@
#!/usr/bin/env python3
"""
Simple test to verify boolean classes can be imported correctly.
"""
import sys
import os
# Add the api directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api"))
try:
# Test that we can import the boolean classes
from core.variables.segments import BooleanSegment, ArrayBooleanSegment
from core.variables.variables import BooleanVariable, ArrayBooleanVariable
from core.variables.types import SegmentType
print("✅ Successfully imported BooleanSegment")
print("✅ Successfully imported ArrayBooleanSegment")
print("✅ Successfully imported BooleanVariable")
print("✅ Successfully imported ArrayBooleanVariable")
print("✅ Successfully imported SegmentType")
# Test that the segment types exist
print(f"✅ SegmentType.BOOLEAN = {SegmentType.BOOLEAN}")
print(f"✅ SegmentType.ARRAY_BOOLEAN = {SegmentType.ARRAY_BOOLEAN}")
# Test creating boolean segments directly
bool_seg = BooleanSegment(value=True)
print(f"✅ Created BooleanSegment: {bool_seg}")
print(f" Value type: {bool_seg.value_type}")
print(f" Value: {bool_seg.value}")
array_bool_seg = ArrayBooleanSegment(value=[True, False, True])
print(f"✅ Created ArrayBooleanSegment: {array_bool_seg}")
print(f" Value type: {array_bool_seg.value_type}")
print(f" Value: {array_bool_seg.value}")
print("\n🎉 All boolean class imports and basic functionality work correctly!")
except ImportError as e:
print(f"❌ Import error: {e}")
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()

View File

@ -1,118 +0,0 @@
#!/usr/bin/env python3
"""
Simple test script to verify boolean condition support in IfElseNode
"""
import sys
import os
# Add the api directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api"))
from core.workflow.utils.condition.processor import (
ConditionProcessor,
_evaluate_condition,
)
def test_boolean_conditions():
"""Test boolean condition evaluation"""
print("Testing boolean condition support...")
# Test boolean "is" operator
result = _evaluate_condition(value=True, operator="is", expected="true")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'is' with True value passed")
result = _evaluate_condition(value=False, operator="is", expected="false")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'is' with False value passed")
# Test boolean "is not" operator
result = _evaluate_condition(value=True, operator="is not", expected="false")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'is not' with True value passed")
result = _evaluate_condition(value=False, operator="is not", expected="true")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'is not' with False value passed")
# Test boolean "=" operator
result = _evaluate_condition(value=True, operator="=", expected="1")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean '=' with True=1 passed")
result = _evaluate_condition(value=False, operator="=", expected="0")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean '=' with False=0 passed")
# Test boolean "≠" operator
result = _evaluate_condition(value=True, operator="", expected="0")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean '' with True≠0 passed")
result = _evaluate_condition(value=False, operator="", expected="1")
assert result == True, f"Expected True, got {result}"
print("✓ Boolean '' with False≠1 passed")
# Test boolean "in" operator
result = _evaluate_condition(value=True, operator="in", expected=["true", "false"])
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'in' with True in array passed")
result = _evaluate_condition(value=False, operator="in", expected=["true", "false"])
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'in' with False in array passed")
# Test boolean "not in" operator
result = _evaluate_condition(value=True, operator="not in", expected=["false", "0"])
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'not in' with True not in [false, 0] passed")
# Test boolean "null" and "not null" operators
result = _evaluate_condition(value=True, operator="not null", expected=None)
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'not null' with True passed")
result = _evaluate_condition(value=False, operator="not null", expected=None)
assert result == True, f"Expected True, got {result}"
print("✓ Boolean 'not null' with False passed")
print("\n🎉 All boolean condition tests passed!")
def test_backward_compatibility():
"""Test that existing string and number conditions still work"""
print("\nTesting backward compatibility...")
# Test string conditions
result = _evaluate_condition(value="hello", operator="is", expected="hello")
assert result == True, f"Expected True, got {result}"
print("✓ String 'is' condition still works")
result = _evaluate_condition(value="hello", operator="contains", expected="ell")
assert result == True, f"Expected True, got {result}"
print("✓ String 'contains' condition still works")
# Test number conditions
result = _evaluate_condition(value=42, operator="=", expected="42")
assert result == True, f"Expected True, got {result}"
print("✓ Number '=' condition still works")
result = _evaluate_condition(value=42, operator=">", expected="40")
assert result == True, f"Expected True, got {result}"
print("✓ Number '>' condition still works")
print("✓ Backward compatibility maintained!")
if __name__ == "__main__":
try:
test_boolean_conditions()
test_backward_compatibility()
print(
"\n✅ All tests passed! Boolean support has been successfully added to IfElseNode."
)
except Exception as e:
print(f"\n❌ Test failed: {e}")
sys.exit(1)

View File

@ -1,67 +0,0 @@
#!/usr/bin/env python3
"""
Test script to verify the boolean array comparison fix in condition processor.
"""
import sys
import os
# Add the api directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api"))
from core.workflow.utils.condition.processor import (
_assert_contains,
_assert_not_contains,
)
def test_boolean_array_contains():
"""Test that boolean arrays work correctly with string comparisons."""
# Test case 1: Boolean array [True, False, True] contains "true"
bool_array = [True, False, True]
# Should return True because "true" converts to True and True is in the array
result1 = _assert_contains(value=bool_array, expected="true")
print(f"Test 1 - [True, False, True] contains 'true': {result1}")
assert result1 == True, "Expected True but got False"
# Should return True because "false" converts to False and False is in the array
result2 = _assert_contains(value=bool_array, expected="false")
print(f"Test 2 - [True, False, True] contains 'false': {result2}")
assert result2 == True, "Expected True but got False"
# Test case 2: Boolean array [True, True] does not contain "false"
bool_array2 = [True, True]
result3 = _assert_contains(value=bool_array2, expected="false")
print(f"Test 3 - [True, True] contains 'false': {result3}")
assert result3 == False, "Expected False but got True"
# Test case 3: Test not_contains
result4 = _assert_not_contains(value=bool_array2, expected="false")
print(f"Test 4 - [True, True] not contains 'false': {result4}")
assert result4 == True, "Expected True but got False"
result5 = _assert_not_contains(value=bool_array, expected="true")
print(f"Test 5 - [True, False, True] not contains 'true': {result5}")
assert result5 == False, "Expected False but got True"
# Test case 4: Test with different string representations
result6 = _assert_contains(
value=bool_array, expected="1"
) # "1" should convert to True
print(f"Test 6 - [True, False, True] contains '1': {result6}")
assert result6 == True, "Expected True but got False"
result7 = _assert_contains(
value=bool_array, expected="0"
) # "0" should convert to False
print(f"Test 7 - [True, False, True] contains '0': {result7}")
assert result7 == True, "Expected True but got False"
print("\n✅ All boolean array comparison tests passed!")
if __name__ == "__main__":
test_boolean_array_contains()

View File

@ -1,99 +0,0 @@
#!/usr/bin/env python3
"""
Simple test script to verify boolean type inference in variable factory.
"""
import sys
import os
# Add the api directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api"))
try:
from factories.variable_factory import build_segment, segment_to_variable
from core.variables.segments import BooleanSegment, ArrayBooleanSegment
from core.variables.variables import BooleanVariable, ArrayBooleanVariable
from core.variables.types import SegmentType
def test_boolean_inference():
print("Testing boolean type inference...")
# Test single boolean values
true_segment = build_segment(True)
false_segment = build_segment(False)
print(f"True value: {true_segment}")
print(f"Type: {type(true_segment)}")
print(f"Value type: {true_segment.value_type}")
print(f"Is BooleanSegment: {isinstance(true_segment, BooleanSegment)}")
print(f"\nFalse value: {false_segment}")
print(f"Type: {type(false_segment)}")
print(f"Value type: {false_segment.value_type}")
print(f"Is BooleanSegment: {isinstance(false_segment, BooleanSegment)}")
# Test array of booleans
bool_array_segment = build_segment([True, False, True])
print(f"\nBoolean array: {bool_array_segment}")
print(f"Type: {type(bool_array_segment)}")
print(f"Value type: {bool_array_segment.value_type}")
print(
f"Is ArrayBooleanSegment: {isinstance(bool_array_segment, ArrayBooleanSegment)}"
)
# Test empty boolean array
empty_bool_array = build_segment([])
print(f"\nEmpty array: {empty_bool_array}")
print(f"Type: {type(empty_bool_array)}")
print(f"Value type: {empty_bool_array.value_type}")
# Test segment to variable conversion
bool_var = segment_to_variable(
segment=true_segment, selector=["test", "bool_var"], name="test_boolean"
)
print(f"\nBoolean variable: {bool_var}")
print(f"Type: {type(bool_var)}")
print(f"Is BooleanVariable: {isinstance(bool_var, BooleanVariable)}")
array_bool_var = segment_to_variable(
segment=bool_array_segment,
selector=["test", "array_bool_var"],
name="test_array_boolean",
)
print(f"\nArray boolean variable: {array_bool_var}")
print(f"Type: {type(array_bool_var)}")
print(
f"Is ArrayBooleanVariable: {isinstance(array_bool_var, ArrayBooleanVariable)}"
)
# Test that bool comes before int (critical ordering)
print(f"\nTesting bool vs int precedence:")
print(f"True is instance of bool: {isinstance(True, bool)}")
print(f"True is instance of int: {isinstance(True, int)}")
print(f"False is instance of bool: {isinstance(False, bool)}")
print(f"False is instance of int: {isinstance(False, int)}")
# Verify that boolean values are correctly inferred as boolean, not int
assert true_segment.value_type == SegmentType.BOOLEAN, (
"True should be inferred as BOOLEAN"
)
assert false_segment.value_type == SegmentType.BOOLEAN, (
"False should be inferred as BOOLEAN"
)
assert bool_array_segment.value_type == SegmentType.ARRAY_BOOLEAN, (
"Boolean array should be inferred as ARRAY_BOOLEAN"
)
print("\n✅ All boolean inference tests passed!")
if __name__ == "__main__":
test_boolean_inference()
except ImportError as e:
print(f"Import error: {e}")
print("Make sure you're running this from the correct directory")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()

View File

@ -1,230 +0,0 @@
#!/usr/bin/env python3
"""
Test script to verify boolean support in VariableAssigner node
"""
import sys
import os
# Add the api directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api"))
from core.variables import SegmentType
from core.workflow.nodes.variable_assigner.v2.helpers import (
is_operation_supported,
is_constant_input_supported,
is_input_value_valid,
)
from core.workflow.nodes.variable_assigner.v2.enums import Operation
from core.workflow.nodes.variable_assigner.v2.constants import EMPTY_VALUE_MAPPING
def test_boolean_operation_support():
"""Test that boolean types support the correct operations"""
print("Testing boolean operation support...")
# Boolean should support SET, OVER_WRITE, and CLEAR
assert is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET
)
assert is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.OVER_WRITE
)
assert is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.CLEAR
)
# Boolean should NOT support arithmetic operations
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.ADD
)
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.SUBTRACT
)
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.MULTIPLY
)
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.DIVIDE
)
# Boolean should NOT support array operations
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.APPEND
)
assert not is_operation_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.EXTEND
)
print("✓ Boolean operation support tests passed")
def test_array_boolean_operation_support():
"""Test that array boolean types support the correct operations"""
print("Testing array boolean operation support...")
# Array boolean should support APPEND, EXTEND, SET, OVER_WRITE, CLEAR
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.APPEND
)
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.EXTEND
)
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.OVER_WRITE
)
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.CLEAR
)
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.REMOVE_FIRST
)
assert is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.REMOVE_LAST
)
# Array boolean should NOT support arithmetic operations
assert not is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.ADD
)
assert not is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.SUBTRACT
)
assert not is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.MULTIPLY
)
assert not is_operation_supported(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.DIVIDE
)
print("✓ Array boolean operation support tests passed")
def test_boolean_constant_input_support():
"""Test that boolean types support constant input for correct operations"""
print("Testing boolean constant input support...")
# Boolean should support constant input for SET and OVER_WRITE
assert is_constant_input_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET
)
assert is_constant_input_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.OVER_WRITE
)
# Boolean should NOT support constant input for arithmetic operations
assert not is_constant_input_supported(
variable_type=SegmentType.BOOLEAN, operation=Operation.ADD
)
print("✓ Boolean constant input support tests passed")
def test_boolean_input_validation():
"""Test that boolean input validation works correctly"""
print("Testing boolean input validation...")
# Boolean values should be valid for boolean type
assert is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=True
)
assert is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=False
)
assert is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.OVER_WRITE, value=True
)
# Non-boolean values should be invalid for boolean type
assert not is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value="true"
)
assert not is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=1
)
assert not is_input_value_valid(
variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=0
)
print("✓ Boolean input validation tests passed")
def test_array_boolean_input_validation():
"""Test that array boolean input validation works correctly"""
print("Testing array boolean input validation...")
# Boolean values should be valid for array boolean append
assert is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.APPEND, value=True
)
assert is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.APPEND, value=False
)
# Boolean arrays should be valid for extend/overwrite
assert is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN,
operation=Operation.EXTEND,
value=[True, False, True],
)
assert is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN,
operation=Operation.OVER_WRITE,
value=[False, False],
)
# Non-boolean values should be invalid
assert not is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN,
operation=Operation.APPEND,
value="true",
)
assert not is_input_value_valid(
variable_type=SegmentType.ARRAY_BOOLEAN,
operation=Operation.EXTEND,
value=[True, "false"],
)
print("✓ Array boolean input validation tests passed")
def test_empty_value_mapping():
"""Test that empty value mapping includes boolean types"""
print("Testing empty value mapping...")
# Check that boolean types have correct empty values
assert SegmentType.BOOLEAN in EMPTY_VALUE_MAPPING
assert EMPTY_VALUE_MAPPING[SegmentType.BOOLEAN] is False
assert SegmentType.ARRAY_BOOLEAN in EMPTY_VALUE_MAPPING
assert EMPTY_VALUE_MAPPING[SegmentType.ARRAY_BOOLEAN] == []
print("✓ Empty value mapping tests passed")
def main():
"""Run all tests"""
print("Running VariableAssigner boolean support tests...\n")
try:
test_boolean_operation_support()
test_array_boolean_operation_support()
test_boolean_constant_input_support()
test_boolean_input_validation()
test_array_boolean_input_validation()
test_empty_value_mapping()
print(
"\n🎉 All tests passed! Boolean support has been successfully added to VariableAssigner."
)
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -27,10 +27,11 @@ const ChunkDetailModal: FC<Props> = ({
}) => {
const { t } = useTranslation()
const { segment, score, child_chunks } = payload
const { position, content, sign_content, keywords, document } = segment
const { position, content, sign_content, keywords, document, answer } = segment
const isParentChildRetrieval = !!(child_chunks && child_chunks.length > 0)
const extension = document.name.split('.').slice(-1)[0] as FileAppearanceTypeEnum
const heighClassName = isParentChildRetrieval ? 'h-[min(627px,_80vh)] overflow-y-auto' : 'h-[min(539px,_80vh)] overflow-y-auto'
const labelPrefix = isParentChildRetrieval ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk')
return (
<Modal
title={t(`${i18nPrefix}.chunkDetail`)}
@ -45,7 +46,7 @@ const ChunkDetailModal: FC<Props> = ({
<div className='flex items-center justify-between'>
<div className='flex grow items-center space-x-2'>
<SegmentIndexTag
labelPrefix={`${isParentChildRetrieval ? 'Parent-' : ''}Chunk`}
labelPrefix={labelPrefix}
positionId={position}
className={cn('w-fit group-hover:opacity-100')}
/>
@ -57,11 +58,29 @@ const ChunkDetailModal: FC<Props> = ({
</div>
<Score value={score} />
</div>
<Markdown
className={cn('!mt-2 !text-text-secondary', heighClassName)}
content={sign_content || content}
customDisallowedElements={['input']}
/>
{!answer && (
<Markdown
className={cn('!mt-2 !text-text-secondary', heighClassName)}
content={sign_content || content}
customDisallowedElements={['input']}
/>
)}
{answer && (
<div>
<div className='flex gap-x-1'>
<div className='w-4 shrink-0 text-[13px] font-medium leading-[20px] text-text-tertiary'>Q</div>
<div className={cn('body-md-regular line-clamp-20 text-text-secondary')}>
{content}
</div>
</div>
<div className='flex gap-x-1'>
<div className='w-4 shrink-0 text-[13px] font-medium leading-[20px] text-text-tertiary'>A</div>
<div className={cn('body-md-regular line-clamp-20 text-text-secondary')}>
{answer}
</div>
</div>
</div>
)}
{!isParentChildRetrieval && keywords && keywords.length > 0 && (
<div className='mt-6'>
<div className='text-xs font-medium uppercase text-text-tertiary'>{t(`${i18nPrefix}.keyword`)}</div>

View File

@ -9,6 +9,7 @@ import tailwind from 'eslint-plugin-tailwindcss'
import reactHooks from 'eslint-plugin-react-hooks'
import sonar from 'eslint-plugin-sonarjs'
import oxlint from 'eslint-plugin-oxlint'
import next from '@next/eslint-plugin-next'
// import reactRefresh from 'eslint-plugin-react-refresh'
@ -63,12 +64,14 @@ export default combine(
}),
unicorn(),
node(),
// use nextjs config will break @eslint/config-inspector
// use `ESLINT_CONFIG_INSPECTOR=true pnpx @eslint/config-inspector` to check the config
// ...process.env.ESLINT_CONFIG_INSPECTOR
// ? []
// Next.js configuration
{
plugins: {
'@next/next': next,
},
rules: {
...next.configs.recommended.rules,
...next.configs['core-web-vitals'].rules,
// performance issue, and not used.
'@next/next/no-html-link-for-pages': 'off',
},

View File

@ -582,6 +582,7 @@ export type Segment = {
keywords: string[]
hit_count: number
index_node_hash: string
answer: string
}
export type Document = {

View File

@ -1,6 +1,6 @@
{
"name": "dify-web",
"version": "1.7.2",
"version": "1.8.0",
"private": true,
"packageManager": "pnpm@10.15.0",
"engines": {
@ -25,7 +25,7 @@
"start": "cp -r .next/static .next/standalone/.next/static && cp -r public .next/standalone/public && cross-env PORT=$npm_config_port HOSTNAME=$npm_config_host node .next/standalone/server.js",
"lint": "pnpx oxlint && pnpm eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache",
"lint-only-show-error": "pnpx oxlint && pnpm eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache --quiet",
"fix": "next lint --fix",
"fix": "eslint --fix .",
"eslint-fix": "eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache --fix",
"eslint-fix-only-show-error": "eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache --fix --quiet",
"eslint-complexity": "eslint --rule 'complexity: [error, {max: 15}]' --quiet",
@ -103,14 +103,14 @@
"mime": "^4.0.4",
"mitt": "^3.0.1",
"negotiator": "^0.6.3",
"next": "~15.3.5",
"next": "15.5.0",
"next-themes": "^0.4.3",
"pinyin-pro": "^3.25.0",
"qrcode.react": "^4.2.0",
"qs": "^6.13.0",
"react": "~19.1.0",
"react": "19.1.1",
"react-18-input-autosize": "^3.0.0",
"react-dom": "~19.1.0",
"react-dom": "19.1.1",
"react-easy-crop": "^5.1.0",
"react-error-boundary": "^4.1.2",
"react-headless-pagination": "^1.1.6",
@ -161,9 +161,9 @@
"@happy-dom/jest-environment": "^17.4.4",
"@mdx-js/loader": "^3.1.0",
"@mdx-js/react": "^3.1.0",
"@next/bundle-analyzer": "^15.4.1",
"@next/eslint-plugin-next": "~15.4.5",
"@next/mdx": "~15.3.5",
"@next/bundle-analyzer": "15.5.0",
"@next/eslint-plugin-next": "15.5.0",
"@next/mdx": "15.5.0",
"@rgrove/parse-xml": "^4.1.0",
"@storybook/addon-essentials": "8.5.0",
"@storybook/addon-interactions": "8.5.0",
@ -185,8 +185,8 @@
"@types/negotiator": "^0.6.3",
"@types/node": "18.15.0",
"@types/qs": "^6.9.16",
"@types/react": "~19.1.8",
"@types/react-dom": "~19.1.6",
"@types/react": "19.1.11",
"@types/react-dom": "19.1.7",
"@types/react-slider": "^1.3.6",
"@types/react-syntax-highlighter": "^15.5.13",
"@types/react-window": "^1.8.8",
@ -200,7 +200,7 @@
"code-inspector-plugin": "^0.18.1",
"cross-env": "^7.0.3",
"eslint": "^9.32.0",
"eslint-config-next": "~15.4.5",
"eslint-config-next": "15.5.0",
"eslint-plugin-oxlint": "^1.6.0",
"eslint-plugin-react-hooks": "^5.1.0",
"eslint-plugin-react-refresh": "^0.4.19",
@ -223,8 +223,8 @@
"uglify-js": "^3.19.3"
},
"resolutions": {
"@types/react": "~19.1.8",
"@types/react-dom": "~19.1.6",
"@types/react": "19.1.11",
"@types/react-dom": "19.1.7",
"string-width": "4.2.3"
},
"lint-staged": {

File diff suppressed because it is too large Load Diff