Merge branch 'main' into feat/model-auth

This commit is contained in:
zxhlyh 2025-08-11 10:01:35 +08:00
commit 1a642084b5
305 changed files with 13495 additions and 934 deletions

1197
.env.example Normal file

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,23 @@ title: "[Chore/Refactor] "
labels:
- refactor
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to submit this report, otherwise it will be closed.
required: true
- label: 【中文用户 & Non English User】请使用英语提交否则会被关闭
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
id: description
attributes:

View File

@ -1,13 +1,18 @@
name: Check i18n Files and Create PR
on:
pull_request:
types: [closed]
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.ts'
permissions:
contents: write
pull-requests: write
jobs:
check-and-update:
if: github.event.pull_request.merged == true
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
defaults:
run:
@ -15,8 +20,8 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 2 # last 2 commits
persist-credentials: false
fetch-depth: 2
token: ${{ secrets.GITHUB_TOKEN }}
- name: Check for file changes in i18n/en-US
id: check_files
@ -27,6 +32,13 @@ jobs:
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .ts)
file_args="$file_args --file=$filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
@ -49,14 +61,15 @@ jobs:
if: env.FILES_CHANGED == 'true'
run: pnpm install --frozen-lockfile
- name: Run npm script
- name: Generate i18n translations
if: env.FILES_CHANGED == 'true'
run: pnpm run auto-gen-i18n
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Update i18n files based on en-US changes
title: 'chore: translate i18n files'
body: This PR was automatically created to update i18n files based on changes in en-US locale.

1
.gitignore vendored
View File

@ -215,3 +215,4 @@ mise.toml
# AI Assistant
.roo/
api/.env.backup
/clickzetta

View File

@ -19,7 +19,7 @@ RUN apt-get update \
# Install Python dependencies
COPY pyproject.toml uv.lock ./
RUN uv sync --locked
RUN uv sync --locked --no-dev
# production stage
FROM base AS production

View File

@ -9,7 +9,7 @@ import sqlalchemy as sa
from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from constants.languages import languages
@ -181,8 +181,8 @@ def migrate_annotation_vector_database():
)
if not apps:
break
except NotFound:
break
except SQLAlchemyError:
raise
page += 1
for app in apps:
@ -308,8 +308,8 @@ def migrate_knowledge_vector_database():
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
except NotFound:
break
except SQLAlchemyError:
raise
page += 1
for dataset in datasets:
@ -561,8 +561,8 @@ def old_metadata_migration():
.order_by(DatasetDocument.created_at.desc())
)
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
except NotFound:
break
except SQLAlchemyError:
raise
if not documents:
break
for document in documents:

View File

@ -330,17 +330,17 @@ class HttpConfig(BaseSettings):
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="Maximum connection timeout in seconds for HTTP requests")
] = 10
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
)
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
PositiveInt, Field(ge=60, description="Maximum read timeout in seconds for HTTP requests")
] = 60
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60
)
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="Maximum write timeout in seconds for HTTP requests")
] = 20
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20
)
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
description="Maximum allowed size in bytes for binary data in HTTP requests",

View File

@ -10,6 +10,7 @@ from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from .storage.amazon_s3_storage_config import S3StorageConfig
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
from .storage.clickzetta_volume_storage_config import ClickZettaVolumeStorageConfig
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from .storage.oci_storage_config import OCIStorageConfig
@ -20,6 +21,7 @@ from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig
from .vdb.clickzetta_config import ClickzettaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
@ -52,6 +54,7 @@ class StorageConfig(BaseSettings):
"aliyun-oss",
"azure-blob",
"baidu-obs",
"clickzetta-volume",
"google-storage",
"huawei-obs",
"oci-storage",
@ -61,8 +64,9 @@ class StorageConfig(BaseSettings):
"local",
] = Field(
description="Type of storage to use."
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', "
"'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
"'volcengine-tos', 'supabase'. Default is 'opendal'.",
default="opendal",
)
@ -303,6 +307,7 @@ class MiddlewareConfig(
AliyunOSSStorageConfig,
AzureBlobStorageConfig,
BaiduOBSStorageConfig,
ClickZettaVolumeStorageConfig,
GoogleCloudStorageConfig,
HuaweiCloudOBSStorageConfig,
OCIStorageConfig,
@ -315,6 +320,7 @@ class MiddlewareConfig(
VectorStoreConfig,
AnalyticdbConfig,
ChromaConfig,
ClickzettaConfig,
HuaweiCloudConfig,
MilvusConfig,
MyScaleConfig,

View File

@ -0,0 +1,65 @@
"""ClickZetta Volume Storage Configuration"""
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class ClickZettaVolumeStorageConfig(BaseSettings):
"""Configuration for ClickZetta Volume storage."""
CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field(
description="Username for ClickZetta Volume authentication",
default=None,
)
CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field(
description="Password for ClickZetta Volume authentication",
default=None,
)
CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field(
description="ClickZetta instance identifier",
default=None,
)
CLICKZETTA_VOLUME_SERVICE: str = Field(
description="ClickZetta service endpoint",
default="api.clickzetta.com",
)
CLICKZETTA_VOLUME_WORKSPACE: str = Field(
description="ClickZetta workspace name",
default="quick_start",
)
CLICKZETTA_VOLUME_VCLUSTER: str = Field(
description="ClickZetta virtual cluster name",
default="default_ap",
)
CLICKZETTA_VOLUME_SCHEMA: str = Field(
description="ClickZetta schema name",
default="dify",
)
CLICKZETTA_VOLUME_TYPE: str = Field(
description="ClickZetta volume type (table|user|external)",
default="user",
)
CLICKZETTA_VOLUME_NAME: Optional[str] = Field(
description="ClickZetta volume name for external volumes",
default=None,
)
CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field(
description="Prefix for ClickZetta volume table names",
default="dataset_",
)
CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field(
description="Directory prefix for User Volume to organize Dify files",
default="dify_km",
)

View File

@ -0,0 +1,69 @@
from typing import Optional
from pydantic import BaseModel, Field
class ClickzettaConfig(BaseModel):
"""
Clickzetta Lakehouse vector database configuration
"""
CLICKZETTA_USERNAME: Optional[str] = Field(
description="Username for authenticating with Clickzetta Lakehouse",
default=None,
)
CLICKZETTA_PASSWORD: Optional[str] = Field(
description="Password for authenticating with Clickzetta Lakehouse",
default=None,
)
CLICKZETTA_INSTANCE: Optional[str] = Field(
description="Clickzetta Lakehouse instance ID",
default=None,
)
CLICKZETTA_SERVICE: Optional[str] = Field(
description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')",
default="api.clickzetta.com",
)
CLICKZETTA_WORKSPACE: Optional[str] = Field(
description="Clickzetta workspace name",
default="default",
)
CLICKZETTA_VCLUSTER: Optional[str] = Field(
description="Clickzetta virtual cluster name",
default="default_ap",
)
CLICKZETTA_SCHEMA: Optional[str] = Field(
description="Database schema name in Clickzetta",
default="public",
)
CLICKZETTA_BATCH_SIZE: Optional[int] = Field(
description="Batch size for bulk insert operations",
default=100,
)
CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field(
description="Enable inverted index for full-text search capabilities",
default=True,
)
CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field(
description="Analyzer type for full-text search: keyword, english, chinese, unicode",
default="chinese",
)
CLICKZETTA_ANALYZER_MODE: Optional[str] = Field(
description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)",
default="smart",
)
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field(
description="Distance function for vector similarity: l2_distance or cosine_distance",
default="cosine_distance",
)

View File

@ -225,14 +225,15 @@ class AnnotationBatchImportApi(Resource):
raise Forbidden()
app_id = str(app_id)
# get file from request
file = request.files["file"]
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# get file from request
file = request.files["file"]
# check file type
if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")

View File

@ -28,6 +28,12 @@ from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class AppListApi(Resource):
@setup_required
@login_required
@ -94,7 +100,7 @@ class AppListApi(Resource):
"""Create app"""
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
@ -146,7 +152,7 @@ class AppApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
@ -189,7 +195,7 @@ class AppCopyApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")

View File

@ -41,7 +41,7 @@ def _validate_name(name):
def _validate_description_length(description):
if len(description) > 400:
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@ -113,7 +113,7 @@ class DatasetListApi(Resource):
)
parser.add_argument(
"description",
type=str,
type=_validate_description_length,
nullable=True,
required=False,
default="",
@ -683,6 +683,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.HUAWEI_CLOUD
| VectorType.TENCENT
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
):
return {
"retrieval_method": [
@ -731,6 +732,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.TENCENT
| VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
):
return {
"retrieval_method": [

View File

@ -49,7 +49,6 @@ class FileApi(Resource):
@marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents")
def post(self):
file = request.files["file"]
source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
@ -58,6 +57,7 @@ class FileApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.filename:
raise FilenameNotExistsError

View File

@ -191,9 +191,6 @@ class WebappLogoWorkspaceApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
# get file from request
file = request.files["file"]
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -201,6 +198,8 @@ class WebappLogoWorkspaceApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
# get file from request
file = request.files["file"]
if not file.filename:
raise FilenameNotExistsError

View File

@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp)
from . import index
from .app import annotation, app, audio, completion, conversation, file, message, site, workflow
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
from .workspace import models

View File

@ -107,3 +107,15 @@ class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class FileNotFoundError(BaseHTTPException):
error_code = "file_not_found"
description = "The requested file was not found."
code = 404
class FileAccessDeniedError(BaseHTTPException):
error_code = "file_access_denied"
description = "Access to the requested file is denied."
code = 403

View File

@ -20,18 +20,17 @@ class FileApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields)
def post(self, app_model: App, end_user: EndUser):
file = request.files["file"]
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if not file.mimetype:
raise UnsupportedFileTypeError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.mimetype:
raise UnsupportedFileTypeError()
if not file.filename:
raise FilenameNotExistsError

View File

@ -0,0 +1,186 @@
import logging
from urllib.parse import quote
from flask import Response
from flask_restful import Resource, reqparse
from controllers.service_api import api
from controllers.service_api.app.error import (
FileAccessDeniedError,
FileNotFoundError,
)
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, EndUser, Message, MessageFile, UploadFile
logger = logging.getLogger(__name__)
class FilePreviewApi(Resource):
"""
Service API File Preview endpoint
Provides secure file preview/download functionality for external API users.
Files can only be accessed if they belong to messages within the requesting app's context.
"""
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, file_id: str):
"""
Preview/Download a file that was uploaded via Service API
Args:
app_model: The authenticated app model
end_user: The authenticated end user (optional)
file_id: UUID of the file to preview
Query Parameters:
user: Optional user identifier
as_attachment: Boolean, whether to download as attachment (default: false)
Returns:
Stream response with file content
Raises:
FileNotFoundError: File does not exist
FileAccessDeniedError: File access denied (not owned by app)
"""
file_id = str(file_id)
# Parse query parameters
parser = reqparse.RequestParser()
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
args = parser.parse_args()
# Validate file ownership and get file objects
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
# Get file content generator
try:
generator = storage.load(upload_file.key, stream=True)
except Exception as e:
raise FileNotFoundError(f"Failed to load file content: {str(e)}")
# Build response with appropriate headers
response = self._build_file_response(generator, upload_file, args["as_attachment"])
return response
def _validate_file_ownership(self, file_id: str, app_id: str) -> tuple[MessageFile, UploadFile]:
"""
Validate that the file belongs to a message within the requesting app's context
Security validations performed:
1. File exists in MessageFile table (was used in a conversation)
2. Message belongs to the requesting app
3. UploadFile record exists and is accessible
4. File tenant matches app tenant (additional security layer)
Args:
file_id: UUID of the file to validate
app_id: UUID of the requesting app
Returns:
Tuple of (MessageFile, UploadFile) if validation passes
Raises:
FileNotFoundError: File or related records not found
FileAccessDeniedError: File does not belong to the app's context
"""
try:
# Input validation
if not file_id or not app_id:
raise FileAccessDeniedError("Invalid file or app identifier")
# First, find the MessageFile that references this upload file
message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first()
if not message_file:
raise FileNotFoundError("File not found in message context")
# Get the message and verify it belongs to the requesting app
message = (
db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first()
)
if not message:
raise FileAccessDeniedError("File access denied: not owned by requesting app")
# Get the actual upload file record
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise FileNotFoundError("Upload file record not found")
# Additional security: verify tenant isolation
app = db.session.query(App).where(App.id == app_id).first()
if app and upload_file.tenant_id != app.tenant_id:
raise FileAccessDeniedError("File access denied: tenant mismatch")
return message_file, upload_file
except (FileNotFoundError, FileAccessDeniedError):
# Re-raise our custom exceptions
raise
except Exception as e:
# Log unexpected errors for debugging
logger.exception(
"Unexpected error during file ownership validation",
extra={"file_id": file_id, "app_id": app_id, "error": str(e)},
)
raise FileAccessDeniedError("File access validation failed")
def _build_file_response(self, generator, upload_file: UploadFile, as_attachment: bool = False) -> Response:
"""
Build Flask Response object with appropriate headers for file streaming
Args:
generator: File content generator from storage
upload_file: UploadFile database record
as_attachment: Whether to set Content-Disposition as attachment
Returns:
Flask Response object with streaming file content
"""
response = Response(
generator,
mimetype=upload_file.mime_type,
direct_passthrough=True,
headers={},
)
# Add Content-Length if known
if upload_file.size and upload_file.size > 0:
response.headers["Content-Length"] = str(upload_file.size)
# Add Accept-Ranges header for audio/video files to support seeking
if upload_file.mime_type in [
"audio/mpeg",
"audio/wav",
"audio/mp4",
"audio/ogg",
"audio/flac",
"audio/aac",
"video/mp4",
"video/webm",
"video/quicktime",
"audio/x-m4a",
]:
response.headers["Accept-Ranges"] = "bytes"
# Set Content-Disposition for downloads
if as_attachment and upload_file.name:
encoded_filename = quote(upload_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
# Override content-type for downloads to force download
response.headers["Content-Type"] = "application/octet-stream"
# Add caching headers for performance
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour
return response
# Register the API endpoint
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")

View File

@ -29,7 +29,7 @@ def _validate_name(name):
def _validate_description_length(description):
if len(description) > 400:
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@ -87,7 +87,7 @@ class DatasetListApi(DatasetApiResource):
)
parser.add_argument(
"description",
type=str,
type=_validate_description_length,
nullable=True,
required=False,
default="",

View File

@ -234,8 +234,6 @@ class DocumentAddByFileApi(DatasetApiResource):
args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
)
# save file info
file = request.files["file"]
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -243,6 +241,8 @@ class DocumentAddByFileApi(DatasetApiResource):
if len(request.files) > 1:
raise TooManyFilesError()
# save file info
file = request.files["file"]
if not file.filename:
raise FilenameNotExistsError

View File

@ -1,5 +1,6 @@
from flask import request
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Unauthorized
from controllers.common import fields
from controllers.web import api
@ -75,14 +76,14 @@ class AppWebAuthPermission(Resource):
try:
auth_header = request.headers.get("Authorization")
if auth_header is None:
raise
raise Unauthorized("Authorization header is missing.")
if " " not in auth_header:
raise
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, tk = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise
raise Unauthorized("Authorization scheme must be 'Bearer'")
decoded = PassportService().verify(tk)
user_id = decoded.get("user_id", "visitor")

View File

@ -12,18 +12,17 @@ from services.file_service import FileService
class FileApi(WebApiResource):
@marshal_with(file_fields)
def post(self, app_model, end_user):
file = request.files["file"]
source = request.form.get("source")
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.filename:
raise FilenameNotExistsError
source = request.form.get("source")
if source not in ("datasets", None):
source = None

View File

@ -118,26 +118,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
):
return
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
)
with Session(db.engine) as session:
db_conversation_variables = session.scalars(stmt).all()
if not db_conversation_variables:
# Create conversation variables if they don't exist.
db_conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in self._workflow.conversation_variables
]
session.add_all(db_conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in db_conversation_variables]
session.commit()
# Initialize conversation variables
conversation_variables = self._initialize_conversation_variables()
# Create a variable pool.
system_inputs = SystemVariable(
@ -292,3 +274,100 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
)
def _initialize_conversation_variables(self) -> list[VariableUnion]:
"""
Initialize conversation variables for the current conversation.
This method:
1. Loads existing variables from the database
2. Creates new variables if none exist
3. Syncs missing variables from the workflow definition
:return: List of conversation variables ready for use
"""
with Session(db.engine) as session:
existing_variables = self._load_existing_conversation_variables(session)
if not existing_variables:
# First time initialization - create all variables
existing_variables = self._create_all_conversation_variables(session)
else:
# Check and add any missing variables from the workflow
existing_variables = self._sync_missing_conversation_variables(session, existing_variables)
# Convert to Variable objects for use in the workflow
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[VariableUnion], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""
Load existing conversation variables from the database.
:param session: Database session
:return: List of existing conversation variables
"""
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
)
return list(session.scalars(stmt).all())
def _create_all_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""
Create all conversation variables for a new conversation.
:param session: Database session
:return: List of created conversation variables
"""
new_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in self._workflow.conversation_variables
]
if new_variables:
session.add_all(new_variables)
return new_variables
def _sync_missing_conversation_variables(
self, session: Session, existing_variables: list[ConversationVariable]
) -> list[ConversationVariable]:
"""
Sync missing conversation variables from the workflow definition.
This handles the case where new variables are added to a workflow
after conversations have already been created.
:param session: Database session
:param existing_variables: List of existing conversation variables
:return: Updated list including any newly created variables
"""
# Get IDs of existing and workflow variables
existing_ids = {var.id for var in existing_variables}
workflow_variables = {var.id: var for var in self._workflow.conversation_variables}
# Find missing variable IDs
missing_ids = set(workflow_variables.keys()) - existing_ids
if not missing_ids:
return existing_variables
# Create missing variables with their default values
new_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id,
conversation_id=self.conversation.id,
variable=workflow_variables[var_id],
)
for var_id in missing_ids
]
session.add_all(new_variables)
# Return combined list
return existing_variables + new_variables

View File

@ -23,6 +23,7 @@ from core.app.entities.task_entities import (
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
StreamEvent,
WorkflowTaskState,
)
from core.llm_generator.llm_generator import LLMGenerator
@ -180,11 +181,15 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first()
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
event=event_type,
)
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:

View File

@ -843,7 +843,7 @@ class ProviderConfiguration(BaseModel):
continue
status = ModelStatus.ACTIVE
if m.model in model_setting_map:
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED

View File

@ -121,9 +121,8 @@ class TokenBufferMemory:
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
if curr_message_tokens > max_token_limit:
pruned_memory = []
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
pruned_memory.append(prompt_messages.pop(0))
prompt_messages.pop(0)
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
return prompt_messages

View File

@ -0,0 +1,190 @@
# Clickzetta Vector Database Integration
This module provides integration with Clickzetta Lakehouse as a vector database for Dify.
## Features
- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type
- **Vector Search**: Efficient similarity search using HNSW algorithm
- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities
- **Hybrid Search**: Combine vector similarity and full-text search for better results
- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing
- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments
## Configuration
### Required Environment Variables
All seven configuration parameters are required:
```bash
# Authentication
CLICKZETTA_USERNAME=your_username
CLICKZETTA_PASSWORD=your_password
# Instance configuration
CLICKZETTA_INSTANCE=your_instance_id
CLICKZETTA_SERVICE=api.clickzetta.com
CLICKZETTA_WORKSPACE=your_workspace
CLICKZETTA_VCLUSTER=your_vcluster
CLICKZETTA_SCHEMA=your_schema
```
### Optional Configuration
```bash
# Batch processing
CLICKZETTA_BATCH_SIZE=100
# Full-text search configuration
CLICKZETTA_ENABLE_INVERTED_INDEX=true
CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode
CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart
# Vector search configuration
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance
```
## Usage
### 1. Set Clickzetta as the Vector Store
In your Dify configuration, set:
```bash
VECTOR_STORE=clickzetta
```
### 2. Table Structure
Clickzetta will automatically create tables with the following structure:
```sql
CREATE TABLE <collection_name> (
id STRING NOT NULL,
content STRING NOT NULL,
metadata JSON,
vector VECTOR(FLOAT, <dimension>) NOT NULL,
PRIMARY KEY (id)
);
-- Vector index for similarity search
CREATE VECTOR INDEX idx_<collection_name>_vec
ON TABLE <schema>.<collection_name>(vector)
PROPERTIES (
"distance.function" = "cosine_distance",
"scalar.type" = "f32"
);
-- Inverted index for full-text search (if enabled)
CREATE INVERTED INDEX idx_<collection_name>_text
ON <schema>.<collection_name>(content)
PROPERTIES (
"analyzer" = "chinese",
"mode" = "smart"
);
```
## Full-Text Search Capabilities
Clickzetta supports advanced full-text search with multiple analyzers:
### Analyzer Types
1. **keyword**: No tokenization, treats the entire string as a single token
- Best for: Exact matching, IDs, codes
2. **english**: Designed for English text
- Features: Recognizes ASCII letters and numbers, converts to lowercase
- Best for: English content
3. **chinese**: Chinese text tokenizer
- Features: Recognizes Chinese and English characters, removes punctuation
- Best for: Chinese or mixed Chinese-English content
4. **unicode**: Multi-language tokenizer based on Unicode
- Features: Recognizes text boundaries in multiple languages
- Best for: Multi-language content
### Analyzer Modes
- **max_word**: Fine-grained tokenization (more tokens)
- **smart**: Intelligent tokenization (balanced)
### Full-Text Search Functions
- `MATCH_ALL(column, query)`: All terms must be present
- `MATCH_ANY(column, query)`: At least one term must be present
- `MATCH_PHRASE(column, query)`: Exact phrase matching
- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching
- `MATCH_REGEXP(column, pattern)`: Regular expression matching
## Performance Optimization
### Vector Search
1. **Adjust exploration factor** for accuracy vs speed trade-off:
```sql
SET cz.vector.index.search.ef=64;
```
2. **Use appropriate distance functions**:
- `cosine_distance`: Best for normalized embeddings (e.g., from language models)
- `l2_distance`: Best for raw feature vectors
### Full-Text Search
1. **Choose the right analyzer**:
- Use `keyword` for exact matching
- Use language-specific analyzers for better tokenization
2. **Combine with vector search**:
- Pre-filter with full-text search for better performance
- Use hybrid search for improved relevance
## Troubleshooting
### Connection Issues
1. Verify all 7 required configuration parameters are set
2. Check network connectivity to Clickzetta service
3. Ensure the user has proper permissions on the schema
### Search Performance
1. Verify vector index exists:
```sql
SHOW INDEX FROM <schema>.<table_name>;
```
2. Check if vector index is being used:
```sql
EXPLAIN SELECT ... WHERE l2_distance(...) < threshold;
```
Look for `vector_index_search_type` in the execution plan.
### Full-Text Search Not Working
1. Verify inverted index is created
2. Check analyzer configuration matches your content language
3. Use `TOKENIZE()` function to test tokenization:
```sql
SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart'));
```
## Limitations
1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns
2. Full-text search relevance scores are not provided by Clickzetta
3. Inverted index creation may fail for very large existing tables (continue without error)
4. Index naming constraints:
- Index names must be unique within a schema
- Only one vector index can be created per column
- The implementation uses timestamps to ensure unique index names
5. A column can only have one vector index at a time
## References
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)

View File

@ -0,0 +1 @@
# Clickzetta Vector Database Integration for Dify

File diff suppressed because it is too large Load Diff

View File

@ -246,6 +246,10 @@ class TencentVector(BaseVector):
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
if not self._enable_hybrid_search:
return []
res = self._client.hybrid_search(
@ -269,6 +273,7 @@ class TencentVector(BaseVector):
),
retrieve_vector=False,
limit=kwargs.get("top_k", 4),
filter=filter,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)

View File

@ -172,6 +172,10 @@ class Vector:
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
return MatrixoneVectorFactory
case VectorType.CLICKZETTA:
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
return ClickzettaVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -30,3 +30,4 @@ class VectorType(StrEnum):
TABLESTORE = "tablestore"
HUAWEI_CLOUD = "huawei_cloud"
MATRIXONE = "matrixone"
CLICKZETTA = "clickzetta"

View File

@ -62,7 +62,7 @@ class WordExtractor(BaseExtractor):
def extract(self) -> list[Document]:
"""Load given path as single page."""
content = self.parse_docx(self.file_path, "storage")
content = self.parse_docx(self.file_path)
return [
Document(
page_content=content,
@ -189,23 +189,8 @@ class WordExtractor(BaseExtractor):
paragraph_content.append(run.text)
return "".join(paragraph_content).strip()
def _parse_paragraph(self, paragraph, image_map):
paragraph_content = []
for run in paragraph.runs:
if run.element.xpath(".//a:blip"):
for blip in run.element.xpath(".//a:blip"):
embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
if embed_id:
rel_target = run.part.rels[embed_id].target_ref
if rel_target in image_map:
paragraph_content.append(image_map[rel_target])
if run.text.strip():
paragraph_content.append(run.text.strip())
return " ".join(paragraph_content) if paragraph_content else ""
def parse_docx(self, docx_path, image_folder):
def parse_docx(self, docx_path):
doc = DocxDocument(docx_path)
os.makedirs(image_folder, exist_ok=True)
content = []

View File

@ -5,14 +5,13 @@ from __future__ import annotations
from typing import Any, Optional
from core.model_manager import ModelInstance
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
from core.rag.splitter.text_splitter import (
TS,
Collection,
Literal,
RecursiveCharacterTextSplitter,
Set,
TokenTextSplitter,
Union,
)
@ -45,14 +44,6 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
return [len(text) for text in texts]
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_character_encoder, **kwargs)

View File

@ -37,12 +37,12 @@ class LocaltimeToTimestampTool(BuiltinTool):
@staticmethod
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
try:
if local_tz is None:
local_tz = datetime.now().astimezone().tzinfo
if isinstance(local_tz, str):
local_tz = pytz.timezone(local_tz)
local_time = datetime.strptime(localtime, time_format)
localtime = local_tz.localize(local_time) # type: ignore
if local_tz is None:
localtime = local_time.astimezone() # type: ignore
elif isinstance(local_tz, str):
local_tz = pytz.timezone(local_tz)
localtime = local_tz.localize(local_time) # type: ignore
timestamp = int(localtime.timestamp()) # type: ignore
return timestamp
except Exception as e:

View File

@ -1,7 +1,8 @@
import json
from collections.abc import Generator
from dataclasses import dataclass
from os import getenv
from typing import Any, Optional
from typing import Any, Optional, Union
from urllib.parse import urlencode
import httpx
@ -20,6 +21,20 @@ API_TOOL_DEFAULT_TIMEOUT = (
)
@dataclass
class ParsedResponse:
"""Represents a parsed HTTP response with type information"""
content: Union[str, dict]
is_json: bool
def to_string(self) -> str:
"""Convert response to string format for credential validation"""
if isinstance(self.content, dict):
return json.dumps(self.content, ensure_ascii=False)
return str(self.content)
class ApiTool(Tool):
"""
Api tool
@ -58,7 +73,9 @@ class ApiTool(Tool):
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
# validate response
return self.validate_and_parse_response(response)
parsed_response = self.validate_and_parse_response(response)
# For credential validation, always return as string
return parsed_response.to_string()
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.API
@ -112,23 +129,36 @@ class ApiTool(Tool):
return headers
def validate_and_parse_response(self, response: httpx.Response) -> str:
def validate_and_parse_response(self, response: httpx.Response) -> ParsedResponse:
"""
validate the response
validate the response and return parsed content with type information
:return: ParsedResponse with content and is_json flag
"""
if isinstance(response, httpx.Response):
if response.status_code >= 400:
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
if not response.content:
return "Empty response from the tool, please check your parameters and try again."
return ParsedResponse(
"Empty response from the tool, please check your parameters and try again.", False
)
# Check content type
content_type = response.headers.get("content-type", "").lower()
is_json_content_type = "application/json" in content_type
# Try to parse as JSON
try:
response = response.json()
try:
return json.dumps(response, ensure_ascii=False)
except Exception:
return json.dumps(response)
json_response = response.json()
# If content-type indicates JSON, return as JSON object
if is_json_content_type:
return ParsedResponse(json_response, True)
else:
# If content-type doesn't indicate JSON, treat as text regardless of content
return ParsedResponse(response.text, False)
except Exception:
return response.text
# Not valid JSON, return as text
return ParsedResponse(response.text, False)
else:
raise ValueError(f"Invalid response type {type(response)}")
@ -369,7 +399,14 @@ class ApiTool(Tool):
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
# validate response
response = self.validate_and_parse_response(response)
parsed_response = self.validate_and_parse_response(response)
# assemble invoke message
yield self.create_text_message(response)
# assemble invoke message based on response type
if parsed_response.is_json and isinstance(parsed_response.content, dict):
yield self.create_json_message(parsed_response.content)
else:
# Convert to string if needed and create text message
text_response = (
parsed_response.content if isinstance(parsed_response.content, str) else str(parsed_response.content)
)
yield self.create_text_message(text_response)

View File

@ -29,7 +29,7 @@ from core.tools.errors import (
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.enums import CreatorUserRole
@ -247,7 +247,8 @@ class ToolEngine:
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
result += json.dumps(
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
)
else:
result += str(response.message)

View File

@ -1,7 +1,14 @@
import logging
from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Optional
from typing import Optional, cast
from uuid import UUID
import numpy as np
import pytz
from flask_login import current_user
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
@ -10,6 +17,41 @@ from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
def safe_json_value(v):
if isinstance(v, datetime):
tz_name = getattr(current_user, "timezone", None) if current_user is not None else None
if not tz_name:
tz_name = "UTC"
return v.astimezone(pytz.timezone(tz_name)).isoformat()
elif isinstance(v, date):
return v.isoformat()
elif isinstance(v, UUID):
return str(v)
elif isinstance(v, Decimal):
return float(v)
elif isinstance(v, bytes):
try:
return v.decode("utf-8")
except UnicodeDecodeError:
return v.hex()
elif isinstance(v, memoryview):
return v.tobytes().hex()
elif isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, dict):
return safe_json_dict(v)
elif isinstance(v, list | tuple | set):
return [safe_json_value(i) for i in v]
else:
return v
def safe_json_dict(d):
if not isinstance(d, dict):
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
return {k: safe_json_value(v) for k, v in d.items()}
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(
@ -113,6 +155,12 @@ class ToolFileMessageTransformer:
)
else:
yield message
elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
json_msg.json_object = safe_json_value(json_msg.json_object)
yield message
else:
yield message

View File

@ -119,6 +119,13 @@ class ObjectSegment(Segment):
class ArraySegment(Segment):
@property
def text(self) -> str:
# Return empty string for empty arrays instead of "[]"
if not self.value:
return ""
return super().text
@property
def markdown(self) -> str:
items = []
@ -155,6 +162,9 @@ class ArrayStringSegment(ArraySegment):
@property
def text(self) -> str:
# Return empty string for empty arrays instead of "[]"
if not self.value:
return ""
return json.dumps(self.value, ensure_ascii=False)

View File

@ -168,7 +168,57 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
"""Extract text from a file based on its file extension."""
match file_extension:
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml":
case (
".txt"
| ".markdown"
| ".md"
| ".html"
| ".htm"
| ".xml"
| ".c"
| ".h"
| ".cpp"
| ".hpp"
| ".cc"
| ".cxx"
| ".c++"
| ".py"
| ".js"
| ".ts"
| ".jsx"
| ".tsx"
| ".java"
| ".php"
| ".rb"
| ".go"
| ".rs"
| ".swift"
| ".kt"
| ".scala"
| ".sh"
| ".bash"
| ".bat"
| ".ps1"
| ".sql"
| ".r"
| ".m"
| ".pl"
| ".lua"
| ".vim"
| ".asm"
| ".s"
| ".css"
| ".scss"
| ".less"
| ".sass"
| ".ini"
| ".cfg"
| ".conf"
| ".toml"
| ".env"
| ".log"
| ".vtt"
):
return _extract_text_from_plain_text(file_content)
case ".json":
return _extract_text_from_json(file_content)
@ -194,8 +244,6 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
return _extract_text_from_eml(file_content)
case ".msg":
return _extract_text_from_msg(file_content)
case ".vtt":
return _extract_text_from_vtt(file_content)
case ".properties":
return _extract_text_from_properties(file_content)
case _:

View File

@ -91,7 +91,7 @@ class Executor:
self.auth = node_data.authorization
self.timeout = timeout
self.ssl_verify = node_data.ssl_verify
self.params = []
self.params = None
self.headers = {}
self.content = None
self.files = None
@ -139,7 +139,8 @@ class Executor:
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
)
self.params = result
if result:
self.params = result
def _init_headers(self):
"""

View File

@ -69,6 +69,19 @@ class Storage:
from extensions.storage.supabase_storage import SupabaseStorage
return SupabaseStorage
case StorageType.CLICKZETTA_VOLUME:
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
ClickZettaVolumeConfig,
ClickZettaVolumeStorage,
)
def create_clickzetta_volume_storage():
# ClickZettaVolumeConfig will automatically read from environment variables
# and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set
volume_config = ClickZettaVolumeConfig()
return ClickZettaVolumeStorage(volume_config)
return create_clickzetta_volume_storage
case _:
raise ValueError(f"unsupported storage type {storage_type}")

View File

@ -0,0 +1,5 @@
"""ClickZetta Volume storage implementation."""
from .clickzetta_volume_storage import ClickZettaVolumeStorage
__all__ = ["ClickZettaVolumeStorage"]

View File

@ -0,0 +1,530 @@
"""ClickZetta Volume Storage Implementation
This module provides storage backend using ClickZetta Volume functionality.
Supports Table Volume, User Volume, and External Volume types.
"""
import logging
import os
import tempfile
from collections.abc import Generator
from io import BytesIO
from pathlib import Path
from typing import Optional
import clickzetta # type: ignore[import]
from pydantic import BaseModel, model_validator
from extensions.storage.base_storage import BaseStorage
from .volume_permissions import VolumePermissionManager, check_volume_permission
logger = logging.getLogger(__name__)
class ClickZettaVolumeConfig(BaseModel):
"""Configuration for ClickZetta Volume storage."""
username: str = ""
password: str = ""
instance: str = ""
service: str = "api.clickzetta.com"
workspace: str = "quick_start"
vcluster: str = "default_ap"
schema_name: str = "dify"
volume_type: str = "table" # table|user|external
volume_name: Optional[str] = None # For external volumes
table_prefix: str = "dataset_" # Prefix for table volume names
dify_prefix: str = "dify_km" # Directory prefix for User Volume
permission_check: bool = True # Enable/disable permission checking
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""Validate the configuration values.
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
then fall back to CLICKZETTA_* environment variables (for vector DB config).
"""
import os
# Helper function to get environment variable with fallback
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
# First try CLICKZETTA_VOLUME_* specific config
volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
if volume_value:
return str(volume_value)
# Then try environment variables
volume_env = os.getenv(volume_key)
if volume_env:
return volume_env
# Fall back to existing CLICKZETTA_* config
fallback_env = os.getenv(fallback_key)
if fallback_env:
return fallback_env
return default or ""
# Apply environment variables with fallback to existing CLICKZETTA_* config
values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
values.setdefault(
"service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
)
values.setdefault(
"workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
)
values.setdefault(
"vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
)
values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
# Volume-specific configurations (no fallback to vector DB config)
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
# 暂时禁用权限检查功能直接设置为false
values.setdefault("permission_check", False)
# Validate required fields
if not values.get("username"):
raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required")
if not values.get("password"):
raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required")
if not values.get("instance"):
raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required")
# Validate volume type
volume_type = values["volume_type"]
if volume_type not in ["table", "user", "external"]:
raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external")
if volume_type == "external" and not values.get("volume_name"):
raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type")
return values
class ClickZettaVolumeStorage(BaseStorage):
"""ClickZetta Volume storage implementation."""
def __init__(self, config: ClickZettaVolumeConfig):
"""Initialize ClickZetta Volume storage.
Args:
config: ClickZetta Volume configuration
"""
self._config = config
self._connection = None
self._permission_manager: VolumePermissionManager | None = None
self._init_connection()
self._init_permission_manager()
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
def _init_connection(self):
"""Initialize ClickZetta connection."""
try:
self._connection = clickzetta.connect(
username=self._config.username,
password=self._config.password,
instance=self._config.instance,
service=self._config.service,
workspace=self._config.workspace,
vcluster=self._config.vcluster,
schema=self._config.schema_name,
)
logger.debug("ClickZetta connection established")
except Exception as e:
logger.exception("Failed to connect to ClickZetta")
raise
def _init_permission_manager(self):
"""Initialize permission manager."""
try:
self._permission_manager = VolumePermissionManager(
self._connection, self._config.volume_type, self._config.volume_name
)
logger.debug("Permission manager initialized")
except Exception as e:
logger.exception("Failed to initialize permission manager")
raise
def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str:
"""Get the appropriate volume path based on volume type."""
if self._config.volume_type == "user":
# Add dify prefix for User Volume to organize files
return f"{self._config.dify_prefix}/{filename}"
elif self._config.volume_type == "table":
# Check if this should use User Volume (special directories)
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
# Use User Volume with dify prefix for special directories
return f"{self._config.dify_prefix}/{filename}"
if dataset_id:
return f"{self._config.table_prefix}{dataset_id}/{filename}"
else:
# Extract dataset_id from filename if not provided
# Format: dataset_id/filename
if "/" in filename:
return filename
else:
raise ValueError("dataset_id is required for table volume or filename must include dataset_id/")
elif self._config.volume_type == "external":
return filename
else:
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str:
"""Get SQL prefix for volume operations."""
if self._config.volume_type == "user":
return "USER VOLUME"
elif self._config.volume_type == "table":
# For Dify's current file storage pattern, most files are stored in
# paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext"
# These should use USER VOLUME for better compatibility
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
return "USER VOLUME"
# Only use TABLE VOLUME for actual dataset-specific paths
# like "dataset_12345/file.pdf" or paths with dataset_ prefix
if dataset_id:
table_name = f"{self._config.table_prefix}{dataset_id}"
else:
# Default table name for generic operations
table_name = "default_dataset"
return f"TABLE VOLUME {table_name}"
elif self._config.volume_type == "external":
return f"VOLUME {self._config.volume_name}"
else:
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
def _execute_sql(self, sql: str, fetch: bool = False):
"""Execute SQL command."""
try:
if self._connection is None:
raise RuntimeError("Connection not initialized")
with self._connection.cursor() as cursor:
cursor.execute(sql)
if fetch:
return cursor.fetchall()
return None
except Exception as e:
logger.exception("SQL execution failed: %s", sql)
raise
def _ensure_table_volume_exists(self, dataset_id: str) -> None:
"""Ensure table volume exists for the given dataset_id."""
if self._config.volume_type != "table" or not dataset_id:
return
# Skip for upload_files and other special directories that use USER VOLUME
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
return
table_name = f"{self._config.table_prefix}{dataset_id}"
try:
# Check if table exists
check_sql = f"SHOW TABLES LIKE '{table_name}'"
result = self._execute_sql(check_sql, fetch=True)
if not result:
# Create table with volume
create_sql = f"""
CREATE TABLE {table_name} (
id INT PRIMARY KEY AUTO_INCREMENT,
filename VARCHAR(255) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_filename (filename)
) WITH VOLUME
"""
self._execute_sql(create_sql)
logger.info("Created table volume: %s", table_name)
except Exception as e:
logger.warning("Failed to create table volume %s: %s", table_name, e)
# Don't raise exception, let the operation continue
# The table might exist but not be visible due to permissions
def save(self, filename: str, data: bytes) -> None:
"""Save data to ClickZetta Volume.
Args:
filename: File path in volume
data: File content as bytes
"""
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
# Ensure table volume exists (for table volumes)
if dataset_id:
self._ensure_table_volume_exists(dataset_id)
# Check permissions (if enabled)
if self._config.permission_check:
# Skip permission check for special directories that use USER VOLUME
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
if self._permission_manager is not None:
check_volume_permission(self._permission_manager, "save", dataset_id)
# Write data to temporary file
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(data)
temp_file_path = temp_file.name
try:
# Upload to volume
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'"
else:
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'"
self._execute_sql(sql)
logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path)
finally:
# Clean up temporary file
Path(temp_file_path).unlink(missing_ok=True)
def load_once(self, filename: str) -> bytes:
"""Load file content from ClickZetta Volume.
Args:
filename: File path in volume
Returns:
File content as bytes
"""
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
# Check permissions (if enabled)
if self._config.permission_check:
# Skip permission check for special directories that use USER VOLUME
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
if self._permission_manager is not None:
check_volume_permission(self._permission_manager, "load_once", dataset_id)
# Download to temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'"
else:
sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'"
self._execute_sql(sql)
# Find the downloaded file (may be in subdirectories)
downloaded_file = None
for root, dirs, files in os.walk(temp_dir):
for file in files:
if file == filename or file == os.path.basename(filename):
downloaded_file = Path(root) / file
break
if downloaded_file:
break
if not downloaded_file or not downloaded_file.exists():
raise FileNotFoundError(f"Downloaded file not found: {filename}")
content = downloaded_file.read_bytes()
logger.debug("File %s loaded from ClickZetta Volume", filename)
return content
def load_stream(self, filename: str) -> Generator:
"""Load file as stream from ClickZetta Volume.
Args:
filename: File path in volume
Yields:
File content chunks
"""
content = self.load_once(filename)
batch_size = 4096
stream = BytesIO(content)
while chunk := stream.read(batch_size):
yield chunk
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
def download(self, filename: str, target_filepath: str):
"""Download file from ClickZetta Volume to local path.
Args:
filename: File path in volume
target_filepath: Local target file path
"""
content = self.load_once(filename)
with Path(target_filepath).open("wb") as f:
f.write(content)
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
def exists(self, filename: str) -> bool:
"""Check if file exists in ClickZetta Volume.
Args:
filename: File path in volume
Returns:
True if file exists, False otherwise
"""
try:
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'"
else:
sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'"
rows = self._execute_sql(sql, fetch=True)
exists = len(rows) > 0
logger.debug("File %s exists check: %s", filename, exists)
return exists
except Exception as e:
logger.warning("Error checking file existence for %s: %s", filename, e)
return False
def delete(self, filename: str):
"""Delete file from ClickZetta Volume.
Args:
filename: File path in volume
"""
if not self.exists(filename):
logger.debug("File %s not found, skip delete", filename)
return
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"REMOVE {volume_prefix} FILE '{volume_path}'"
else:
sql = f"REMOVE {volume_prefix} FILE '{filename}'"
self._execute_sql(sql)
logger.debug("File %s deleted from ClickZetta Volume", filename)
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
"""Scan files and directories in ClickZetta Volume.
Args:
path: Path to scan (dataset_id for table volumes)
files: Include files in results
directories: Include directories in results
Returns:
List of file/directory paths
"""
try:
# For table volumes, path is treated as dataset_id
dataset_id = None
if self._config.volume_type == "table":
dataset_id = path
path = "" # Root of the table volume
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# For User Volume, add dify prefix to path
if volume_prefix == "USER VOLUME":
if path:
scan_path = f"{self._config.dify_prefix}/{path}"
sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'"
else:
sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'"
else:
if path:
sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'"
else:
sql = f"LIST {volume_prefix}"
rows = self._execute_sql(sql, fetch=True)
result = []
for row in rows:
file_path = row[0] # relative_path column
# For User Volume, remove dify prefix from results
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
result.append(file_path)
logger.debug("Scanned %d items in path %s", len(result), path)
return result
except Exception as e:
logger.exception("Error scanning path %s", path)
return []

View File

@ -0,0 +1,516 @@
"""ClickZetta Volume文件生命周期管理
该模块提供文件版本控制自动清理备份和恢复等生命周期管理功能
支持知识库文件的完整生命周期管理
"""
import json
import logging
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Optional
logger = logging.getLogger(__name__)
class FileStatus(Enum):
"""文件状态枚举"""
ACTIVE = "active" # 活跃状态
ARCHIVED = "archived" # 已归档
DELETED = "deleted" # 已删除(软删除)
BACKUP = "backup" # 备份文件
@dataclass
class FileMetadata:
"""文件元数据"""
filename: str
size: int | None
created_at: datetime
modified_at: datetime
version: int | None
status: FileStatus
checksum: Optional[str] = None
tags: Optional[dict[str, str]] = None
parent_version: Optional[int] = None
def to_dict(self) -> dict:
"""转换为字典格式"""
data = asdict(self)
data["created_at"] = self.created_at.isoformat()
data["modified_at"] = self.modified_at.isoformat()
data["status"] = self.status.value
return data
@classmethod
def from_dict(cls, data: dict) -> "FileMetadata":
"""从字典创建实例"""
data = data.copy()
data["created_at"] = datetime.fromisoformat(data["created_at"])
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
data["status"] = FileStatus(data["status"])
return cls(**data)
class FileLifecycleManager:
"""文件生命周期管理器"""
def __init__(self, storage, dataset_id: Optional[str] = None):
"""初始化生命周期管理器
Args:
storage: ClickZetta Volume存储实例
dataset_id: 数据集ID用于Table Volume
"""
self._storage = storage
self._dataset_id = dataset_id
self._metadata_file = ".dify_file_metadata.json"
self._version_prefix = ".versions/"
self._backup_prefix = ".backups/"
self._deleted_prefix = ".deleted/"
# 获取权限管理器(如果存在)
self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None)
def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata:
"""保存文件并管理生命周期
Args:
filename: 文件名
data: 文件内容
tags: 文件标签
Returns:
文件元数据
"""
# 权限检查
if not self._check_permission(filename, "save"):
from .volume_permissions import VolumePermissionError
raise VolumePermissionError(
f"Permission denied for lifecycle save operation on file: {filename}",
operation="save",
volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
dataset_id=self._dataset_id,
)
try:
# 1. 检查是否存在旧版本
metadata_dict = self._load_metadata()
current_metadata = metadata_dict.get(filename)
# 2. 如果存在旧版本,创建版本备份
if current_metadata:
self._create_version_backup(filename, current_metadata)
# 3. 计算文件信息
now = datetime.now()
checksum = self._calculate_checksum(data)
new_version = (current_metadata["version"] + 1) if current_metadata else 1
# 4. 保存新文件
self._storage.save(filename, data)
# 5. 创建元数据
created_at = now
parent_version = None
if current_metadata:
# 如果created_at是字符串转换为datetime
if isinstance(current_metadata["created_at"], str):
created_at = datetime.fromisoformat(current_metadata["created_at"])
else:
created_at = current_metadata["created_at"]
parent_version = current_metadata["version"]
file_metadata = FileMetadata(
filename=filename,
size=len(data),
created_at=created_at,
modified_at=now,
version=new_version,
status=FileStatus.ACTIVE,
checksum=checksum,
tags=tags or {},
parent_version=parent_version,
)
# 6. 更新元数据
metadata_dict[filename] = file_metadata.to_dict()
self._save_metadata(metadata_dict)
logger.info("File %s saved with lifecycle management, version %s", filename, new_version)
return file_metadata
except Exception as e:
logger.exception("Failed to save file with lifecycle")
raise
def get_file_metadata(self, filename: str) -> Optional[FileMetadata]:
"""获取文件元数据
Args:
filename: 文件名
Returns:
文件元数据如果不存在返回None
"""
try:
metadata_dict = self._load_metadata()
if filename in metadata_dict:
return FileMetadata.from_dict(metadata_dict[filename])
return None
except Exception as e:
logger.exception("Failed to get file metadata for %s", filename)
return None
def list_file_versions(self, filename: str) -> list[FileMetadata]:
"""列出文件的所有版本
Args:
filename: 文件名
Returns:
文件版本列表按版本号排序
"""
try:
versions = []
# 获取当前版本
current_metadata = self.get_file_metadata(filename)
if current_metadata:
versions.append(current_metadata)
# 获取历史版本
version_pattern = f"{self._version_prefix}{filename}.v*"
try:
version_files = self._storage.scan(self._dataset_id or "", files=True)
for file_path in version_files:
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
# 解析版本号
version_str = file_path.split(".v")[-1].split(".")[0]
try:
version_num = int(version_str)
# 这里简化处理,实际应该从版本文件中读取元数据
# 暂时创建基本的元数据信息
except ValueError:
continue
except:
# 如果无法扫描版本文件,只返回当前版本
pass
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
except Exception as e:
logger.exception("Failed to list file versions for %s", filename)
return []
def restore_version(self, filename: str, version: int) -> bool:
"""恢复文件到指定版本
Args:
filename: 文件名
version: 要恢复的版本号
Returns:
恢复是否成功
"""
try:
version_filename = f"{self._version_prefix}{filename}.v{version}"
# 检查版本文件是否存在
if not self._storage.exists(version_filename):
logger.warning("Version %s of %s not found", version, filename)
return False
# 读取版本文件内容
version_data = self._storage.load_once(version_filename)
# 保存当前版本为备份
current_metadata = self.get_file_metadata(filename)
if current_metadata:
self._create_version_backup(filename, current_metadata.to_dict())
# 恢复文件
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
return True
except Exception as e:
logger.exception("Failed to restore %s to version %s", filename, version)
return False
def archive_file(self, filename: str) -> bool:
"""归档文件
Args:
filename: 文件名
Returns:
归档是否成功
"""
# 权限检查
if not self._check_permission(filename, "archive"):
logger.warning("Permission denied for archive operation on file: %s", filename)
return False
try:
# 更新文件状态为归档
metadata_dict = self._load_metadata()
if filename not in metadata_dict:
logger.warning("File %s not found in metadata", filename)
return False
metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)
logger.info("File %s archived successfully", filename)
return True
except Exception as e:
logger.exception("Failed to archive file %s", filename)
return False
def soft_delete_file(self, filename: str) -> bool:
"""软删除文件(移动到删除目录)
Args:
filename: 文件名
Returns:
删除是否成功
"""
# 权限检查
if not self._check_permission(filename, "delete"):
logger.warning("Permission denied for soft delete operation on file: %s", filename)
return False
try:
# 检查文件是否存在
if not self._storage.exists(filename):
logger.warning("File %s not found", filename)
return False
# 读取文件内容
file_data = self._storage.load_once(filename)
# 移动到删除目录
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self._storage.save(deleted_filename, file_data)
# 删除原文件
self._storage.delete(filename)
# 更新元数据
metadata_dict = self._load_metadata()
if filename in metadata_dict:
metadata_dict[filename]["status"] = FileStatus.DELETED.value
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)
logger.info("File %s soft deleted successfully", filename)
return True
except Exception as e:
logger.exception("Failed to soft delete file %s", filename)
return False
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
"""清理旧版本文件
Args:
max_versions: 保留的最大版本数
max_age_days: 版本文件的最大保留天数
Returns:
清理的文件数量
"""
try:
cleaned_count = 0
cutoff_date = datetime.now() - timedelta(days=max_age_days)
# 获取所有版本文件
try:
all_files = self._storage.scan(self._dataset_id or "", files=True)
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
# 按文件分组
file_versions: dict[str, list[tuple[int, str]]] = {}
for version_file in version_files:
# 解析文件名和版本
parts = version_file[len(self._version_prefix) :].split(".v")
if len(parts) >= 2:
base_filename = parts[0]
version_part = parts[1].split(".")[0]
try:
version_num = int(version_part)
if base_filename not in file_versions:
file_versions[base_filename] = []
file_versions[base_filename].append((version_num, version_file))
except ValueError:
continue
# 清理每个文件的旧版本
for base_filename, versions in file_versions.items():
# 按版本号排序
versions.sort(key=lambda x: x[0], reverse=True)
# 保留最新的max_versions个版本删除其余的
if len(versions) > max_versions:
to_delete = versions[max_versions:]
for version_num, version_file in to_delete:
self._storage.delete(version_file)
cleaned_count += 1
logger.debug("Cleaned old version: %s", version_file)
logger.info("Cleaned %d old version files", cleaned_count)
except Exception as e:
logger.warning("Could not scan for version files: %s", e)
return cleaned_count
except Exception as e:
logger.exception("Failed to cleanup old versions")
return 0
def get_storage_statistics(self) -> dict[str, Any]:
"""获取存储统计信息
Returns:
存储统计字典
"""
try:
metadata_dict = self._load_metadata()
stats: dict[str, Any] = {
"total_files": len(metadata_dict),
"active_files": 0,
"archived_files": 0,
"deleted_files": 0,
"total_size": 0,
"versions_count": 0,
"oldest_file": None,
"newest_file": None,
}
oldest_date = None
newest_date = None
for filename, metadata in metadata_dict.items():
file_meta = FileMetadata.from_dict(metadata)
# 统计文件状态
if file_meta.status == FileStatus.ACTIVE:
stats["active_files"] = (stats["active_files"] or 0) + 1
elif file_meta.status == FileStatus.ARCHIVED:
stats["archived_files"] = (stats["archived_files"] or 0) + 1
elif file_meta.status == FileStatus.DELETED:
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
# 统计大小
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
# 统计版本
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
# 找出最新和最旧的文件
if oldest_date is None or file_meta.created_at < oldest_date:
oldest_date = file_meta.created_at
stats["oldest_file"] = filename
if newest_date is None or file_meta.modified_at > newest_date:
newest_date = file_meta.modified_at
stats["newest_file"] = filename
return stats
except Exception as e:
logger.exception("Failed to get storage statistics")
return {}
def _create_version_backup(self, filename: str, metadata: dict):
"""创建版本备份"""
try:
# 读取当前文件内容
current_data = self._storage.load_once(filename)
# 保存为版本文件
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
self._storage.save(version_filename, current_data)
logger.debug("Created version backup: %s", version_filename)
except Exception as e:
logger.warning("Failed to create version backup for %s: %s", filename, e)
def _load_metadata(self) -> dict[str, Any]:
"""加载元数据文件"""
try:
if self._storage.exists(self._metadata_file):
metadata_content = self._storage.load_once(self._metadata_file)
result = json.loads(metadata_content.decode("utf-8"))
return dict(result) if result else {}
else:
return {}
except Exception as e:
logger.warning("Failed to load metadata: %s", e)
return {}
def _save_metadata(self, metadata_dict: dict):
"""保存元数据文件"""
try:
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
logger.debug("Metadata saved successfully")
except Exception as e:
logger.exception("Failed to save metadata")
raise
def _calculate_checksum(self, data: bytes) -> str:
"""计算文件校验和"""
import hashlib
return hashlib.md5(data).hexdigest()
def _check_permission(self, filename: str, operation: str) -> bool:
"""检查文件操作权限
Args:
filename: 文件名
operation: 操作类型
Returns:
True if permission granted, False otherwise
"""
# 如果没有权限管理器,默认允许
if not self._permission_manager:
return True
try:
# 根据操作类型映射到权限
operation_mapping = {
"save": "save",
"load": "load_once",
"delete": "delete",
"archive": "delete", # 归档需要删除权限
"restore": "save", # 恢复需要写权限
"cleanup": "delete", # 清理需要删除权限
"read": "load_once",
"write": "save",
}
mapped_operation = operation_mapping.get(operation, operation)
# 检查权限
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
return bool(result)
except Exception as e:
logger.exception("Permission check failed for %s operation %s", filename, operation)
# 安全默认:权限检查失败时拒绝访问
return False

View File

@ -0,0 +1,646 @@
"""ClickZetta Volume权限管理机制
该模块提供Volume权限检查验证和管理功能
根据ClickZetta的权限模型不同Volume类型有不同的权限要求
"""
import logging
from enum import Enum
from typing import Optional
logger = logging.getLogger(__name__)
class VolumePermission(Enum):
"""Volume权限类型枚举"""
READ = "SELECT" # 对应ClickZetta的SELECT权限
WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限
LIST = "SELECT" # 列出文件需要SELECT权限
DELETE = "INSERT,UPDATE,DELETE" # 删除文件需要写权限
USAGE = "USAGE" # External Volume需要的基本权限
class VolumePermissionManager:
"""Volume权限管理器"""
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None):
"""初始化权限管理器
Args:
connection_or_config: ClickZetta连接对象或配置字典
volume_type: Volume类型 (user|table|external)
volume_name: Volume名称 (用于external volume)
"""
# 支持两种初始化方式:连接对象或配置字典
if isinstance(connection_or_config, dict):
# 从配置字典创建连接
import clickzetta # type: ignore[import-untyped]
config = connection_or_config
self._connection = clickzetta.connect(
username=config.get("username"),
password=config.get("password"),
instance=config.get("instance"),
service=config.get("service"),
workspace=config.get("workspace"),
vcluster=config.get("vcluster"),
schema=config.get("schema") or config.get("database"),
)
self._volume_type = config.get("volume_type", volume_type)
self._volume_name = config.get("volume_name", volume_name)
else:
# 直接使用连接对象
self._connection = connection_or_config
self._volume_type = volume_type
self._volume_name = volume_name
if not self._connection:
raise ValueError("Valid connection or config is required")
if not self._volume_type:
raise ValueError("volume_type is required")
self._permission_cache: dict[str, set[str]] = {}
self._current_username = None # 将从连接中获取当前用户名
def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool:
"""检查用户是否有执行特定操作的权限
Args:
operation: 要执行的操作类型
dataset_id: 数据集ID (用于table volume)
Returns:
True if user has permission, False otherwise
"""
try:
if self._volume_type == "user":
return self._check_user_volume_permission(operation)
elif self._volume_type == "table":
return self._check_table_volume_permission(operation, dataset_id)
elif self._volume_type == "external":
return self._check_external_volume_permission(operation)
else:
logger.warning("Unknown volume type: %s", self._volume_type)
return False
except Exception as e:
logger.exception("Permission check failed")
return False
def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
"""检查User Volume权限
User Volume权限规则:
- 用户对自己的User Volume有全部权限
- 只要用户能够连接到ClickZetta就默认具有User Volume的基本权限
- 更注重连接身份验证而不是复杂的权限检查
"""
try:
# 获取当前用户名
current_user = self._get_current_username()
# 检查基本连接状态
with self._connection.cursor() as cursor:
# 简单的连接测试,如果能执行查询说明用户有基本权限
cursor.execute("SELECT 1")
result = cursor.fetchone()
if result:
logger.debug(
"User Volume permission check for %s, operation %s: granted (basic connection verified)",
current_user,
operation.name,
)
return True
else:
logger.warning(
"User Volume permission check failed: cannot verify basic connection for %s", current_user
)
return False
except Exception as e:
logger.exception("User Volume permission check failed")
# 对于User Volume如果权限检查失败可能是配置问题给出更友好的错误提示
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
return False
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool:
"""检查Table Volume权限
Table Volume权限规则:
- Table Volume权限继承对应表的权限
- SELECT权限 -> 可以READ/LIST文件
- INSERT,UPDATE,DELETE权限 -> 可以WRITE/DELETE文件
"""
if not dataset_id:
logger.warning("dataset_id is required for table volume permission check")
return False
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
try:
# 检查表权限
permissions = self._get_table_permissions(table_name)
required_permissions = set(operation.value.split(","))
# 检查是否有所需的所有权限
has_permission = required_permissions.issubset(permissions)
logger.debug(
"Table Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
table_name,
operation.name,
required_permissions,
permissions,
has_permission,
)
return has_permission
except Exception as e:
logger.exception("Table volume permission check failed for %s", table_name)
return False
def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
"""检查External Volume权限
External Volume权限规则:
- 尝试获取对External Volume的权限
- 如果权限检查失败进行备选验证
- 对于开发环境提供更宽松的权限检查
"""
if not self._volume_name:
logger.warning("volume_name is required for external volume permission check")
return False
try:
# 检查External Volume权限
permissions = self._get_external_volume_permissions(self._volume_name)
# External Volume权限映射根据操作类型确定所需权限
required_permissions = set()
if operation in [VolumePermission.READ, VolumePermission.LIST]:
required_permissions.add("read")
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
required_permissions.add("write")
# 检查是否有所需的所有权限
has_permission = required_permissions.issubset(permissions)
logger.debug(
"External Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
self._volume_name,
operation.name,
required_permissions,
permissions,
has_permission,
)
# 如果权限检查失败,尝试备选验证
if not has_permission:
logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name)
# 备选验证尝试列出Volume来验证基本访问权限
try:
with self._connection.cursor() as cursor:
cursor.execute("SHOW VOLUMES")
volumes = cursor.fetchall()
for volume in volumes:
if len(volume) > 0 and volume[0] == self._volume_name:
logger.info("Fallback verification successful for %s", self._volume_name)
return True
except Exception as fallback_e:
logger.warning("Fallback verification failed for %s: %s", self._volume_name, fallback_e)
return has_permission
except Exception as e:
logger.exception("External volume permission check failed for %s", self._volume_name)
logger.info("External Volume permission check failed, but permission checking is disabled in this version")
return False
def _get_table_permissions(self, table_name: str) -> set[str]:
"""获取用户对指定表的权限
Args:
table_name: 表名
Returns:
用户对该表的权限集合
"""
cache_key = f"table:{table_name}"
if cache_key in self._permission_cache:
return self._permission_cache[cache_key]
permissions = set()
try:
with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查当前用户权限
cursor.execute("SHOW GRANTS")
grants = cursor.fetchall()
# 解析权限结果,查找对该表的权限
for grant in grants:
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
privilege = grant[0].upper()
object_type = grant[1].upper() if len(grant) > 1 else ""
object_name = grant[2] if len(grant) > 2 else ""
# 检查是否是对该表的权限
if (
object_type == "TABLE"
and object_name == table_name
or object_type == "SCHEMA"
and object_name in table_name
):
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
if privilege == "ALL":
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
else:
permissions.add(privilege)
# 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限
if not permissions:
try:
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
permissions.add("SELECT")
except Exception:
logger.debug("Cannot query table %s, no SELECT permission", table_name)
except Exception as e:
logger.warning("Could not check table permissions for %s: %s", table_name, e)
# 安全默认:权限检查失败时拒绝访问
pass
# 缓存权限信息
self._permission_cache[cache_key] = permissions
return permissions
def _get_current_username(self) -> str:
"""获取当前用户名"""
if self._current_username:
return self._current_username
try:
with self._connection.cursor() as cursor:
cursor.execute("SELECT CURRENT_USER()")
result = cursor.fetchone()
if result:
self._current_username = result[0]
return str(self._current_username)
except Exception as e:
logger.exception("Failed to get current username")
return "unknown"
def _get_user_permissions(self, username: str) -> set[str]:
"""获取用户的基本权限集合"""
cache_key = f"user_permissions:{username}"
if cache_key in self._permission_cache:
return self._permission_cache[cache_key]
permissions = set()
try:
with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查当前用户权限
cursor.execute("SHOW GRANTS")
grants = cursor.fetchall()
# 解析权限结果,查找用户的基本权限
for grant in grants:
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
privilege = grant[0].upper()
object_type = grant[1].upper() if len(grant) > 1 else ""
# 收集所有相关权限
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
if privilege == "ALL":
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
else:
permissions.add(privilege)
except Exception as e:
logger.warning("Could not check user permissions for %s: %s", username, e)
# 安全默认:权限检查失败时拒绝访问
pass
# 缓存权限信息
self._permission_cache[cache_key] = permissions
return permissions
def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
"""获取用户对指定External Volume的权限
Args:
volume_name: External Volume名称
Returns:
用户对该Volume的权限集合
"""
cache_key = f"external_volume:{volume_name}"
if cache_key in self._permission_cache:
return self._permission_cache[cache_key]
permissions = set()
try:
with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查Volume权限
logger.info("Checking permissions for volume: %s", volume_name)
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
grants = cursor.fetchall()
logger.info("Raw grants result for %s: %s", volume_name, grants)
# 解析权限结果
# 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
# grantee_name, grantor_name, grant_option, granted_time)
for grant in grants:
logger.info("Processing grant: %s", grant)
if len(grant) >= 5:
granted_type = grant[0]
privilege = grant[1].upper()
granted_on = grant[3]
object_name = grant[4]
logger.info(
"Grant details - type: %s, privilege: %s, granted_on: %s, object_name: %s",
granted_type,
privilege,
granted_on,
object_name,
)
# 检查是否是对该Volume的权限或者是层级权限
if (
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
logger.info("Matching grant found for %s", volume_name)
if "READ" in privilege:
permissions.add("read")
logger.info("Added READ permission for %s", volume_name)
if "WRITE" in privilege:
permissions.add("write")
logger.info("Added WRITE permission for %s", volume_name)
if "ALTER" in privilege:
permissions.add("alter")
logger.info("Added ALTER permission for %s", volume_name)
if privilege == "ALL":
permissions.update(["read", "write", "alter"])
logger.info("Added ALL permissions for %s", volume_name)
logger.info("Final permissions for %s: %s", volume_name, permissions)
# 如果没有找到明确的权限尝试查看Volume列表来验证基本权限
if not permissions:
try:
cursor.execute("SHOW VOLUMES")
volumes = cursor.fetchall()
for volume in volumes:
if len(volume) > 0 and volume[0] == volume_name:
permissions.add("read") # 至少有读权限
logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name)
break
except Exception:
logger.debug("Cannot access volume %s, no basic permission", volume_name)
except Exception as e:
logger.warning("Could not check external volume permissions for %s: %s", volume_name, e)
# 在权限检查失败时尝试基本的Volume访问验证
try:
with self._connection.cursor() as cursor:
cursor.execute("SHOW VOLUMES")
volumes = cursor.fetchall()
for volume in volumes:
if len(volume) > 0 and volume[0] == volume_name:
logger.info("Basic volume access verified for %s", volume_name)
permissions.add("read")
permissions.add("write") # 假设有写权限
break
except Exception as basic_e:
logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e)
# 最后的备选方案:假设有基本权限
permissions.add("read")
# 缓存权限信息
self._permission_cache[cache_key] = permissions
return permissions
def clear_permission_cache(self):
"""清空权限缓存"""
self._permission_cache.clear()
logger.debug("Permission cache cleared")
def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]:
"""获取权限摘要
Args:
dataset_id: 数据集ID (用于table volume)
Returns:
权限摘要字典
"""
summary = {}
for operation in VolumePermission:
summary[operation.name.lower()] = self.check_permission(operation, dataset_id)
return summary
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
"""检查文件路径的权限继承
Args:
file_path: 文件路径
operation: 要执行的操作
Returns:
True if user has permission, False otherwise
"""
try:
# 解析文件路径
path_parts = file_path.strip("/").split("/")
if not path_parts:
logger.warning("Invalid file path for permission inheritance check")
return False
# 对于Table Volume第一层是dataset_id
if self._volume_type == "table":
if len(path_parts) < 1:
return False
dataset_id = path_parts[0]
# 检查对dataset的权限
has_dataset_permission = self.check_permission(operation, dataset_id)
if not has_dataset_permission:
logger.debug("Permission denied for dataset %s", dataset_id)
return False
# 检查路径遍历攻击
if self._contains_path_traversal(file_path):
logger.warning("Path traversal attack detected: %s", file_path)
return False
# 检查是否访问敏感目录
if self._is_sensitive_path(file_path):
logger.warning("Access to sensitive path denied: %s", file_path)
return False
logger.debug("Permission inherited for path %s", file_path)
return True
elif self._volume_type == "user":
# User Volume的权限继承
current_user = self._get_current_username()
# 检查是否试图访问其他用户的目录
if len(path_parts) > 1 and path_parts[0] != current_user:
logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0])
return False
# 检查基本权限
return self.check_permission(operation)
elif self._volume_type == "external":
# External Volume的权限继承
# 检查对External Volume的权限
return self.check_permission(operation)
else:
logger.warning("Unknown volume type for permission inheritance: %s", self._volume_type)
return False
except Exception as e:
logger.exception("Permission inheritance check failed")
return False
def _contains_path_traversal(self, file_path: str) -> bool:
"""检查路径是否包含路径遍历攻击"""
# 检查常见的路径遍历模式
traversal_patterns = [
"../",
"..\\",
"..%2f",
"..%2F",
"..%5c",
"..%5C",
"%2e%2e%2f",
"%2e%2e%5c",
"....//",
"....\\\\",
]
file_path_lower = file_path.lower()
for pattern in traversal_patterns:
if pattern in file_path_lower:
return True
# 检查绝对路径
if file_path.startswith("/") or file_path.startswith("\\"):
return True
# 检查Windows驱动器路径
if len(file_path) >= 2 and file_path[1] == ":":
return True
return False
def _is_sensitive_path(self, file_path: str) -> bool:
"""检查路径是否为敏感路径"""
sensitive_patterns = [
"passwd",
"shadow",
"hosts",
"config",
"secrets",
"private",
"key",
"certificate",
"cert",
"ssl",
"database",
"backup",
"dump",
"log",
"tmp",
]
file_path_lower = file_path.lower()
return any(pattern in file_path_lower for pattern in sensitive_patterns)
def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool:
"""验证操作权限
Args:
operation: 操作名称 (save|load|exists|delete|scan)
dataset_id: 数据集ID
Returns:
True if operation is allowed, False otherwise
"""
operation_mapping = {
"save": VolumePermission.WRITE,
"load": VolumePermission.READ,
"load_once": VolumePermission.READ,
"load_stream": VolumePermission.READ,
"download": VolumePermission.READ,
"exists": VolumePermission.READ,
"delete": VolumePermission.DELETE,
"scan": VolumePermission.LIST,
}
if operation not in operation_mapping:
logger.warning("Unknown operation: %s", operation)
return False
volume_permission = operation_mapping[operation]
return self.check_permission(volume_permission, dataset_id)
class VolumePermissionError(Exception):
"""Volume权限错误异常"""
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None):
self.operation = operation
self.volume_type = volume_type
self.dataset_id = dataset_id
super().__init__(message)
def check_volume_permission(
permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None
) -> None:
"""权限检查装饰器函数
Args:
permission_manager: 权限管理器
operation: 操作名称
dataset_id: 数据集ID
Raises:
VolumePermissionError: 如果没有权限
"""
if not permission_manager.validate_operation(operation, dataset_id):
error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"
if dataset_id:
error_message += f" (dataset: {dataset_id})"
raise VolumePermissionError(
error_message,
operation=operation,
volume_type=permission_manager._volume_type or "unknown",
dataset_id=dataset_id,
)

View File

@ -5,6 +5,7 @@ class StorageType(StrEnum):
ALIYUN_OSS = "aliyun-oss"
AZURE_BLOB = "azure-blob"
BAIDU_OBS = "baidu-obs"
CLICKZETTA_VOLUME = "clickzetta-volume"
GOOGLE_STORAGE = "google-storage"
HUAWEI_OBS = "huawei-obs"
LOCAL = "local"

View File

@ -1,5 +1,4 @@
import hashlib
import os
from typing import Union
from Crypto.Cipher import AES
@ -18,7 +17,7 @@ def generate_key_pair(tenant_id: str) -> str:
pem_private = private_key.export_key()
pem_public = public_key.export_key()
filepath = os.path.join("privkeys", tenant_id, "private.pem")
filepath = f"privkeys/{tenant_id}/private.pem"
storage.save(filepath, pem_private)
@ -48,7 +47,7 @@ def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
filepath = os.path.join("privkeys", tenant_id, "private.pem")
filepath = f"privkeys/{tenant_id}/private.pem"
cache_key = f"tenant_privkey:{hashlib.sha3_256(filepath.encode()).hexdigest()}"
private_key = redis_client.get(cache_key)

View File

@ -194,6 +194,7 @@ vdb = [
"alibabacloud_tea_openapi~=0.3.9",
"chromadb==0.5.20",
"clickhouse-connect~=0.7.16",
"clickzetta-connector-python>=0.8.102",
"couchbase~=4.3.0",
"elasticsearch==8.14.0",
"opensearch-py==2.4.0",
@ -213,3 +214,4 @@ vdb = [
"xinference-client~=1.2.2",
"mo-vector~=0.1.13",
]

View File

@ -3,7 +3,7 @@ import time
import click
from sqlalchemy import text
from werkzeug.exceptions import NotFound
from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
@ -27,8 +27,8 @@ def clean_embedding_cache_task():
.all()
)
embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
except NotFound:
break
except SQLAlchemyError:
raise
if embedding_ids:
for embedding_id in embedding_ids:
db.session.execute(

View File

@ -3,7 +3,7 @@ import logging
import time
import click
from werkzeug.exceptions import NotFound
from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
@ -42,8 +42,8 @@ def clean_messages():
.all()
)
except NotFound:
break
except SQLAlchemyError:
raise
if not messages:
break
for message in messages:

View File

@ -3,7 +3,7 @@ import time
import click
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound
from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
@ -65,8 +65,8 @@ def clean_unused_datasets_task():
datasets = db.paginate(stmt, page=1, per_page=50)
except NotFound:
break
except SQLAlchemyError:
raise
if datasets.items is None or len(datasets.items) == 0:
break
for dataset in datasets:
@ -146,8 +146,8 @@ def clean_unused_datasets_task():
)
datasets = db.paginate(stmt, page=1, per_page=50)
except NotFound:
break
except SQLAlchemyError:
raise
if datasets.items is None or len(datasets.items) == 0:
break
for dataset in datasets:

View File

@ -50,12 +50,16 @@ class ConversationService:
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
)
# Check if include_ids is not None and not empty to avoid WHERE false condition
if include_ids is not None and len(include_ids) > 0:
# Check if include_ids is not None to apply filter
if include_ids is not None:
if len(include_ids) == 0:
# If include_ids is empty, return empty result
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
stmt = stmt.where(Conversation.id.in_(include_ids))
# Check if exclude_ids is not None and not empty to avoid WHERE false condition
if exclude_ids is not None and len(exclude_ids) > 0:
stmt = stmt.where(~Conversation.id.in_(exclude_ids))
# Check if exclude_ids is not None to apply filter
if exclude_ids is not None:
if len(exclude_ids) > 0:
stmt = stmt.where(~Conversation.id.in_(exclude_ids))
# define sort fields and directions
sort_field, sort_direction = cls._get_sort_params(sort_by)

View File

@ -256,7 +256,7 @@ class WorkflowDraftVariableService:
def _reset_node_var_or_sys_var(
self, workflow: Workflow, variable: WorkflowDraftVariable
) -> WorkflowDraftVariable | None:
# If a variable does not allow updating, it makes no sence to resetting it.
# If a variable does not allow updating, it makes no sense to reset it.
if not variable.editable:
return variable
# No execution record for this variable, delete the variable instead.
@ -478,7 +478,7 @@ def _batch_upsert_draft_variable(
"node_execution_id": stmt.excluded.node_execution_id,
},
)
elif _UpsertPolicy.IGNORE:
elif policy == _UpsertPolicy.IGNORE:
stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name())
else:
raise Exception("Invalid value for update policy.")

View File

@ -56,15 +56,24 @@ def clean_dataset_task(
documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
# Fix: Always clean vector database resources regardless of document existence
# This ensures all 33 vector databases properly drop tables/collections/indices
if doc_form is None:
# Use default paragraph index type for empty datasets to enable vector database cleanup
from core.rag.index_processor.constant.index_type import IndexType
doc_form = IndexType.PARAGRAPH_INDEX
logging.info(
click.style(f"No documents found, using default index type for cleanup: {doc_form}", fg="yellow")
)
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
if documents is None or len(documents) == 0:
logging.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
else:
logging.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
# Specify the index type before initializing the index processor
if doc_form is None:
raise ValueError("Index type must be specified.")
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
for document in documents:
db.session.delete(document)

View File

@ -0,0 +1,168 @@
"""
Unit tests for App description validation functions.
This test module validates the 400-character limit enforcement
for App descriptions across all creation and editing endpoints.
"""
import os
import sys
import pytest
# Add the API root to Python path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
class TestAppDescriptionValidationUnit:
"""Unit tests for description validation function"""
def test_validate_description_length_function(self):
"""Test the _validate_description_length function directly"""
from controllers.console.app.app import _validate_description_length
# Test valid descriptions
assert _validate_description_length("") == ""
assert _validate_description_length("x" * 400) == "x" * 400
assert _validate_description_length(None) is None
# Test invalid descriptions
with pytest.raises(ValueError) as exc_info:
_validate_description_length("x" * 401)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
_validate_description_length("x" * 500)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
_validate_description_length("x" * 1000)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
def test_validation_consistency_with_dataset(self):
"""Test that App and Dataset validation functions are consistent"""
from controllers.console.app.app import _validate_description_length as app_validate
from controllers.console.datasets.datasets import _validate_description_length as dataset_validate
from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate
# Test same valid inputs
valid_desc = "x" * 400
assert app_validate(valid_desc) == dataset_validate(valid_desc) == service_dataset_validate(valid_desc)
assert app_validate("") == dataset_validate("") == service_dataset_validate("")
assert app_validate(None) == dataset_validate(None) == service_dataset_validate(None)
# Test same invalid inputs produce same error
invalid_desc = "x" * 401
app_error = None
dataset_error = None
service_dataset_error = None
try:
app_validate(invalid_desc)
except ValueError as e:
app_error = str(e)
try:
dataset_validate(invalid_desc)
except ValueError as e:
dataset_error = str(e)
try:
service_dataset_validate(invalid_desc)
except ValueError as e:
service_dataset_error = str(e)
assert app_error == dataset_error == service_dataset_error
assert app_error == "Description cannot exceed 400 characters."
def test_boundary_values(self):
"""Test boundary values for description validation"""
from controllers.console.app.app import _validate_description_length
# Test exact boundary
exactly_400 = "x" * 400
assert _validate_description_length(exactly_400) == exactly_400
# Test just over boundary
just_over_400 = "x" * 401
with pytest.raises(ValueError):
_validate_description_length(just_over_400)
# Test just under boundary
just_under_400 = "x" * 399
assert _validate_description_length(just_under_400) == just_under_400
def test_edge_cases(self):
"""Test edge cases for description validation"""
from controllers.console.app.app import _validate_description_length
# Test None input
assert _validate_description_length(None) is None
# Test empty string
assert _validate_description_length("") == ""
# Test single character
assert _validate_description_length("a") == "a"
# Test unicode characters
unicode_desc = "测试" * 200 # 400 characters in Chinese
assert _validate_description_length(unicode_desc) == unicode_desc
# Test unicode over limit
unicode_over = "测试" * 201 # 402 characters
with pytest.raises(ValueError):
_validate_description_length(unicode_over)
def test_whitespace_handling(self):
"""Test how validation handles whitespace"""
from controllers.console.app.app import _validate_description_length
# Test description with spaces
spaces_400 = " " * 400
assert _validate_description_length(spaces_400) == spaces_400
# Test description with spaces over limit
spaces_401 = " " * 401
with pytest.raises(ValueError):
_validate_description_length(spaces_401)
# Test mixed content
mixed_400 = "a" * 200 + " " * 200
assert _validate_description_length(mixed_400) == mixed_400
# Test mixed over limit
mixed_401 = "a" * 200 + " " * 201
with pytest.raises(ValueError):
_validate_description_length(mixed_401)
if __name__ == "__main__":
# Run tests directly
import traceback
test_instance = TestAppDescriptionValidationUnit()
test_methods = [method for method in dir(test_instance) if method.startswith("test_")]
passed = 0
failed = 0
for test_method in test_methods:
try:
print(f"Running {test_method}...")
getattr(test_instance, test_method)()
print(f"{test_method} PASSED")
passed += 1
except Exception as e:
print(f"{test_method} FAILED: {str(e)}")
traceback.print_exc()
failed += 1
print(f"\n📊 Test Results: {passed} passed, {failed} failed")
if failed == 0:
print("🎉 All tests passed!")
else:
print("💥 Some tests failed!")
sys.exit(1)

View File

@ -0,0 +1,168 @@
"""Integration tests for ClickZetta Volume Storage."""
import os
import tempfile
import unittest
import pytest
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
ClickZettaVolumeConfig,
ClickZettaVolumeStorage,
)
class TestClickZettaVolumeStorage(unittest.TestCase):
"""Test cases for ClickZetta Volume Storage."""
def setUp(self):
"""Set up test environment."""
self.config = ClickZettaVolumeConfig(
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
password=os.getenv("CLICKZETTA_PASSWORD", "test_pass"),
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
service=os.getenv("CLICKZETTA_SERVICE", "uat-api.clickzetta.com"),
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"),
volume_type="table",
table_prefix="test_dataset_",
)
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
def test_user_volume_operations(self):
"""Test basic operations with User Volume."""
config = self.config
config.volume_type = "user"
storage = ClickZettaVolumeStorage(config)
# Test file operations
test_filename = "test_file.txt"
test_content = b"Hello, ClickZetta Volume!"
# Save file
storage.save(test_filename, test_content)
# Check if file exists
assert storage.exists(test_filename)
# Load file
loaded_content = storage.load_once(test_filename)
assert loaded_content == test_content
# Test streaming
stream_content = b""
for chunk in storage.load_stream(test_filename):
stream_content += chunk
assert stream_content == test_content
# Test download
with tempfile.NamedTemporaryFile() as temp_file:
storage.download(test_filename, temp_file.name)
with open(temp_file.name, "rb") as f:
downloaded_content = f.read()
assert downloaded_content == test_content
# Test scan
files = storage.scan("", files=True, directories=False)
assert test_filename in files
# Delete file
storage.delete(test_filename)
assert not storage.exists(test_filename)
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
def test_table_volume_operations(self):
"""Test basic operations with Table Volume."""
config = self.config
config.volume_type = "table"
storage = ClickZettaVolumeStorage(config)
# Test file operations with dataset_id
dataset_id = "12345"
test_filename = f"{dataset_id}/test_file.txt"
test_content = b"Hello, Table Volume!"
# Save file
storage.save(test_filename, test_content)
# Check if file exists
assert storage.exists(test_filename)
# Load file
loaded_content = storage.load_once(test_filename)
assert loaded_content == test_content
# Test scan for dataset
files = storage.scan(dataset_id, files=True, directories=False)
assert "test_file.txt" in files
# Delete file
storage.delete(test_filename)
assert not storage.exists(test_filename)
def test_config_validation(self):
"""Test configuration validation."""
# Test missing required fields
with pytest.raises(ValueError):
ClickZettaVolumeConfig(
username="", # Empty username should fail
password="pass",
instance="instance",
)
# Test invalid volume type
with pytest.raises(ValueError):
ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type")
# Test external volume without volume_name
with pytest.raises(ValueError):
ClickZettaVolumeConfig(
username="user",
password="pass",
instance="instance",
volume_type="external",
# Missing volume_name
)
def test_volume_path_generation(self):
"""Test volume path generation for different types."""
storage = ClickZettaVolumeStorage(self.config)
# Test table volume path
path = storage._get_volume_path("test.txt", "12345")
assert path == "test_dataset_12345/test.txt"
# Test path with existing dataset_id prefix
path = storage._get_volume_path("12345/test.txt")
assert path == "12345/test.txt"
# Test user volume
storage._config.volume_type = "user"
path = storage._get_volume_path("test.txt")
assert path == "test.txt"
def test_sql_prefix_generation(self):
"""Test SQL prefix generation for different volume types."""
storage = ClickZettaVolumeStorage(self.config)
# Test table volume SQL prefix
prefix = storage._get_volume_sql_prefix("12345")
assert prefix == "TABLE VOLUME test_dataset_12345"
# Test user volume SQL prefix
storage._config.volume_type = "user"
prefix = storage._get_volume_sql_prefix()
assert prefix == "USER VOLUME"
# Test external volume SQL prefix
storage._config.volume_type = "external"
storage._config.volume_name = "my_external_volume"
prefix = storage._get_volume_sql_prefix()
assert prefix == "VOLUME my_external_volume"
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,25 @@
# Clickzetta Integration Tests
## Running Tests
To run the Clickzetta integration tests, you need to set the following environment variables:
```bash
export CLICKZETTA_USERNAME=your_username
export CLICKZETTA_PASSWORD=your_password
export CLICKZETTA_INSTANCE=your_instance
export CLICKZETTA_SERVICE=api.clickzetta.com
export CLICKZETTA_WORKSPACE=your_workspace
export CLICKZETTA_VCLUSTER=your_vcluster
export CLICKZETTA_SCHEMA=dify
```
Then run the tests:
```bash
pytest api/tests/integration_tests/vdb/clickzetta/
```
## Security Note
Never commit credentials to the repository. Always use environment variables or secure credential management systems.

View File

@ -0,0 +1,224 @@
import os
import pytest
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
class TestClickzettaVector(AbstractVectorTest):
"""
Test cases for Clickzetta vector database integration.
"""
@pytest.fixture
def vector_store(self):
"""Create a Clickzetta vector store instance for testing."""
# Skip test if Clickzetta credentials are not configured
if not os.getenv("CLICKZETTA_USERNAME"):
pytest.skip("CLICKZETTA_USERNAME is not configured")
if not os.getenv("CLICKZETTA_PASSWORD"):
pytest.skip("CLICKZETTA_PASSWORD is not configured")
if not os.getenv("CLICKZETTA_INSTANCE"):
pytest.skip("CLICKZETTA_INSTANCE is not configured")
config = ClickzettaConfig(
username=os.getenv("CLICKZETTA_USERNAME", ""),
password=os.getenv("CLICKZETTA_PASSWORD", ""),
instance=os.getenv("CLICKZETTA_INSTANCE", ""),
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
batch_size=10, # Small batch size for testing
enable_inverted_index=True,
analyzer_type="chinese",
analyzer_mode="smart",
vector_distance_function="cosine_distance",
)
with setup_mock_redis():
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
yield vector
# Cleanup: delete the test collection
try:
vector.delete()
except Exception:
pass
def test_clickzetta_vector_basic_operations(self, vector_store):
"""Test basic CRUD operations on Clickzetta vector store."""
# Prepare test data
texts = [
"这是第一个测试文档,包含一些中文内容。",
"This is the second test document with English content.",
"第三个文档混合了English和中文内容。",
]
embeddings = [
[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.9, 1.0, 1.1, 1.2],
]
documents = [
Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"})
for i, text in enumerate(texts)
]
# Test create (initial insert)
vector_store.create(texts=documents, embeddings=embeddings)
# Test text_exists
assert vector_store.text_exists("doc_0")
assert not vector_store.text_exists("doc_999")
# Test search_by_vector
query_vector = [0.1, 0.2, 0.3, 0.4]
results = vector_store.search_by_vector(query_vector, top_k=2)
assert len(results) > 0
assert results[0].page_content == texts[0] # Should match the first document
# Test search_by_full_text (Chinese)
results = vector_store.search_by_full_text("中文", top_k=3)
assert len(results) >= 2 # Should find documents with Chinese content
# Test search_by_full_text (English)
results = vector_store.search_by_full_text("English", top_k=3)
assert len(results) >= 2 # Should find documents with English content
# Test delete_by_ids
vector_store.delete_by_ids(["doc_0"])
assert not vector_store.text_exists("doc_0")
assert vector_store.text_exists("doc_1")
# Test delete_by_metadata_field
vector_store.delete_by_metadata_field("source", "test")
assert not vector_store.text_exists("doc_1")
assert not vector_store.text_exists("doc_2")
def test_clickzetta_vector_advanced_search(self, vector_store):
"""Test advanced search features of Clickzetta vector store."""
# Prepare test data with more complex metadata
documents = []
embeddings = []
for i in range(10):
doc = Document(
page_content=f"Document {i}: " + get_example_text(),
metadata={
"doc_id": f"adv_doc_{i}",
"category": "technical" if i % 2 == 0 else "general",
"document_id": f"doc_{i // 3}", # Group documents
"importance": i,
},
)
documents.append(doc)
# Create varied embeddings
embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i])
vector_store.create(texts=documents, embeddings=embeddings)
# Test vector search with document filter
query_vector = [0.5, 1.0, 1.5, 2.0]
results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"])
assert len(results) > 0
# All results should belong to doc_0 or doc_1 groups
for result in results:
assert result.metadata["document_id"] in ["doc_0", "doc_1"]
# Test score threshold
results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5)
# Check that all results have a score above threshold
for result in results:
assert result.metadata.get("score", 0) >= 0.5
def test_clickzetta_batch_operations(self, vector_store):
"""Test batch insertion operations."""
# Prepare large batch of documents
batch_size = 25
documents = []
embeddings = []
for i in range(batch_size):
doc = Document(
page_content=f"Batch document {i}: This is a test document for batch processing.",
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"},
)
documents.append(doc)
embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)])
# Test batch insert
vector_store.add_texts(documents=documents, embeddings=embeddings)
# Verify all documents were inserted
for i in range(batch_size):
assert vector_store.text_exists(f"batch_doc_{i}")
# Clean up
vector_store.delete_by_metadata_field("batch", "test_batch")
def test_clickzetta_edge_cases(self, vector_store):
"""Test edge cases and error handling."""
# Test empty operations
vector_store.create(texts=[], embeddings=[])
vector_store.add_texts(documents=[], embeddings=[])
vector_store.delete_by_ids([])
# Test special characters in content
special_doc = Document(
page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline",
metadata={"doc_id": "special_doc", "test": "edge_case"},
)
embeddings = [[0.1, 0.2, 0.3, 0.4]]
vector_store.add_texts(documents=[special_doc], embeddings=embeddings)
assert vector_store.text_exists("special_doc")
# Test search with special characters
results = vector_store.search_by_full_text("quotes", top_k=1)
if results: # Full-text search might not be available
assert len(results) > 0
# Clean up
vector_store.delete_by_ids(["special_doc"])
def test_clickzetta_full_text_search_modes(self, vector_store):
"""Test different full-text search capabilities."""
# Prepare documents with various language content
documents = [
Document(
page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
),
Document(
page_content="Clickzetta provides powerful Lakehouse solutions",
metadata={"doc_id": "en_doc_1", "lang": "english"},
),
Document(
page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
),
Document(
page_content="Modern data architecture includes Lakehouse technology",
metadata={"doc_id": "en_doc_2", "lang": "english"},
),
]
embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents]
vector_store.create(texts=documents, embeddings=embeddings)
# Test Chinese full-text search
results = vector_store.search_by_full_text("Lakehouse", top_k=4)
assert len(results) >= 2 # Should find at least documents with "Lakehouse"
# Test English full-text search
results = vector_store.search_by_full_text("solutions", top_k=2)
assert len(results) >= 1 # Should find English documents with "solutions"
# Test mixed search
results = vector_store.search_by_full_text("数据架构", top_k=2)
assert len(results) >= 1 # Should find Chinese documents with this phrase
# Clean up
vector_store.delete_by_metadata_field("lang", "chinese")
vector_store.delete_by_metadata_field("lang", "english")

View File

@ -0,0 +1,165 @@
#!/usr/bin/env python3
"""
Test Clickzetta integration in Docker environment
"""
import os
import time
import requests
from clickzetta import connect
def test_clickzetta_connection():
"""Test direct connection to Clickzetta"""
print("=== Testing direct Clickzetta connection ===")
try:
conn = connect(
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
password=os.getenv("CLICKZETTA_PASSWORD", "test_password"),
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"),
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"),
database=os.getenv("CLICKZETTA_SCHEMA", "dify"),
)
with conn.cursor() as cursor:
# Test basic connectivity
cursor.execute("SELECT 1 as test")
result = cursor.fetchone()
print(f"✓ Connection test: {result}")
# Check if our test table exists
cursor.execute("SHOW TABLES IN dify")
tables = cursor.fetchall()
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
# Check if test collection exists
test_collection = "collection_test_dataset"
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
cursor.execute(f"DESCRIBE dify.{test_collection}")
columns = cursor.fetchall()
print(f"✓ Table structure for {test_collection}:")
for col in columns:
print(f" - {col[0]}: {col[1]}")
# Check for indexes
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
indexes = cursor.fetchall()
print(f"✓ Indexes on {test_collection}:")
for idx in indexes:
print(f" - {idx}")
return True
except Exception as e:
print(f"✗ Connection test failed: {e}")
return False
def test_dify_api():
"""Test Dify API with Clickzetta backend"""
print("\n=== Testing Dify API ===")
base_url = "http://localhost:5001"
# Wait for API to be ready
max_retries = 30
for i in range(max_retries):
try:
response = requests.get(f"{base_url}/console/api/health")
if response.status_code == 200:
print("✓ Dify API is ready")
break
except:
if i == max_retries - 1:
print("✗ Dify API is not responding")
return False
time.sleep(2)
# Check vector store configuration
try:
# This is a simplified check - in production, you'd use proper auth
print("✓ Dify is configured to use Clickzetta as vector store")
return True
except Exception as e:
print(f"✗ API test failed: {e}")
return False
def verify_table_structure():
"""Verify the table structure meets Dify requirements"""
print("\n=== Verifying Table Structure ===")
expected_columns = {
"id": "VARCHAR",
"page_content": "VARCHAR",
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
"vector": "ARRAY<FLOAT>",
}
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
print("✓ Expected table structure:")
for col, dtype in expected_columns.items():
print(f" - {col}: {dtype}")
print("\n✓ Required metadata fields:")
for field in expected_metadata_fields:
print(f" - {field}")
print("\n✓ Index requirements:")
print(" - Vector index (HNSW) on 'vector' column")
print(" - Full-text index on 'page_content' (optional)")
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
print(" - Functional index on metadata->>'$.document_id' (recommended)")
return True
def main():
"""Run all tests"""
print("Starting Clickzetta integration tests for Dify Docker\n")
tests = [
("Direct Clickzetta Connection", test_clickzetta_connection),
("Dify API Status", test_dify_api),
("Table Structure Verification", verify_table_structure),
]
results = []
for test_name, test_func in tests:
try:
success = test_func()
results.append((test_name, success))
except Exception as e:
print(f"\n{test_name} crashed: {e}")
results.append((test_name, False))
# Summary
print("\n" + "=" * 50)
print("Test Summary:")
print("=" * 50)
passed = sum(1 for _, success in results if success)
total = len(results)
for test_name, success in results:
status = "✅ PASSED" if success else "❌ FAILED"
print(f"{test_name}: {status}")
print(f"\nTotal: {passed}/{total} tests passed")
if passed == total:
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
print("\nNext steps:")
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
print("2. Access Dify at http://localhost:3000")
print("3. Create a dataset and test vector storage with Clickzetta")
return 0
else:
print("\n⚠️ Some tests failed. Please check the errors above.")
return 1
if __name__ == "__main__":
exit(main())

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,487 @@
from unittest.mock import patch
import pytest
from faker import Faker
from models.api_based_extension import APIBasedExtension
from services.account_service import AccountService, TenantService
from services.api_based_extension_service import APIBasedExtensionService
class TestAPIBasedExtensionService:
"""Integration tests for APIBasedExtensionService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch("services.api_based_extension_service.APIBasedExtensionRequestor") as mock_requestor,
):
# Setup default mock returns
mock_account_feature_service.get_features.return_value.billing.enabled = False
# Mock successful ping response
mock_requestor_instance = mock_requestor.return_value
mock_requestor_instance.request.return_value = {"result": "pong"}
yield {
"account_feature_service": mock_account_feature_service,
"requestor": mock_requestor,
"requestor_instance": mock_requestor_instance,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
return account, tenant
def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful saving of API-based extension.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup extension data
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
# Save extension
saved_extension = APIBasedExtensionService.save(extension_data)
# Verify extension was saved correctly
assert saved_extension.id is not None
assert saved_extension.tenant_id == tenant.id
assert saved_extension.name == extension_data.name
assert saved_extension.api_endpoint == extension_data.api_endpoint
assert saved_extension.api_key == extension_data.api_key # Should be decrypted when retrieved
assert saved_extension.created_at is not None
# Verify extension was saved to database
from extensions.ext_database import db
db.session.refresh(saved_extension)
assert saved_extension.id is not None
# Verify ping connection was called
mock_external_service_dependencies["requestor_instance"].request.assert_called_once()
def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test validation errors when saving extension with invalid data.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Test empty name
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = ""
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
with pytest.raises(ValueError, match="name must not be empty"):
APIBasedExtensionService.save(extension_data)
# Test empty api_endpoint
extension_data.name = fake.company()
extension_data.api_endpoint = ""
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
APIBasedExtensionService.save(extension_data)
# Test empty api_key
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = ""
with pytest.raises(ValueError, match="api_key must not be empty"):
APIBasedExtensionService.save(extension_data)
def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of all extensions by tenant ID.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create multiple extensions
extensions = []
for i in range(3):
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = f"Extension {i}: {fake.company()}"
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
saved_extension = APIBasedExtensionService.save(extension_data)
extensions.append(saved_extension)
# Get all extensions for tenant
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
# Verify results
assert len(extension_list) == 3
# Verify all extensions belong to the correct tenant and are ordered by created_at desc
for i, extension in enumerate(extension_list):
assert extension.tenant_id == tenant.id
assert extension.api_key is not None # Should be decrypted
if i > 0:
# Verify descending order (newer first)
assert extension.created_at <= extension_list[i - 1].created_at
def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of extension by tenant ID and extension ID.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create an extension
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
created_extension = APIBasedExtensionService.save(extension_data)
# Get extension by ID
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
# Verify extension was retrieved correctly
assert retrieved_extension is not None
assert retrieved_extension.id == created_extension.id
assert retrieved_extension.tenant_id == tenant.id
assert retrieved_extension.name == extension_data.name
assert retrieved_extension.api_endpoint == extension_data.api_endpoint
assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted
assert retrieved_extension.created_at is not None
def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test retrieval of extension when extension is not found.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
non_existent_extension_id = fake.uuid4()
# Try to get non-existent extension
with pytest.raises(ValueError, match="API based extension is not found"):
APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id)
def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful deletion of extension.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create an extension first
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
created_extension = APIBasedExtensionService.save(extension_data)
extension_id = created_extension.id
# Delete the extension
APIBasedExtensionService.delete(created_extension)
# Verify extension was deleted
from extensions.ext_database import db
deleted_extension = db.session.query(APIBasedExtension).filter(APIBasedExtension.id == extension_id).first()
assert deleted_extension is None
def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test validation error when saving extension with duplicate name.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create first extension
extension_data1 = APIBasedExtension()
extension_data1.tenant_id = tenant.id
extension_data1.name = "Test Extension"
extension_data1.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data1.api_key = fake.password(length=20)
APIBasedExtensionService.save(extension_data1)
# Try to create second extension with same name
extension_data2 = APIBasedExtension()
extension_data2.tenant_id = tenant.id
extension_data2.name = "Test Extension" # Same name
extension_data2.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data2.api_key = fake.password(length=20)
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
APIBasedExtensionService.save(extension_data2)
def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful update of existing extension.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create initial extension
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
created_extension = APIBasedExtensionService.save(extension_data)
# Save original values for later comparison
original_name = created_extension.name
original_endpoint = created_extension.api_endpoint
# Update the extension
new_name = fake.company()
new_endpoint = f"https://{fake.domain_name()}/api"
new_api_key = fake.password(length=20)
created_extension.name = new_name
created_extension.api_endpoint = new_endpoint
created_extension.api_key = new_api_key
updated_extension = APIBasedExtensionService.save(created_extension)
# Verify extension was updated correctly
assert updated_extension.id == created_extension.id
assert updated_extension.tenant_id == tenant.id
assert updated_extension.name == new_name
assert updated_extension.api_endpoint == new_endpoint
# Verify original values were changed
assert updated_extension.name != original_name
assert updated_extension.api_endpoint != original_endpoint
# Verify ping connection was called for both create and update
assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2
# Verify the update by retrieving the extension again
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
assert retrieved_extension.name == new_name
assert retrieved_extension.api_endpoint == new_endpoint
assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved
def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test connection error when saving extension with invalid endpoint.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Mock connection error
mock_external_service_dependencies["requestor_instance"].request.side_effect = ValueError(
"connection error: request timeout"
)
# Setup extension data
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = "https://invalid-endpoint.com/api"
extension_data.api_key = fake.password(length=20)
# Try to save extension with connection error
with pytest.raises(ValueError, match="connection error: request timeout"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_invalid_api_key_length(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test validation error when saving extension with API key that is too short.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup extension data with short API key
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = "1234" # Less than 5 characters
# Try to save extension with short API key
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test validation errors when saving extension with empty required fields.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Test with None values
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = None
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
with pytest.raises(ValueError, match="name must not be empty"):
APIBasedExtensionService.save(extension_data)
# Test with None api_endpoint
extension_data.name = fake.company()
extension_data.api_endpoint = None
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
APIBasedExtensionService.save(extension_data)
# Test with None api_key
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = None
with pytest.raises(ValueError, match="api_key must not be empty"):
APIBasedExtensionService.save(extension_data)
def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test retrieval of extensions when no extensions exist for tenant.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Get all extensions for tenant (none exist)
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
# Verify empty list is returned
assert len(extension_list) == 0
assert extension_list == []
def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test validation error when ping response is invalid.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Mock invalid ping response
mock_external_service_dependencies["requestor_instance"].request.return_value = {"result": "invalid"}
# Setup extension data
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
# Try to save extension with invalid ping response
with pytest.raises(ValueError, match="{'result': 'invalid'}"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test validation error when ping response is missing result field.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Mock ping response without result field
mock_external_service_dependencies["requestor_instance"].request.return_value = {"status": "ok"}
# Setup extension data
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
# Try to save extension with missing ping result
with pytest.raises(ValueError, match="{'status': 'ok'}"):
APIBasedExtensionService.save(extension_data)
def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test retrieval of extension when tenant ID doesn't match.
"""
fake = Faker()
account1, tenant1 = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create second account and tenant
account2, tenant2 = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create extension in first tenant
extension_data = APIBasedExtension()
extension_data.tenant_id = tenant1.id
extension_data.name = fake.company()
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = fake.password(length=20)
created_extension = APIBasedExtensionService.save(extension_data)
# Try to get extension with wrong tenant ID
with pytest.raises(ValueError, match="API based extension is not found"):
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)

View File

@ -0,0 +1,473 @@
import json
from unittest.mock import MagicMock, patch
import pytest
import yaml
from faker import Faker
from models.model import App, AppModelConfig
from services.account_service import AccountService, TenantService
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
from services.app_service import AppService
class TestAppDslService:
"""Integration tests for AppDslService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.app_dsl_service.WorkflowService") as mock_workflow_service,
patch("services.app_dsl_service.DependenciesAnalysisService") as mock_dependencies_service,
patch("services.app_dsl_service.WorkflowDraftVariableService") as mock_draft_variable_service,
patch("services.app_dsl_service.ssrf_proxy") as mock_ssrf_proxy,
patch("services.app_dsl_service.redis_client") as mock_redis_client,
patch("services.app_dsl_service.app_was_created") as mock_app_was_created,
patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated,
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
):
# Setup default mock returns
mock_workflow_service.return_value.get_draft_workflow.return_value = None
mock_workflow_service.return_value.sync_draft_workflow.return_value = MagicMock()
mock_dependencies_service.generate_latest_dependencies.return_value = []
mock_dependencies_service.get_leaked_dependencies.return_value = []
mock_dependencies_service.generate_dependencies.return_value = []
mock_draft_variable_service.return_value.delete_workflow_variables.return_value = None
mock_ssrf_proxy.get.return_value.content = b"test content"
mock_ssrf_proxy.get.return_value.raise_for_status.return_value = None
mock_redis_client.setex.return_value = None
mock_redis_client.get.return_value = None
mock_redis_client.delete.return_value = None
mock_app_was_created.send.return_value = None
mock_app_model_config_was_updated.send.return_value = None
# Mock ModelManager for app service
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
# Mock FeatureService and EnterpriseService
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
yield {
"workflow_service": mock_workflow_service,
"dependencies_service": mock_dependencies_service,
"draft_variable_service": mock_draft_variable_service,
"ssrf_proxy": mock_ssrf_proxy,
"redis_client": mock_redis_client,
"app_was_created": mock_app_was_created,
"app_model_config_was_updated": mock_app_model_config_was_updated,
"model_manager": mock_model_manager,
"feature_service": mock_feature_service,
"enterprise_service": mock_enterprise_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (app, account) - Created app and account instances
"""
fake = Faker()
# Setup mocks for account creation
with patch("services.account_service.FeatureService") as mock_account_feature_service:
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Setup app creation arguments
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
# Create app
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
return app, account
def _create_simple_yaml_content(self, app_name="Test App", app_mode="chat"):
"""
Helper method to create simple YAML content for testing.
"""
yaml_data = {
"version": "0.3.0",
"kind": "app",
"app": {
"name": app_name,
"mode": app_mode,
"icon": "🤖",
"icon_background": "#FFEAD5",
"description": "Test app description",
"use_icon_as_answer_icon": False,
},
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 1000,
"temperature": 0.7,
"top_p": 1.0,
},
},
"pre_prompt": "You are a helpful assistant.",
"prompt_type": "simple",
},
}
return yaml.dump(yaml_data, allow_unicode=True)
def test_import_app_yaml_content_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app import from YAML content.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create YAML content
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
# Import app
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name="Imported App",
description="Imported app description",
)
# Verify import result
assert result.status == ImportStatus.COMPLETED
assert result.app_id is not None
assert result.app_mode == "chat"
assert result.imported_dsl_version == "0.3.0"
assert result.error == ""
# Verify app was created in database
imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first()
assert imported_app is not None
assert imported_app.name == "Imported App"
assert imported_app.description == "Imported app description"
assert imported_app.mode == "chat"
assert imported_app.tenant_id == account.current_tenant_id
assert imported_app.created_by == account.id
# Verify model config was created
model_config = (
db_session_with_containers.query(AppModelConfig).filter(AppModelConfig.app_id == result.app_id).first()
)
assert model_config is not None
# The provider and model_id are stored in the model field as JSON
model_dict = model_config.model_dict
assert model_dict["provider"] == "openai"
assert model_dict["name"] == "gpt-3.5-turbo"
def test_import_app_yaml_url_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app import from YAML URL.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create YAML content for mock response
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
# Setup mock response
mock_response = MagicMock()
mock_response.content = yaml_content.encode("utf-8")
mock_response.raise_for_status.return_value = None
mock_external_service_dependencies["ssrf_proxy"].get.return_value = mock_response
# Import app from URL
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_URL,
yaml_url="https://example.com/app.yaml",
name="URL Imported App",
description="App imported from URL",
)
# Verify import result
assert result.status == ImportStatus.COMPLETED
assert result.app_id is not None
assert result.app_mode == "chat"
assert result.imported_dsl_version == "0.3.0"
assert result.error == ""
# Verify app was created in database
imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first()
assert imported_app is not None
assert imported_app.name == "URL Imported App"
assert imported_app.description == "App imported from URL"
assert imported_app.mode == "chat"
assert imported_app.tenant_id == account.current_tenant_id
# Verify ssrf_proxy was called
mock_external_service_dependencies["ssrf_proxy"].get.assert_called_once_with(
"https://example.com/app.yaml", follow_redirects=True, timeout=(10, 10)
)
def test_import_app_invalid_yaml_format(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app import with invalid YAML format.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create invalid YAML content
invalid_yaml = "invalid: yaml: content: ["
# Import app with invalid YAML
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=invalid_yaml,
name="Invalid App",
)
# Verify import failed
assert result.status == ImportStatus.FAILED
assert result.app_id is None
assert "Invalid YAML format" in result.error
assert result.imported_dsl_version == ""
# Verify no app was created in database
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
assert apps_count == 1 # Only the original test app
def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app import with missing YAML content.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Import app without YAML content
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT,
name="Missing Content App",
)
# Verify import failed
assert result.status == ImportStatus.FAILED
assert result.app_id is None
assert "yaml_content is required" in result.error
assert result.imported_dsl_version == ""
# Verify no app was created in database
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
assert apps_count == 1 # Only the original test app
def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app import with missing YAML URL.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Import app without YAML URL
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_URL,
name="Missing URL App",
)
# Verify import failed
assert result.status == ImportStatus.FAILED
assert result.app_id is None
assert "yaml_url is required" in result.error
assert result.imported_dsl_version == ""
# Verify no app was created in database
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
assert apps_count == 1 # Only the original test app
def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app import with invalid import mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create YAML content
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
# Import app with invalid mode should raise ValueError
dsl_service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Invalid import_mode: invalid-mode"):
dsl_service.import_app(
account=account,
import_mode="invalid-mode",
yaml_content=yaml_content,
name="Invalid Mode App",
)
# Verify no app was created in database
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
assert apps_count == 1 # Only the original test app
def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful DSL export for chat app.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create model config for the app
model_config = AppModelConfig()
model_config.id = fake.uuid4()
model_config.app_id = app.id
model_config.provider = "openai"
model_config.model_id = "gpt-3.5-turbo"
model_config.model = json.dumps(
{
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 1000,
"temperature": 0.7,
},
}
)
model_config.pre_prompt = "You are a helpful assistant."
model_config.prompt_type = "simple"
model_config.created_by = account.id
model_config.updated_by = account.id
# Set the app_model_config_id to link the config
app.app_model_config_id = model_config.id
db_session_with_containers.add(model_config)
db_session_with_containers.commit()
# Export DSL
exported_dsl = AppDslService.export_dsl(app, include_secret=False)
# Parse exported YAML
exported_data = yaml.safe_load(exported_dsl)
# Verify exported data structure
assert exported_data["kind"] == "app"
assert exported_data["app"]["name"] == app.name
assert exported_data["app"]["mode"] == app.mode
assert exported_data["app"]["icon"] == app.icon
assert exported_data["app"]["icon_background"] == app.icon_background
assert exported_data["app"]["description"] == app.description
# Verify model config was exported
assert "model_config" in exported_data
# The exported model_config structure may be different from the database structure
# Check that the model config exists and has the expected content
assert exported_data["model_config"] is not None
# Verify dependencies were exported
assert "dependencies" in exported_data
assert isinstance(exported_data["dependencies"], list)
def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful DSL export for workflow app.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Update app to workflow mode
app.mode = "workflow"
db_session_with_containers.commit()
# Mock workflow service to return a workflow
mock_workflow = MagicMock()
mock_workflow.to_dict.return_value = {
"graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []},
"features": {},
"environment_variables": [],
"conversation_variables": [],
}
mock_external_service_dependencies[
"workflow_service"
].return_value.get_draft_workflow.return_value = mock_workflow
# Export DSL
exported_dsl = AppDslService.export_dsl(app, include_secret=False)
# Parse exported YAML
exported_data = yaml.safe_load(exported_dsl)
# Verify exported data structure
assert exported_data["kind"] == "app"
assert exported_data["app"]["name"] == app.name
assert exported_data["app"]["mode"] == "workflow"
# Verify workflow was exported
assert "workflow" in exported_data
assert "graph" in exported_data["workflow"]
assert "nodes" in exported_data["workflow"]["graph"]
# Verify dependencies were exported
assert "dependencies" in exported_data
assert isinstance(exported_data["dependencies"], list)
# Verify workflow service was called
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
app
)
def test_check_dependencies_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful dependency checking.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Mock Redis to return dependencies
mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}'
mock_external_service_dependencies["redis_client"].get.return_value = mock_dependencies_json
# Check dependencies
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.check_dependencies(app_model=app)
# Verify result
assert result.leaked_dependencies == []
# Verify Redis was queried
mock_external_service_dependencies["redis_client"].get.assert_called_once_with(
f"app_check_dependencies:{app.id}"
)
# Verify dependencies service was called
mock_external_service_dependencies["dependencies_service"].get_leaked_dependencies.assert_called_once()

View File

@ -0,0 +1,928 @@
from unittest.mock import patch
import pytest
from faker import Faker
from constants.model_template import default_app_templates
from models.model import App, Site
from services.account_service import AccountService, TenantService
from services.app_service import AppService
class TestAppService:
"""Integration tests for AppService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
# Setup default mock returns for app service
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
# Setup default mock returns for account service
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
# Mock ModelManager for model configuration
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
yield {
"feature_service": mock_feature_service,
"enterprise_service": mock_enterprise_service,
"model_manager": mock_model_manager,
"account_feature_service": mock_account_feature_service,
}
def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app creation with basic parameters.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Setup app creation arguments
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
# Create app
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Verify app was created correctly
assert app.name == app_args["name"]
assert app.description == app_args["description"]
assert app.mode == app_args["mode"]
assert app.icon_type == app_args["icon_type"]
assert app.icon == app_args["icon"]
assert app.icon_background == app_args["icon_background"]
assert app.tenant_id == tenant.id
assert app.api_rph == app_args["api_rph"]
assert app.api_rpm == app_args["api_rpm"]
assert app.created_by == account.id
assert app.updated_by == account.id
assert app.status == "normal"
assert app.enable_site is True
assert app.enable_api is True
assert app.is_demo is False
assert app.is_public is False
assert app.is_universal is False
def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app creation with different app modes.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_service = AppService()
# Test different app modes
# from AppMode enum in default_app_model_template
app_modes = [v.value for v in default_app_templates]
for mode in app_modes:
app_args = {
"name": f"{fake.company()} {mode}",
"description": f"Test app for {mode} mode",
"mode": mode,
"icon_type": "emoji",
"icon": "🚀",
"icon_background": "#4ECDC4",
}
app = app_service.create_app(tenant.id, app_args, account)
# Verify app mode was set correctly
assert app.mode == mode
assert app.name == app_args["name"]
assert app.tenant_id == tenant.id
assert app.created_by == account.id
def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app retrieval.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
app_service = AppService()
created_app = app_service.create_app(tenant.id, app_args, account)
# Get app using the service
retrieved_app = app_service.get_app(created_app)
# Verify retrieved app matches created app
assert retrieved_app.id == created_app.id
assert retrieved_app.name == created_app.name
assert retrieved_app.description == created_app.description
assert retrieved_app.mode == created_app.mode
assert retrieved_app.tenant_id == created_app.tenant_id
assert retrieved_app.created_by == created_app.created_by
def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful paginated app list retrieval.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_service = AppService()
# Create multiple apps
app_names = [fake.company() for _ in range(5)]
for name in app_names:
app_args = {
"name": name,
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "📱",
"icon_background": "#96CEB4",
}
app_service.create_app(tenant.id, app_args, account)
# Get paginated apps
args = {
"page": 1,
"limit": 10,
"mode": "chat",
}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
# Verify pagination results
assert paginated_apps is not None
assert len(paginated_apps.items) >= 5 # Should have at least 5 apps
assert paginated_apps.page == 1
assert paginated_apps.per_page == 10
# Verify all apps belong to the correct tenant
for app in paginated_apps.items:
assert app.tenant_id == tenant.id
assert app.mode == "chat"
def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test paginated app list with various filters.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_service = AppService()
# Create apps with different modes
chat_app_args = {
"name": "Chat App",
"description": "A chat application",
"mode": "chat",
"icon_type": "emoji",
"icon": "💬",
"icon_background": "#FF6B6B",
}
completion_app_args = {
"name": "Completion App",
"description": "A completion application",
"mode": "completion",
"icon_type": "emoji",
"icon": "✍️",
"icon_background": "#4ECDC4",
}
chat_app = app_service.create_app(tenant.id, chat_app_args, account)
completion_app = app_service.create_app(tenant.id, completion_app_args, account)
# Test filter by mode
chat_args = {
"page": 1,
"limit": 10,
"mode": "chat",
}
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, chat_args)
assert len(chat_apps.items) == 1
assert chat_apps.items[0].mode == "chat"
# Test filter by name
name_args = {
"page": 1,
"limit": 10,
"mode": "chat",
"name": "Chat",
}
filtered_apps = app_service.get_paginate_apps(account.id, tenant.id, name_args)
assert len(filtered_apps.items) == 1
assert "Chat" in filtered_apps.items[0].name
# Test filter by created_by_me
created_by_me_args = {
"page": 1,
"limit": 10,
"mode": "completion",
"is_created_by_me": True,
}
my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args)
assert len(my_apps.items) == 1
def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test paginated app list with tag filters.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_service = AppService()
# Create an app
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🏷️",
"icon_background": "#FFEAA7",
}
app = app_service.create_app(tenant.id, app_args, account)
# Mock TagService to return the app ID for tag filtering
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
mock_tag_service.return_value = [app.id]
# Test with tag filter
args = {
"page": 1,
"limit": 10,
"mode": "chat",
"tag_ids": ["tag1", "tag2"],
}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
# Verify tag service was called
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"])
# Verify results
assert paginated_apps is not None
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].id == app.id
# Test with tag filter that returns no results
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
mock_tag_service.return_value = []
args = {
"page": 1,
"limit": 10,
"mode": "chat",
"tag_ids": ["nonexistent_tag"],
}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
# Should return None when no apps match tag filter
assert paginated_apps is None
def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app update with all fields.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original values
original_name = app.name
original_description = app.description
original_icon = app.icon
original_icon_background = app.icon_background
original_use_icon_as_answer_icon = app.use_icon_as_answer_icon
# Update app
update_args = {
"name": "Updated App Name",
"description": "Updated app description",
"icon_type": "emoji",
"icon": "🔄",
"icon_background": "#FF8C42",
"use_icon_as_answer_icon": True,
}
with patch("flask_login.utils._get_user", return_value=account):
updated_app = app_service.update_app(app, update_args)
# Verify updated fields
assert updated_app.name == update_args["name"]
assert updated_app.description == update_args["description"]
assert updated_app.icon == update_args["icon"]
assert updated_app.icon_background == update_args["icon_background"]
assert updated_app.use_icon_as_answer_icon is True
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app name update.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original name
original_name = app.name
# Update app name
new_name = "New App Name"
with patch("flask_login.utils._get_user", return_value=account):
updated_app = app_service.update_app_name(app, new_name)
assert updated_app.name == new_name
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
assert updated_app.description == app.description
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app icon update.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original values
original_icon = app.icon
original_icon_background = app.icon_background
# Update app icon
new_icon = "🌟"
new_icon_background = "#FFD93D"
with patch("flask_login.utils._get_user", return_value=account):
updated_app = app_service.update_app_icon(app, new_icon, new_icon_background)
assert updated_app.icon == new_icon
assert updated_app.icon_background == new_icon_background
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
assert updated_app.name == app.name
assert updated_app.description == app.description
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app site status update.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🌐",
"icon_background": "#74B9FF",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original site status
original_site_status = app.enable_site
# Update site status to disabled
with patch("flask_login.utils._get_user", return_value=account):
updated_app = app_service.update_app_site_status(app, False)
assert updated_app.enable_site is False
assert updated_app.updated_by == account.id
# Update site status back to enabled
with patch("flask_login.utils._get_user", return_value=account):
updated_app = app_service.update_app_site_status(updated_app, True)
assert updated_app.enable_site is True
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
assert updated_app.name == app.name
assert updated_app.description == app.description
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app API status update.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🔌",
"icon_background": "#A29BFE",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original API status
original_api_status = app.enable_api
# Update API status to disabled
with patch("flask_login.utils._get_user", return_value=account):
updated_app = app_service.update_app_api_status(app, False)
assert updated_app.enable_api is False
assert updated_app.updated_by == account.id
# Update API status back to enabled
with patch("flask_login.utils._get_user", return_value=account):
updated_app = app_service.update_app_api_status(updated_app, True)
assert updated_app.enable_api is True
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
assert updated_app.name == app.name
assert updated_app.description == app.description
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app site status update when status doesn't change.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🔄",
"icon_background": "#FD79A8",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original values
original_site_status = app.enable_site
original_updated_at = app.updated_at
# Update site status to the same value (no change)
updated_app = app_service.update_app_site_status(app, original_site_status)
# Verify app is returned unchanged
assert updated_app.id == app.id
assert updated_app.enable_site == original_site_status
assert updated_app.updated_at == original_updated_at
# Verify other fields remain unchanged
assert updated_app.name == app.name
assert updated_app.description == app.description
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app deletion.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🗑️",
"icon_background": "#E17055",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store app ID for verification
app_id = app.id
# Mock the async deletion task
with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task:
mock_delete_task.delay.return_value = None
# Delete app
app_service.delete_app(app)
# Verify async deletion task was called
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
# Verify app was deleted from database
from extensions.ext_database import db
deleted_app = db.session.query(App).filter_by(id=app_id).first()
assert deleted_app is None
def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app deletion with related data cleanup.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🧹",
"icon_background": "#00B894",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store app ID for verification
app_id = app.id
# Mock webapp auth cleanup
mock_external_service_dependencies[
"feature_service"
].get_system_features.return_value.webapp_auth.enabled = True
# Mock the async deletion task
with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task:
mock_delete_task.delay.return_value = None
# Delete app
app_service.delete_app(app)
# Verify webapp auth cleanup was called
mock_external_service_dependencies["enterprise_service"].WebAppAuth.cleanup_webapp.assert_called_once_with(
app_id
)
# Verify async deletion task was called
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
# Verify app was deleted from database
from extensions.ext_database import db
deleted_app = db.session.query(App).filter_by(id=app_id).first()
assert deleted_app is None
def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app metadata retrieval.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "📊",
"icon_background": "#6C5CE7",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Get app metadata
app_meta = app_service.get_app_meta(app)
# Verify metadata contains expected fields
assert "tool_icons" in app_meta
# Note: get_app_meta currently only returns tool_icons
def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app code retrieval by app ID.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🔗",
"icon_background": "#FDCB6E",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Get app code by ID
app_code = AppService.get_app_code_by_id(app.id)
# Verify app code was retrieved correctly
# Note: Site would be created when App is created, site.code is auto-generated
assert app_code is not None
assert len(app_code) > 0
def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app ID retrieval by app code.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🆔",
"icon_background": "#E84393",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Create a site for the app
site = Site()
site.app_id = app.id
site.code = fake.postalcode()
site.title = fake.company()
site.status = "normal"
site.default_language = "en-US"
site.customize_token_strategy = "uuid"
from extensions.ext_database import db
db.session.add(site)
db.session.commit()
# Get app ID by code
app_id = AppService.get_app_id_by_code(site.code)
# Verify app ID was retrieved correctly
assert app_id == app.id
def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app creation with invalid mode.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Setup app creation arguments with invalid mode
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "invalid_mode", # Invalid mode
"icon_type": "emoji",
"icon": "",
"icon_background": "#D63031",
}
app_service = AppService()
# Attempt to create app with invalid mode
with pytest.raises(ValueError, match="invalid mode value"):
app_service.create_app(tenant.id, app_args, account)

View File

@ -0,0 +1,252 @@
import pytest
from controllers.console.app.app import _validate_description_length as app_validate
from controllers.console.datasets.datasets import _validate_description_length as dataset_validate
from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate
class TestDescriptionValidationUnit:
"""Unit tests for description validation functions in App and Dataset APIs"""
def test_app_validate_description_length_valid(self):
"""Test App validation function with valid descriptions"""
# Empty string should be valid
assert app_validate("") == ""
# None should be valid
assert app_validate(None) is None
# Short description should be valid
short_desc = "Short description"
assert app_validate(short_desc) == short_desc
# Exactly 400 characters should be valid
exactly_400 = "x" * 400
assert app_validate(exactly_400) == exactly_400
# Just under limit should be valid
just_under = "x" * 399
assert app_validate(just_under) == just_under
def test_app_validate_description_length_invalid(self):
"""Test App validation function with invalid descriptions"""
# 401 characters should fail
just_over = "x" * 401
with pytest.raises(ValueError) as exc_info:
app_validate(just_over)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
# 500 characters should fail
way_over = "x" * 500
with pytest.raises(ValueError) as exc_info:
app_validate(way_over)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
# 1000 characters should fail
very_long = "x" * 1000
with pytest.raises(ValueError) as exc_info:
app_validate(very_long)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
def test_dataset_validate_description_length_valid(self):
"""Test Dataset validation function with valid descriptions"""
# Empty string should be valid
assert dataset_validate("") == ""
# Short description should be valid
short_desc = "Short description"
assert dataset_validate(short_desc) == short_desc
# Exactly 400 characters should be valid
exactly_400 = "x" * 400
assert dataset_validate(exactly_400) == exactly_400
# Just under limit should be valid
just_under = "x" * 399
assert dataset_validate(just_under) == just_under
def test_dataset_validate_description_length_invalid(self):
"""Test Dataset validation function with invalid descriptions"""
# 401 characters should fail
just_over = "x" * 401
with pytest.raises(ValueError) as exc_info:
dataset_validate(just_over)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
# 500 characters should fail
way_over = "x" * 500
with pytest.raises(ValueError) as exc_info:
dataset_validate(way_over)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
def test_service_dataset_validate_description_length_valid(self):
"""Test Service Dataset validation function with valid descriptions"""
# Empty string should be valid
assert service_dataset_validate("") == ""
# None should be valid
assert service_dataset_validate(None) is None
# Short description should be valid
short_desc = "Short description"
assert service_dataset_validate(short_desc) == short_desc
# Exactly 400 characters should be valid
exactly_400 = "x" * 400
assert service_dataset_validate(exactly_400) == exactly_400
# Just under limit should be valid
just_under = "x" * 399
assert service_dataset_validate(just_under) == just_under
def test_service_dataset_validate_description_length_invalid(self):
"""Test Service Dataset validation function with invalid descriptions"""
# 401 characters should fail
just_over = "x" * 401
with pytest.raises(ValueError) as exc_info:
service_dataset_validate(just_over)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
# 500 characters should fail
way_over = "x" * 500
with pytest.raises(ValueError) as exc_info:
service_dataset_validate(way_over)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
def test_app_dataset_validation_consistency(self):
"""Test that App and Dataset validation functions behave identically"""
test_cases = [
"", # Empty string
"Short description", # Normal description
"x" * 100, # Medium description
"x" * 400, # Exactly at limit
]
# Test valid cases produce same results
for test_desc in test_cases:
assert app_validate(test_desc) == dataset_validate(test_desc) == service_dataset_validate(test_desc)
# Test invalid cases produce same errors
invalid_cases = [
"x" * 401, # Just over limit
"x" * 500, # Way over limit
"x" * 1000, # Very long
]
for invalid_desc in invalid_cases:
app_error = None
dataset_error = None
service_dataset_error = None
# Capture App validation error
try:
app_validate(invalid_desc)
except ValueError as e:
app_error = str(e)
# Capture Dataset validation error
try:
dataset_validate(invalid_desc)
except ValueError as e:
dataset_error = str(e)
# Capture Service Dataset validation error
try:
service_dataset_validate(invalid_desc)
except ValueError as e:
service_dataset_error = str(e)
# All should produce errors
assert app_error is not None, f"App validation should fail for {len(invalid_desc)} characters"
assert dataset_error is not None, f"Dataset validation should fail for {len(invalid_desc)} characters"
error_msg = f"Service Dataset validation should fail for {len(invalid_desc)} characters"
assert service_dataset_error is not None, error_msg
# Errors should be identical
error_msg = f"Error messages should be identical for {len(invalid_desc)} characters"
assert app_error == dataset_error == service_dataset_error, error_msg
assert app_error == "Description cannot exceed 400 characters."
def test_boundary_values(self):
"""Test boundary values around the 400 character limit"""
boundary_tests = [
(0, True), # Empty
(1, True), # Minimum
(399, True), # Just under limit
(400, True), # Exactly at limit
(401, False), # Just over limit
(402, False), # Over limit
(500, False), # Way over limit
]
for length, should_pass in boundary_tests:
test_desc = "x" * length
if should_pass:
# Should not raise exception
assert app_validate(test_desc) == test_desc
assert dataset_validate(test_desc) == test_desc
assert service_dataset_validate(test_desc) == test_desc
else:
# Should raise ValueError
with pytest.raises(ValueError):
app_validate(test_desc)
with pytest.raises(ValueError):
dataset_validate(test_desc)
with pytest.raises(ValueError):
service_dataset_validate(test_desc)
def test_special_characters(self):
"""Test validation with special characters, Unicode, etc."""
# Unicode characters
unicode_desc = "测试描述" * 100 # Chinese characters
if len(unicode_desc) <= 400:
assert app_validate(unicode_desc) == unicode_desc
assert dataset_validate(unicode_desc) == unicode_desc
assert service_dataset_validate(unicode_desc) == unicode_desc
# Special characters
special_desc = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?" * 10
if len(special_desc) <= 400:
assert app_validate(special_desc) == special_desc
assert dataset_validate(special_desc) == special_desc
assert service_dataset_validate(special_desc) == special_desc
# Mixed content
mixed_desc = "Mixed content: 测试 123 !@# " * 15
if len(mixed_desc) <= 400:
assert app_validate(mixed_desc) == mixed_desc
assert dataset_validate(mixed_desc) == mixed_desc
assert service_dataset_validate(mixed_desc) == mixed_desc
elif len(mixed_desc) > 400:
with pytest.raises(ValueError):
app_validate(mixed_desc)
with pytest.raises(ValueError):
dataset_validate(mixed_desc)
with pytest.raises(ValueError):
service_dataset_validate(mixed_desc)
def test_whitespace_handling(self):
"""Test validation with various whitespace scenarios"""
# Leading/trailing whitespace
whitespace_desc = " Description with whitespace "
if len(whitespace_desc) <= 400:
assert app_validate(whitespace_desc) == whitespace_desc
assert dataset_validate(whitespace_desc) == whitespace_desc
assert service_dataset_validate(whitespace_desc) == whitespace_desc
# Newlines and tabs
multiline_desc = "Line 1\nLine 2\tTabbed content"
if len(multiline_desc) <= 400:
assert app_validate(multiline_desc) == multiline_desc
assert dataset_validate(multiline_desc) == multiline_desc
assert service_dataset_validate(multiline_desc) == multiline_desc
# Only whitespace over limit
only_spaces = " " * 401
with pytest.raises(ValueError):
app_validate(only_spaces)
with pytest.raises(ValueError):
dataset_validate(only_spaces)
with pytest.raises(ValueError):
service_dataset_validate(only_spaces)

View File

@ -0,0 +1,336 @@
"""
Unit tests for Service API File Preview endpoint
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from controllers.service_api.app.error import FileAccessDeniedError, FileNotFoundError
from controllers.service_api.app.file_preview import FilePreviewApi
from models.model import App, EndUser, Message, MessageFile, UploadFile
class TestFilePreviewApi:
"""Test suite for FilePreviewApi"""
@pytest.fixture
def file_preview_api(self):
"""Create FilePreviewApi instance for testing"""
return FilePreviewApi()
@pytest.fixture
def mock_app(self):
"""Mock App model"""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
return app
@pytest.fixture
def mock_end_user(self):
"""Mock EndUser model"""
end_user = Mock(spec=EndUser)
end_user.id = str(uuid.uuid4())
return end_user
@pytest.fixture
def mock_upload_file(self):
"""Mock UploadFile model"""
upload_file = Mock(spec=UploadFile)
upload_file.id = str(uuid.uuid4())
upload_file.name = "test_file.jpg"
upload_file.mime_type = "image/jpeg"
upload_file.size = 1024
upload_file.key = "storage/key/test_file.jpg"
upload_file.tenant_id = str(uuid.uuid4())
return upload_file
@pytest.fixture
def mock_message_file(self):
"""Mock MessageFile model"""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.upload_file_id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
return message_file
@pytest.fixture
def mock_message(self):
"""Mock Message model"""
message = Mock(spec=Message)
message.id = str(uuid.uuid4())
message.app_id = str(uuid.uuid4())
return message
def test_validate_file_ownership_success(
self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message
):
"""Test successful file ownership validation"""
file_id = str(uuid.uuid4())
app_id = mock_app.id
# Set up the mocks
mock_upload_file.tenant_id = mock_app.tenant_id
mock_message.app_id = app_id
mock_message_file.upload_file_id = file_id
mock_message_file.message_id = mock_message.id
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock database queries
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_message_file, # MessageFile query
mock_message, # Message query
mock_upload_file, # UploadFile query
mock_app, # App query for tenant validation
]
# Execute the method
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
# Assertions
assert result_message_file == mock_message_file
assert result_upload_file == mock_upload_file
def test_validate_file_ownership_file_not_found(self, file_preview_api):
"""Test file ownership validation when MessageFile not found"""
file_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock MessageFile not found
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Execute and assert exception
with pytest.raises(FileNotFoundError) as exc_info:
file_preview_api._validate_file_ownership(file_id, app_id)
assert "File not found in message context" in str(exc_info.value)
def test_validate_file_ownership_access_denied(self, file_preview_api, mock_message_file):
"""Test file ownership validation when Message not owned by app"""
file_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock MessageFile found but Message not owned by app
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_message_file, # MessageFile query - found
None, # Message query - not found (access denied)
]
# Execute and assert exception
with pytest.raises(FileAccessDeniedError) as exc_info:
file_preview_api._validate_file_ownership(file_id, app_id)
assert "not owned by requesting app" in str(exc_info.value)
def test_validate_file_ownership_upload_file_not_found(self, file_preview_api, mock_message_file, mock_message):
"""Test file ownership validation when UploadFile not found"""
file_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock MessageFile and Message found but UploadFile not found
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_message_file, # MessageFile query - found
mock_message, # Message query - found
None, # UploadFile query - not found
]
# Execute and assert exception
with pytest.raises(FileNotFoundError) as exc_info:
file_preview_api._validate_file_ownership(file_id, app_id)
assert "Upload file record not found" in str(exc_info.value)
def test_validate_file_ownership_tenant_mismatch(
self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message
):
"""Test file ownership validation with tenant mismatch"""
file_id = str(uuid.uuid4())
app_id = mock_app.id
# Set up tenant mismatch
mock_upload_file.tenant_id = "different_tenant_id"
mock_app.tenant_id = "app_tenant_id"
mock_message.app_id = app_id
mock_message_file.upload_file_id = file_id
mock_message_file.message_id = mock_message.id
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock database queries
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_message_file, # MessageFile query
mock_message, # Message query
mock_upload_file, # UploadFile query
mock_app, # App query for tenant validation
]
# Execute and assert exception
with pytest.raises(FileAccessDeniedError) as exc_info:
file_preview_api._validate_file_ownership(file_id, app_id)
assert "tenant mismatch" in str(exc_info.value)
def test_validate_file_ownership_invalid_input(self, file_preview_api):
"""Test file ownership validation with invalid input"""
# Test with empty file_id
with pytest.raises(FileAccessDeniedError) as exc_info:
file_preview_api._validate_file_ownership("", "app_id")
assert "Invalid file or app identifier" in str(exc_info.value)
# Test with empty app_id
with pytest.raises(FileAccessDeniedError) as exc_info:
file_preview_api._validate_file_ownership("file_id", "")
assert "Invalid file or app identifier" in str(exc_info.value)
def test_build_file_response_basic(self, file_preview_api, mock_upload_file):
"""Test basic file response building"""
mock_generator = Mock()
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
# Check response properties
assert response.mimetype == mock_upload_file.mime_type
assert response.direct_passthrough is True
assert response.headers["Content-Length"] == str(mock_upload_file.size)
assert "Cache-Control" in response.headers
def test_build_file_response_as_attachment(self, file_preview_api, mock_upload_file):
"""Test file response building with attachment flag"""
mock_generator = Mock()
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, True)
# Check attachment-specific headers
assert "attachment" in response.headers["Content-Disposition"]
assert mock_upload_file.name in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file):
"""Test file response building for audio/video files"""
mock_generator = Mock()
mock_upload_file.mime_type = "video/mp4"
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
# Check Range support for media files
assert response.headers["Accept-Ranges"] == "bytes"
def test_build_file_response_no_size(self, file_preview_api, mock_upload_file):
"""Test file response building when size is unknown"""
mock_generator = Mock()
mock_upload_file.size = 0 # Unknown size
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
# Content-Length should not be set when size is unknown
assert "Content-Length" not in response.headers
@patch("controllers.service_api.app.file_preview.storage")
def test_get_method_integration(
self, mock_storage, file_preview_api, mock_app, mock_end_user, mock_upload_file, mock_message_file, mock_message
):
"""Test the full GET method integration (without decorator)"""
file_id = str(uuid.uuid4())
app_id = mock_app.id
# Set up mocks
mock_upload_file.tenant_id = mock_app.tenant_id
mock_message.app_id = app_id
mock_message_file.upload_file_id = file_id
mock_message_file.message_id = mock_message.id
mock_generator = Mock()
mock_storage.load.return_value = mock_generator
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock database queries
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_message_file, # MessageFile query
mock_message, # Message query
mock_upload_file, # UploadFile query
mock_app, # App query for tenant validation
]
with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse:
# Mock request parsing
mock_parser = Mock()
mock_parser.parse_args.return_value = {"as_attachment": False}
mock_reqparse.RequestParser.return_value = mock_parser
# Test the core logic directly without Flask decorators
# Validate file ownership
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
assert result_message_file == mock_message_file
assert result_upload_file == mock_upload_file
# Test file response building
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
assert response is not None
# Verify storage was called correctly
mock_storage.load.assert_not_called() # Since we're testing components separately
@patch("controllers.service_api.app.file_preview.storage")
def test_storage_error_handling(
self, mock_storage, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message
):
"""Test storage error handling in the core logic"""
file_id = str(uuid.uuid4())
app_id = mock_app.id
# Set up mocks
mock_upload_file.tenant_id = mock_app.tenant_id
mock_message.app_id = app_id
mock_message_file.upload_file_id = file_id
mock_message_file.message_id = mock_message.id
# Mock storage error
mock_storage.load.side_effect = Exception("Storage error")
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock database queries for validation
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_message_file, # MessageFile query
mock_message, # Message query
mock_upload_file, # UploadFile query
mock_app, # App query for tenant validation
]
# First validate file ownership works
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
assert result_message_file == mock_message_file
assert result_upload_file == mock_upload_file
# Test storage error handling
with pytest.raises(Exception) as exc_info:
mock_storage.load(mock_upload_file.key, stream=True)
assert "Storage error" in str(exc_info.value)
@patch("controllers.service_api.app.file_preview.logger")
def test_validate_file_ownership_unexpected_error_logging(self, mock_logger, file_preview_api):
"""Test that unexpected errors are logged properly"""
file_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock database query to raise unexpected exception
mock_db.session.query.side_effect = Exception("Unexpected database error")
# Execute and assert exception
with pytest.raises(FileAccessDeniedError) as exc_info:
file_preview_api._validate_file_ownership(file_id, app_id)
# Verify error message
assert "File access validation failed" in str(exc_info.value)
# Verify logging was called
mock_logger.exception.assert_called_once_with(
"Unexpected error during file ownership validation",
extra={"file_id": file_id, "app_id": app_id, "error": "Unexpected database error"},
)

View File

@ -0,0 +1,419 @@
"""Test conversation variable handling in AdvancedChatAppRunner."""
from unittest.mock import MagicMock, patch
from uuid import uuid4
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.variables import SegmentType
from factories import variable_factory
from models import ConversationVariable, Workflow
class TestAdvancedChatAppRunnerConversationVariables:
"""Test that AdvancedChatAppRunner correctly handles conversation variables."""
def test_missing_conversation_variables_are_added(self):
"""Test that new conversation variables added to workflow are created for existing conversations."""
# Setup
app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_id = str(uuid4())
# Create workflow with two conversation variables
workflow_vars = [
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var1",
"name": "existing_var",
"value_type": SegmentType.STRING,
"value": "default1",
}
),
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var2",
"name": "new_var",
"value_type": SegmentType.STRING,
"value": "default2",
}
),
]
# Mock workflow with conversation variables
mock_workflow = MagicMock(spec=Workflow)
mock_workflow.conversation_variables = workflow_vars
mock_workflow.tenant_id = str(uuid4())
mock_workflow.app_id = app_id
mock_workflow.id = workflow_id
mock_workflow.type = "chat"
mock_workflow.graph_dict = {}
mock_workflow.environment_variables = []
# Create existing conversation variable (only var1 exists in DB)
existing_db_var = MagicMock(spec=ConversationVariable)
existing_db_var.id = "var1"
existing_db_var.app_id = app_id
existing_db_var.conversation_id = conversation_id
existing_db_var.to_variable = MagicMock(return_value=workflow_vars[0])
# Mock conversation and message
mock_conversation = MagicMock()
mock_conversation.app_id = app_id
mock_conversation.id = conversation_id
mock_message = MagicMock()
mock_message.id = str(uuid4())
# Mock app config
mock_app_config = MagicMock()
mock_app_config.app_id = app_id
mock_app_config.workflow_id = workflow_id
mock_app_config.tenant_id = str(uuid4())
# Mock app generate entity
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
mock_app_generate_entity.app_config = mock_app_config
mock_app_generate_entity.inputs = {}
mock_app_generate_entity.query = "test query"
mock_app_generate_entity.files = []
mock_app_generate_entity.user_id = str(uuid4())
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
mock_app_generate_entity.workflow_run_id = str(uuid4())
mock_app_generate_entity.call_depth = 0
mock_app_generate_entity.single_iteration_run = None
mock_app_generate_entity.single_loop_run = None
mock_app_generate_entity.trace_manager = None
# Create runner
runner = AdvancedChatAppRunner(
application_generate_entity=mock_app_generate_entity,
queue_manager=MagicMock(),
conversation=mock_conversation,
message=mock_message,
dialogue_count=1,
variable_loader=MagicMock(),
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
)
# Mock database session
mock_session = MagicMock(spec=Session)
# First query returns only existing variable
mock_scalars_result = MagicMock()
mock_scalars_result.all.return_value = [existing_db_var]
mock_session.scalars.return_value = mock_scalars_result
# Track what gets added to session
added_items = []
def track_add_all(items):
added_items.extend(items)
mock_session.add_all.side_effect = track_add_all
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(runner, "handle_input_moderation", return_value=False),
patch.object(runner, "handle_annotation_reply", return_value=False),
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock graph initialization
mock_init_graph.return_value = MagicMock()
# Mock workflow entry
mock_workflow_entry = MagicMock()
mock_workflow_entry.run.return_value = iter([]) # Empty generator
mock_workflow_entry_class.return_value = mock_workflow_entry
# Run the method
runner.run()
# Verify that the missing variable was added
assert len(added_items) == 1, "Should have added exactly one missing variable"
# Check that the added item is the missing variable (var2)
added_var = added_items[0]
assert hasattr(added_var, "id"), "Added item should be a ConversationVariable"
# Note: Since we're mocking ConversationVariable.from_variable,
# we can't directly check the id, but we can verify add_all was called
assert mock_session.add_all.called, "Session add_all should have been called"
assert mock_session.commit.called, "Session commit should have been called"
def test_no_variables_creates_all(self):
"""Test that all conversation variables are created when none exist in DB."""
# Setup
app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_id = str(uuid4())
# Create workflow with conversation variables
workflow_vars = [
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var1",
"name": "var1",
"value_type": SegmentType.STRING,
"value": "default1",
}
),
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var2",
"name": "var2",
"value_type": SegmentType.STRING,
"value": "default2",
}
),
]
# Mock workflow
mock_workflow = MagicMock(spec=Workflow)
mock_workflow.conversation_variables = workflow_vars
mock_workflow.tenant_id = str(uuid4())
mock_workflow.app_id = app_id
mock_workflow.id = workflow_id
mock_workflow.type = "chat"
mock_workflow.graph_dict = {}
mock_workflow.environment_variables = []
# Mock conversation and message
mock_conversation = MagicMock()
mock_conversation.app_id = app_id
mock_conversation.id = conversation_id
mock_message = MagicMock()
mock_message.id = str(uuid4())
# Mock app config
mock_app_config = MagicMock()
mock_app_config.app_id = app_id
mock_app_config.workflow_id = workflow_id
mock_app_config.tenant_id = str(uuid4())
# Mock app generate entity
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
mock_app_generate_entity.app_config = mock_app_config
mock_app_generate_entity.inputs = {}
mock_app_generate_entity.query = "test query"
mock_app_generate_entity.files = []
mock_app_generate_entity.user_id = str(uuid4())
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
mock_app_generate_entity.workflow_run_id = str(uuid4())
mock_app_generate_entity.call_depth = 0
mock_app_generate_entity.single_iteration_run = None
mock_app_generate_entity.single_loop_run = None
mock_app_generate_entity.trace_manager = None
# Create runner
runner = AdvancedChatAppRunner(
application_generate_entity=mock_app_generate_entity,
queue_manager=MagicMock(),
conversation=mock_conversation,
message=mock_message,
dialogue_count=1,
variable_loader=MagicMock(),
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
)
# Mock database session
mock_session = MagicMock(spec=Session)
# Query returns empty list (no existing variables)
mock_scalars_result = MagicMock()
mock_scalars_result.all.return_value = []
mock_session.scalars.return_value = mock_scalars_result
# Track what gets added to session
added_items = []
def track_add_all(items):
added_items.extend(items)
mock_session.add_all.side_effect = track_add_all
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(runner, "handle_input_moderation", return_value=False),
patch.object(runner, "handle_annotation_reply", return_value=False),
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
patch("core.app.apps.advanced_chat.app_runner.ConversationVariable") as mock_conv_var_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock ConversationVariable.from_variable to return mock objects
mock_conv_vars = []
for var in workflow_vars:
mock_cv = MagicMock()
mock_cv.id = var.id
mock_cv.to_variable.return_value = var
mock_conv_vars.append(mock_cv)
mock_conv_var_class.from_variable.side_effect = mock_conv_vars
# Mock graph initialization
mock_init_graph.return_value = MagicMock()
# Mock workflow entry
mock_workflow_entry = MagicMock()
mock_workflow_entry.run.return_value = iter([]) # Empty generator
mock_workflow_entry_class.return_value = mock_workflow_entry
# Run the method
runner.run()
# Verify that all variables were created
assert len(added_items) == 2, "Should have added both variables"
assert mock_session.add_all.called, "Session add_all should have been called"
assert mock_session.commit.called, "Session commit should have been called"
def test_all_variables_exist_no_changes(self):
"""Test that no changes are made when all variables already exist in DB."""
# Setup
app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_id = str(uuid4())
# Create workflow with conversation variables
workflow_vars = [
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var1",
"name": "var1",
"value_type": SegmentType.STRING,
"value": "default1",
}
),
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var2",
"name": "var2",
"value_type": SegmentType.STRING,
"value": "default2",
}
),
]
# Mock workflow
mock_workflow = MagicMock(spec=Workflow)
mock_workflow.conversation_variables = workflow_vars
mock_workflow.tenant_id = str(uuid4())
mock_workflow.app_id = app_id
mock_workflow.id = workflow_id
mock_workflow.type = "chat"
mock_workflow.graph_dict = {}
mock_workflow.environment_variables = []
# Create existing conversation variables (both exist in DB)
existing_db_vars = []
for var in workflow_vars:
db_var = MagicMock(spec=ConversationVariable)
db_var.id = var.id
db_var.app_id = app_id
db_var.conversation_id = conversation_id
db_var.to_variable = MagicMock(return_value=var)
existing_db_vars.append(db_var)
# Mock conversation and message
mock_conversation = MagicMock()
mock_conversation.app_id = app_id
mock_conversation.id = conversation_id
mock_message = MagicMock()
mock_message.id = str(uuid4())
# Mock app config
mock_app_config = MagicMock()
mock_app_config.app_id = app_id
mock_app_config.workflow_id = workflow_id
mock_app_config.tenant_id = str(uuid4())
# Mock app generate entity
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
mock_app_generate_entity.app_config = mock_app_config
mock_app_generate_entity.inputs = {}
mock_app_generate_entity.query = "test query"
mock_app_generate_entity.files = []
mock_app_generate_entity.user_id = str(uuid4())
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
mock_app_generate_entity.workflow_run_id = str(uuid4())
mock_app_generate_entity.call_depth = 0
mock_app_generate_entity.single_iteration_run = None
mock_app_generate_entity.single_loop_run = None
mock_app_generate_entity.trace_manager = None
# Create runner
runner = AdvancedChatAppRunner(
application_generate_entity=mock_app_generate_entity,
queue_manager=MagicMock(),
conversation=mock_conversation,
message=mock_message,
dialogue_count=1,
variable_loader=MagicMock(),
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
)
# Mock database session
mock_session = MagicMock(spec=Session)
# Query returns all existing variables
mock_scalars_result = MagicMock()
mock_scalars_result.all.return_value = existing_db_vars
mock_session.scalars.return_value = mock_scalars_result
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(runner, "handle_input_moderation", return_value=False),
patch.object(runner, "handle_annotation_reply", return_value=False),
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock graph initialization
mock_init_graph.return_value = MagicMock()
# Mock workflow entry
mock_workflow_entry = MagicMock()
mock_workflow_entry.run.return_value = iter([]) # Empty generator
mock_workflow_entry_class.return_value = mock_workflow_entry
# Run the method
runner.run()
# Verify that no variables were added
assert not mock_session.add_all.called, "Session add_all should not have been called"
assert mock_session.commit.called, "Session commit should still be called"

View File

@ -49,7 +49,7 @@ def test_executor_with_json_body_and_number_variable():
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == []
assert executor.params is None
assert executor.json == {"number": 42}
assert executor.data is None
assert executor.files is None
@ -102,7 +102,7 @@ def test_executor_with_json_body_and_object_variable():
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == []
assert executor.params is None
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
assert executor.data is None
assert executor.files is None
@ -157,7 +157,7 @@ def test_executor_with_json_body_and_nested_object_variable():
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == []
assert executor.params is None
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
assert executor.data is None
assert executor.files is None
@ -245,7 +245,7 @@ def test_executor_with_form_data():
assert executor.url == "https://api.example.com/upload"
assert "Content-Type" in executor.headers
assert "multipart/form-data" in executor.headers["Content-Type"]
assert executor.params == []
assert executor.params is None
assert executor.json is None
# '__multipart_placeholder__' is expected when no file inputs exist,
# to ensure the request is treated as multipart/form-data by the backend.

View File

@ -0,0 +1,127 @@
import uuid
from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from services.conversation_service import ConversationService
class TestConversationService:
def test_pagination_with_empty_include_ids(self):
"""Test that empty include_ids returns empty result"""
mock_session = MagicMock()
mock_app_model = MagicMock(id=str(uuid.uuid4()))
mock_user = MagicMock(id=str(uuid.uuid4()))
result = ConversationService.pagination_by_last_id(
session=mock_session,
app_model=mock_app_model,
user=mock_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
include_ids=[], # Empty include_ids should return empty result
exclude_ids=None,
)
assert result.data == []
assert result.has_more is False
assert result.limit == 20
def test_pagination_with_non_empty_include_ids(self):
"""Test that non-empty include_ids filters properly"""
mock_session = MagicMock()
mock_app_model = MagicMock(id=str(uuid.uuid4()))
mock_user = MagicMock(id=str(uuid.uuid4()))
# Mock the query results
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
mock_session.scalars.return_value.all.return_value = mock_conversations
mock_session.scalar.return_value = 0
with patch("services.conversation_service.select") as mock_select:
mock_stmt = MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
mock_stmt.order_by.return_value = mock_stmt
mock_stmt.limit.return_value = mock_stmt
mock_stmt.subquery.return_value = MagicMock()
result = ConversationService.pagination_by_last_id(
session=mock_session,
app_model=mock_app_model,
user=mock_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
include_ids=["conv1", "conv2"], # Non-empty include_ids
exclude_ids=None,
)
# Verify the where clause was called with id.in_
assert mock_stmt.where.called
def test_pagination_with_empty_exclude_ids(self):
"""Test that empty exclude_ids doesn't filter"""
mock_session = MagicMock()
mock_app_model = MagicMock(id=str(uuid.uuid4()))
mock_user = MagicMock(id=str(uuid.uuid4()))
# Mock the query results
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(5)]
mock_session.scalars.return_value.all.return_value = mock_conversations
mock_session.scalar.return_value = 0
with patch("services.conversation_service.select") as mock_select:
mock_stmt = MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
mock_stmt.order_by.return_value = mock_stmt
mock_stmt.limit.return_value = mock_stmt
mock_stmt.subquery.return_value = MagicMock()
result = ConversationService.pagination_by_last_id(
session=mock_session,
app_model=mock_app_model,
user=mock_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
include_ids=None,
exclude_ids=[], # Empty exclude_ids should not filter
)
# Result should contain the mocked conversations
assert len(result.data) == 5
def test_pagination_with_non_empty_exclude_ids(self):
"""Test that non-empty exclude_ids filters properly"""
mock_session = MagicMock()
mock_app_model = MagicMock(id=str(uuid.uuid4()))
mock_user = MagicMock(id=str(uuid.uuid4()))
# Mock the query results
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
mock_session.scalars.return_value.all.return_value = mock_conversations
mock_session.scalar.return_value = 0
with patch("services.conversation_service.select") as mock_select:
mock_stmt = MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
mock_stmt.order_by.return_value = mock_stmt
mock_stmt.limit.return_value = mock_stmt
mock_stmt.subquery.return_value = MagicMock()
result = ConversationService.pagination_by_last_id(
session=mock_session,
app_model=mock_app_model,
user=mock_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
include_ids=None,
exclude_ids=["conv1", "conv2"], # Non-empty exclude_ids
)
# Verify the where clause was called for exclusion
assert mock_stmt.where.called

View File

@ -983,6 +983,25 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/42/1f/935d0810b73184a1d306f92458cb0a2e9b0de2377f536da874e063b8e422/clickhouse_connect-0.7.19-cp312-cp312-win_amd64.whl", hash = "sha256:b771ca6a473d65103dcae82810d3a62475c5372fc38d8f211513c72b954fb020", size = 239584, upload-time = "2024-08-21T21:36:22.105Z" },
]
[[package]]
name = "clickzetta-connector-python"
version = "0.8.102"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "future" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "pandas" },
{ name = "pyarrow" },
{ name = "python-dateutil" },
{ name = "requests" },
{ name = "sqlalchemy" },
{ name = "urllib3" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/c6/e5/23dcc950e873127df0135cf45144062a3207f5d2067259c73854e8ce7228/clickzetta_connector_python-0.8.102-py3-none-any.whl", hash = "sha256:c45486ae77fd82df7113ec67ec50e772372588d79c23757f8ee6291a057994a7", size = 77861, upload-time = "2025-07-17T03:11:59.543Z" },
]
[[package]]
name = "cloudscraper"
version = "1.2.71"
@ -1383,6 +1402,7 @@ vdb = [
{ name = "alibabacloud-tea-openapi" },
{ name = "chromadb" },
{ name = "clickhouse-connect" },
{ name = "clickzetta-connector-python" },
{ name = "couchbase" },
{ name = "elasticsearch" },
{ name = "mo-vector" },
@ -1568,6 +1588,7 @@ vdb = [
{ name = "alibabacloud-tea-openapi", specifier = "~=0.3.9" },
{ name = "chromadb", specifier = "==0.5.20" },
{ name = "clickhouse-connect", specifier = "~=0.7.16" },
{ name = "clickzetta-connector-python", specifier = ">=0.8.102" },
{ name = "couchbase", specifier = "~=4.3.0" },
{ name = "elasticsearch", specifier = "==8.14.0" },
{ name = "mo-vector", specifier = "~=0.1.13" },
@ -2111,7 +2132,7 @@ wheels = [
[[package]]
name = "google-cloud-bigquery"
version = "3.34.0"
version = "3.30.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-api-core", extra = ["grpc"] },
@ -2122,9 +2143,9 @@ dependencies = [
{ name = "python-dateutil" },
{ name = "requests" },
]
sdist = { url = "https://files.pythonhosted.org/packages/24/f9/e9da2d56d7028f05c0e2f5edf6ce43c773220c3172666c3dd925791d763d/google_cloud_bigquery-3.34.0.tar.gz", hash = "sha256:5ee1a78ba5c2ccb9f9a8b2bf3ed76b378ea68f49b6cac0544dc55cc97ff7c1ce", size = 489091, upload-time = "2025-05-29T17:18:06.03Z" }
sdist = { url = "https://files.pythonhosted.org/packages/f1/2f/3dda76b3ec029578838b1fe6396e6b86eb574200352240e23dea49265bb7/google_cloud_bigquery-3.30.0.tar.gz", hash = "sha256:7e27fbafc8ed33cc200fe05af12ecd74d279fe3da6692585a3cef7aee90575b6", size = 474389, upload-time = "2025-02-27T18:49:45.416Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b1/7e/7115c4f67ca0bc678f25bff1eab56cc37d06eb9a3978940b2ebd0705aa0a/google_cloud_bigquery-3.34.0-py3-none-any.whl", hash = "sha256:de20ded0680f8136d92ff5256270b5920dfe4fae479f5d0f73e90e5df30b1cf7", size = 253555, upload-time = "2025-05-29T17:18:02.904Z" },
{ url = "https://files.pythonhosted.org/packages/0c/6d/856a6ca55c1d9d99129786c929a27dd9d31992628ebbff7f5d333352981f/google_cloud_bigquery-3.30.0-py2.py3-none-any.whl", hash = "sha256:f4d28d846a727f20569c9b2d2f4fa703242daadcb2ec4240905aa485ba461877", size = 247885, upload-time = "2025-02-27T18:49:43.454Z" },
]
[[package]]
@ -3918,11 +3939,11 @@ wheels = [
[[package]]
name = "packaging"
version = "24.2"
version = "23.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950, upload-time = "2024-11-08T09:47:47.202Z" }
sdist = { url = "https://files.pythonhosted.org/packages/fb/2b/9b9c33ffed44ee921d0967086d653047286054117d584f1b1a7c22ceaf7b/packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5", size = 146714, upload-time = "2023-10-01T13:50:05.279Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451, upload-time = "2024-11-08T09:47:44.722Z" },
{ url = "https://files.pythonhosted.org/packages/ec/1a/610693ac4ee14fcdf2d9bf3c493370e4f2ef7ae2e19217d7a237ff42367d/packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7", size = 53011, upload-time = "2023-10-01T13:50:03.745Z" },
]
[[package]]
@ -4302,6 +4323,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" },
]
[[package]]
name = "pyarrow"
version = "14.0.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d7/8b/d18b7eb6fb22e5ed6ffcbc073c85dae635778dbd1270a6cf5d750b031e84/pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025", size = 1063645, upload-time = "2023-12-18T15:43:41.625Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/94/8a/411ef0b05483076b7f548c74ccaa0f90c1e60d3875db71a821f6ffa8cf42/pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b", size = 26904455, upload-time = "2023-12-18T15:40:43.477Z" },
{ url = "https://files.pythonhosted.org/packages/6c/6c/882a57798877e3a49ba54d8e0540bea24aed78fb42e1d860f08c3449c75e/pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23", size = 23997116, upload-time = "2023-12-18T15:40:48.533Z" },
{ url = "https://files.pythonhosted.org/packages/ec/3f/ef47fe6192ce4d82803a073db449b5292135406c364a7fc49dfbcd34c987/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200", size = 35944575, upload-time = "2023-12-18T15:40:55.128Z" },
{ url = "https://files.pythonhosted.org/packages/1a/90/2021e529d7f234a3909f419d4341d53382541ef77d957fa274a99c533b18/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696", size = 38079719, upload-time = "2023-12-18T15:41:02.565Z" },
{ url = "https://files.pythonhosted.org/packages/30/a9/474caf5fd54a6d5315aaf9284c6e8f5d071ca825325ad64c53137b646e1f/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a", size = 35429706, upload-time = "2023-12-18T15:41:09.955Z" },
{ url = "https://files.pythonhosted.org/packages/d9/f8/cfba56f5353e51c19b0c240380ce39483f4c76e5c4aee5a000f3d75b72da/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02", size = 38001476, upload-time = "2023-12-18T15:41:16.372Z" },
{ url = "https://files.pythonhosted.org/packages/43/3f/7bdf7dc3b3b0cfdcc60760e7880954ba99ccd0bc1e0df806f3dd61bc01cd/pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b", size = 24576230, upload-time = "2023-12-18T15:41:22.561Z" },
{ url = "https://files.pythonhosted.org/packages/69/5b/d8ab6c20c43b598228710e4e4a6cba03a01f6faa3d08afff9ce76fd0fd47/pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944", size = 26819585, upload-time = "2023-12-18T15:41:27.59Z" },
{ url = "https://files.pythonhosted.org/packages/2d/29/bed2643d0dd5e9570405244a61f6db66c7f4704a6e9ce313f84fa5a3675a/pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5", size = 23965222, upload-time = "2023-12-18T15:41:32.449Z" },
{ url = "https://files.pythonhosted.org/packages/2a/34/da464632e59a8cdd083370d69e6c14eae30221acb284f671c6bc9273fadd/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422", size = 35942036, upload-time = "2023-12-18T15:41:38.767Z" },
{ url = "https://files.pythonhosted.org/packages/a8/ff/cbed4836d543b29f00d2355af67575c934999ff1d43e3f438ab0b1b394f1/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07", size = 38089266, upload-time = "2023-12-18T15:41:47.617Z" },
{ url = "https://files.pythonhosted.org/packages/38/41/345011cb831d3dbb2dab762fc244c745a5df94b199223a99af52a5f7dff6/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591", size = 35404468, upload-time = "2023-12-18T15:41:54.49Z" },
{ url = "https://files.pythonhosted.org/packages/fd/af/2fc23ca2068ff02068d8dabf0fb85b6185df40ec825973470e613dbd8790/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379", size = 38003134, upload-time = "2023-12-18T15:42:01.593Z" },
{ url = "https://files.pythonhosted.org/packages/95/1f/9d912f66a87e3864f694e000977a6a70a644ea560289eac1d733983f215d/pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d", size = 25043754, upload-time = "2023-12-18T15:42:07.108Z" },
]
[[package]]
name = "pyasn1"
version = "0.6.1"

View File

@ -333,6 +333,25 @@ OPENDAL_SCHEME=fs
# Configurations for OpenDAL Local File System.
OPENDAL_FS_ROOT=storage
# ClickZetta Volume Configuration (for storage backend)
# To use ClickZetta Volume as storage backend, set STORAGE_TYPE=clickzetta-volume
# Note: ClickZetta Volume will reuse the existing CLICKZETTA_* connection parameters
# Volume type selection (three types available):
# - user: Personal/small team use, simple config, user-level permissions
# - table: Enterprise multi-tenant, smart routing, table-level + user-level permissions
# - external: Data lake integration, external storage connection, volume-level + storage-level permissions
CLICKZETTA_VOLUME_TYPE=user
# External Volume name (required only when TYPE=external)
CLICKZETTA_VOLUME_NAME=
# Table Volume table prefix (used only when TYPE=table)
CLICKZETTA_VOLUME_TABLE_PREFIX=dataset_
# Dify file directory prefix (isolates from other apps, recommended to keep default)
CLICKZETTA_VOLUME_DIFY_PREFIX=dify_km
# S3 Configuration
#
S3_ENDPOINT=
@ -416,7 +435,7 @@ SUPABASE_URL=your-server-url
# ------------------------------
# The type of vector store to use.
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`.
VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
@ -655,6 +674,20 @@ TABLESTORE_ACCESS_KEY_ID=xxx
TABLESTORE_ACCESS_KEY_SECRET=xxx
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false
# Clickzetta configuration, only available when VECTOR_STORE is `clickzetta`
CLICKZETTA_USERNAME=
CLICKZETTA_PASSWORD=
CLICKZETTA_INSTANCE=
CLICKZETTA_SERVICE=api.clickzetta.com
CLICKZETTA_WORKSPACE=quick_start
CLICKZETTA_VCLUSTER=default_ap
CLICKZETTA_SCHEMA=dify
CLICKZETTA_BATCH_SIZE=100
CLICKZETTA_ENABLE_INVERTED_INDEX=true
CLICKZETTA_ANALYZER_TYPE=chinese
CLICKZETTA_ANALYZER_MODE=smart
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance
# ------------------------------
# Knowledge Configuration
# ------------------------------

View File

@ -93,6 +93,10 @@ x-shared-env: &shared-api-worker-env
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
CLICKZETTA_VOLUME_TYPE: ${CLICKZETTA_VOLUME_TYPE:-user}
CLICKZETTA_VOLUME_NAME: ${CLICKZETTA_VOLUME_NAME:-}
CLICKZETTA_VOLUME_TABLE_PREFIX: ${CLICKZETTA_VOLUME_TABLE_PREFIX:-dataset_}
CLICKZETTA_VOLUME_DIFY_PREFIX: ${CLICKZETTA_VOLUME_DIFY_PREFIX:-dify_km}
S3_ENDPOINT: ${S3_ENDPOINT:-}
S3_REGION: ${S3_REGION:-us-east-1}
S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai}
@ -313,6 +317,18 @@ x-shared-env: &shared-api-worker-env
TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx}
TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx}
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false}
CLICKZETTA_USERNAME: ${CLICKZETTA_USERNAME:-}
CLICKZETTA_PASSWORD: ${CLICKZETTA_PASSWORD:-}
CLICKZETTA_INSTANCE: ${CLICKZETTA_INSTANCE:-}
CLICKZETTA_SERVICE: ${CLICKZETTA_SERVICE:-api.clickzetta.com}
CLICKZETTA_WORKSPACE: ${CLICKZETTA_WORKSPACE:-quick_start}
CLICKZETTA_VCLUSTER: ${CLICKZETTA_VCLUSTER:-default_ap}
CLICKZETTA_SCHEMA: ${CLICKZETTA_SCHEMA:-dify}
CLICKZETTA_BATCH_SIZE: ${CLICKZETTA_BATCH_SIZE:-100}
CLICKZETTA_ENABLE_INVERTED_INDEX: ${CLICKZETTA_ENABLE_INVERTED_INDEX:-true}
CLICKZETTA_ANALYZER_TYPE: ${CLICKZETTA_ANALYZER_TYPE:-chinese}
CLICKZETTA_ANALYZER_MODE: ${CLICKZETTA_ANALYZER_MODE:-smart}
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: ${CLICKZETTA_VECTOR_DISTANCE_FUNCTION:-cosine_distance}
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
ETL_TYPE: ${ETL_TYPE:-dify}

View File

@ -0,0 +1,97 @@
/**
* Description Validation Test
*
* Tests for the 400-character description validation across App and Dataset
* creation and editing workflows to ensure consistent validation behavior.
*/
describe('Description Validation Logic', () => {
// Simulate backend validation function
const validateDescriptionLength = (description?: string | null) => {
if (description && description.length > 400)
throw new Error('Description cannot exceed 400 characters.')
return description
}
describe('Backend Validation Function', () => {
test('allows description within 400 characters', () => {
const validDescription = 'x'.repeat(400)
expect(() => validateDescriptionLength(validDescription)).not.toThrow()
expect(validateDescriptionLength(validDescription)).toBe(validDescription)
})
test('allows empty description', () => {
expect(() => validateDescriptionLength('')).not.toThrow()
expect(() => validateDescriptionLength(null)).not.toThrow()
expect(() => validateDescriptionLength(undefined)).not.toThrow()
})
test('rejects description exceeding 400 characters', () => {
const invalidDescription = 'x'.repeat(401)
expect(() => validateDescriptionLength(invalidDescription)).toThrow(
'Description cannot exceed 400 characters.',
)
})
})
describe('Backend Validation Consistency', () => {
test('App and Dataset have consistent validation limits', () => {
const maxLength = 400
const validDescription = 'x'.repeat(maxLength)
const invalidDescription = 'x'.repeat(maxLength + 1)
// Both should accept exactly 400 characters
expect(validDescription.length).toBe(400)
expect(() => validateDescriptionLength(validDescription)).not.toThrow()
// Both should reject 401 characters
expect(invalidDescription.length).toBe(401)
expect(() => validateDescriptionLength(invalidDescription)).toThrow()
})
test('validation error messages are consistent', () => {
const expectedErrorMessage = 'Description cannot exceed 400 characters.'
// This would be the error message from both App and Dataset backend validation
expect(expectedErrorMessage).toBe('Description cannot exceed 400 characters.')
const invalidDescription = 'x'.repeat(401)
try {
validateDescriptionLength(invalidDescription)
}
catch (error) {
expect((error as Error).message).toBe(expectedErrorMessage)
}
})
})
describe('Character Length Edge Cases', () => {
const testCases = [
{ length: 0, shouldPass: true, description: 'empty description' },
{ length: 1, shouldPass: true, description: '1 character' },
{ length: 399, shouldPass: true, description: '399 characters' },
{ length: 400, shouldPass: true, description: '400 characters (boundary)' },
{ length: 401, shouldPass: false, description: '401 characters (over limit)' },
{ length: 500, shouldPass: false, description: '500 characters' },
{ length: 1000, shouldPass: false, description: '1000 characters' },
]
testCases.forEach(({ length, shouldPass, description }) => {
test(`handles ${description} correctly`, () => {
const testDescription = length > 0 ? 'x'.repeat(length) : ''
expect(testDescription.length).toBe(length)
if (shouldPass) {
expect(() => validateDescriptionLength(testDescription)).not.toThrow()
expect(validateDescriptionLength(testDescription)).toBe(testDescription)
}
else {
expect(() => validateDescriptionLength(testDescription)).toThrow(
'Description cannot exceed 400 characters.',
)
}
})
})
})
})

View File

@ -4,7 +4,6 @@ import React, { useEffect, useMemo } from 'react'
import { usePathname } from 'next/navigation'
import useSWR from 'swr'
import { useTranslation } from 'react-i18next'
import { useBoolean } from 'ahooks'
import {
RiEqualizer2Fill,
RiEqualizer2Line,
@ -44,17 +43,12 @@ type IExtraInfoProps = {
}
const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => {
const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile)
const { t } = useTranslation()
const docLink = useDocLink()
const hasRelatedApps = relatedApps?.data && relatedApps?.data?.length > 0
const relatedAppsTotal = relatedApps?.data?.length || 0
useEffect(() => {
setShowTips(!isMobile)
}, [isMobile, setShowTips])
return <div>
{/* Related apps for desktop */}
<div className={classNames(

View File

@ -1,131 +0,0 @@
'use client'
import { useEffect, useMemo, useState } from 'react'
import { useContext } from 'use-context-selector'
import { useTranslation } from 'react-i18next'
import { RiListUnordered } from '@remixicon/react'
import TemplateEn from './template/template.en.mdx'
import TemplateZh from './template/template.zh.mdx'
import TemplateJa from './template/template.ja.mdx'
import I18n from '@/context/i18n'
import { LanguagesSupported } from '@/i18n-config/language'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import cn from '@/utils/classnames'
type DocProps = {
apiBaseUrl: string
}
const Doc = ({ apiBaseUrl }: DocProps) => {
const { locale } = useContext(I18n)
const { t } = useTranslation()
const [toc, setToc] = useState<Array<{ href: string; text: string }>>([])
const [isTocExpanded, setIsTocExpanded] = useState(false)
const { theme } = useTheme()
// Set initial TOC expanded state based on screen width
useEffect(() => {
const mediaQuery = window.matchMedia('(min-width: 1280px)')
setIsTocExpanded(mediaQuery.matches)
}, [])
// Extract TOC from article content
useEffect(() => {
const extractTOC = () => {
const article = document.querySelector('article')
if (article) {
const headings = article.querySelectorAll('h2')
const tocItems = Array.from(headings).map((heading) => {
const anchor = heading.querySelector('a')
if (anchor) {
return {
href: anchor.getAttribute('href') || '',
text: anchor.textContent || '',
}
}
return null
}).filter((item): item is { href: string; text: string } => item !== null)
setToc(tocItems)
}
}
setTimeout(extractTOC, 0)
}, [locale])
// Handle TOC item click
const handleTocClick = (e: React.MouseEvent<HTMLAnchorElement>, item: { href: string; text: string }) => {
e.preventDefault()
const targetId = item.href.replace('#', '')
const element = document.getElementById(targetId)
if (element) {
const scrollContainer = document.querySelector('.scroll-container')
if (scrollContainer) {
const headerOffset = -40
const elementTop = element.offsetTop - headerOffset
scrollContainer.scrollTo({
top: elementTop,
behavior: 'smooth',
})
}
}
}
const Template = useMemo(() => {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateZh apiBaseUrl={apiBaseUrl} />
case LanguagesSupported[7]:
return <TemplateJa apiBaseUrl={apiBaseUrl} />
default:
return <TemplateEn apiBaseUrl={apiBaseUrl} />
}
}, [apiBaseUrl, locale])
return (
<div className="flex">
<div className={`fixed right-20 top-32 z-10 transition-all ${isTocExpanded ? 'w-64' : 'w-10'}`}>
{isTocExpanded
? (
<nav className="toc max-h-[calc(100vh-150px)] w-full overflow-y-auto rounded-lg bg-components-panel-bg p-4 shadow-md">
<div className="mb-4 flex items-center justify-between">
<h3 className="text-lg font-semibold text-text-primary">{t('appApi.develop.toc')}</h3>
<button
onClick={() => setIsTocExpanded(false)}
className="text-text-tertiary hover:text-text-secondary"
>
</button>
</div>
<ul className="space-y-2">
{toc.map((item, index) => (
<li key={index}>
<a
href={item.href}
className="text-text-secondary transition-colors duration-200 hover:text-text-primary hover:underline"
onClick={e => handleTocClick(e, item)}
>
{item.text}
</a>
</li>
))}
</ul>
</nav>
)
: (
<button
onClick={() => setIsTocExpanded(true)}
className="flex h-10 w-10 items-center justify-center rounded-full bg-components-button-secondary-bg shadow-md transition-colors duration-200 hover:bg-components-button-secondary-bg-hover"
>
<RiListUnordered className="h-6 w-6 text-components-button-secondary-text" />
</button>
)}
</div>
<article className={cn('prose-xl prose mx-1 rounded-t-xl bg-background-default px-4 pt-16 sm:mx-12', theme === Theme.dark && 'prose-invert')}>
{Template}
</article>
</div>
)
}
export default Doc

View File

@ -9,10 +9,10 @@ import { useQuery } from '@tanstack/react-query'
// Components
import ExternalAPIPanel from '../../components/datasets/external-api/external-api-panel'
import Datasets from './Datasets'
import DatasetFooter from './DatasetFooter'
import Datasets from './datasets'
import DatasetFooter from './dataset-footer'
import ApiServer from '../../components/develop/ApiServer'
import Doc from './Doc'
import Doc from './doc'
import TabSliderNew from '@/app/components/base/tab-slider-new'
import TagManagementModal from '@/app/components/base/tag-management'
import TagFilter from '@/app/components/base/tag-management/filter'
@ -86,8 +86,8 @@ const Container = () => {
}, [currentWorkspace, router])
return (
<div ref={containerRef} className='scroll-container relative flex grow flex-col overflow-y-auto bg-background-body'>
<div className='sticky top-0 z-10 flex h-[80px] shrink-0 flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]'>
<div ref={containerRef} className={`scroll-container relative flex grow flex-col overflow-y-auto rounded-t-xl outline-none ${activeTab === 'dataset' ? 'bg-background-body' : 'bg-components-panel-bg'}`}>
<div className={`sticky top-0 z-10 flex shrink-0 flex-wrap items-center justify-between gap-y-2 rounded-t-xl px-6 py-2 ${activeTab === 'api' ? 'border-b border-solid border-b-divider-regular' : ''} ${activeTab === 'dataset' ? 'bg-background-body' : 'bg-components-panel-bg'}`}>
<TabSliderNew
value={activeTab}
onChange={newActiveTab => setActiveTab(newActiveTab)}

View File

@ -5,6 +5,7 @@ import { useRouter } from 'next/navigation'
import { useCallback, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { RiMoreFill } from '@remixicon/react'
import { mutate } from 'swr'
import cn from '@/utils/classnames'
import Confirm from '@/app/components/base/confirm'
import { ToastContext } from '@/app/components/base/toast'
@ -57,6 +58,19 @@ const DatasetCard = ({
const onConfirmDelete = useCallback(async () => {
try {
await deleteDataset(dataset.id)
// Clear SWR cache to prevent stale data in knowledge retrieval nodes
mutate(
(key) => {
if (typeof key === 'string') return key.includes('/datasets')
if (typeof key === 'object' && key !== null)
return key.url === '/datasets' || key.url?.includes('/datasets')
return false
},
undefined,
{ revalidate: true },
)
notify({ type: 'success', message: t('dataset.datasetDeleted') })
if (onSuccess)
onSuccess()
@ -162,24 +176,19 @@ const DatasetCard = ({
</div>
<div
className={cn(
'mb-2 max-h-[72px] grow px-[14px] text-xs leading-normal text-text-tertiary group-hover:line-clamp-2 group-hover:max-h-[36px]',
tags.length ? 'line-clamp-2' : 'line-clamp-4',
'mb-2 line-clamp-2 max-h-[36px] grow px-[14px] text-xs leading-normal text-text-tertiary',
!dataset.embedding_available && 'opacity-50 hover:opacity-100',
)}
title={dataset.description}>
{dataset.description}
</div>
<div className={cn(
'mt-4 h-[42px] shrink-0 items-center pb-[6px] pl-[14px] pr-[6px] pt-1',
tags.length ? 'flex' : '!hidden group-hover:!flex',
)}>
<div className='mt-4 flex h-[42px] shrink-0 items-center pb-[6px] pl-[14px] pr-[6px] pt-1'>
<div className={cn('flex w-0 grow items-center gap-1', !dataset.embedding_available && 'opacity-50 hover:opacity-100')} onClick={(e) => {
e.stopPropagation()
e.preventDefault()
}}>
<div className={cn(
'mr-[41px] w-full grow group-hover:!mr-0 group-hover:!block',
tags.length ? '!block' : '!hidden',
'mr-[41px] w-full grow group-hover:!mr-0',
)}>
<TagSelector
position='bl'

View File

@ -3,8 +3,8 @@
import { useCallback, useEffect, useRef } from 'react'
import useSWRInfinite from 'swr/infinite'
import { debounce } from 'lodash-es'
import NewDatasetCard from './NewDatasetCard'
import DatasetCard from './DatasetCard'
import NewDatasetCard from './new-dataset-card'
import DatasetCard from './dataset-card'
import type { DataSetListResponse, FetchDatasetsParams } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets'
import { useAppContext } from '@/context/app-context'
@ -36,7 +36,7 @@ const getKey = (
}
type Props = {
containerRef: React.RefObject<HTMLDivElement>
containerRef: React.RefObject<HTMLDivElement | null>
tags: string[]
keywords: string
includeAll: boolean

View File

@ -0,0 +1,203 @@
'use client'
import { useEffect, useMemo, useState } from 'react'
import { useContext } from 'use-context-selector'
import { useTranslation } from 'react-i18next'
import { RiCloseLine, RiListUnordered } from '@remixicon/react'
import TemplateEn from './template/template.en.mdx'
import TemplateZh from './template/template.zh.mdx'
import TemplateJa from './template/template.ja.mdx'
import I18n from '@/context/i18n'
import { LanguagesSupported } from '@/i18n-config/language'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import cn from '@/utils/classnames'
type DocProps = {
apiBaseUrl: string
}
const Doc = ({ apiBaseUrl }: DocProps) => {
const { locale } = useContext(I18n)
const { t } = useTranslation()
const [toc, setToc] = useState<Array<{ href: string; text: string }>>([])
const [isTocExpanded, setIsTocExpanded] = useState(false)
const [activeSection, setActiveSection] = useState<string>('')
const { theme } = useTheme()
// Set initial TOC expanded state based on screen width
useEffect(() => {
const mediaQuery = window.matchMedia('(min-width: 1280px)')
setIsTocExpanded(mediaQuery.matches)
}, [])
// Extract TOC from article content
useEffect(() => {
const extractTOC = () => {
const article = document.querySelector('article')
if (article) {
const headings = article.querySelectorAll('h2')
const tocItems = Array.from(headings).map((heading) => {
const anchor = heading.querySelector('a')
if (anchor) {
return {
href: anchor.getAttribute('href') || '',
text: anchor.textContent || '',
}
}
return null
}).filter((item): item is { href: string; text: string } => item !== null)
setToc(tocItems)
// Set initial active section
if (tocItems.length > 0)
setActiveSection(tocItems[0].href.replace('#', ''))
}
}
setTimeout(extractTOC, 0)
}, [locale])
// Track scroll position for active section highlighting
useEffect(() => {
const handleScroll = () => {
const scrollContainer = document.querySelector('.scroll-container')
if (!scrollContainer || toc.length === 0)
return
// Find active section based on scroll position
let currentSection = ''
toc.forEach((item) => {
const targetId = item.href.replace('#', '')
const element = document.getElementById(targetId)
if (element) {
const rect = element.getBoundingClientRect()
// Consider section active if its top is above the middle of viewport
if (rect.top <= window.innerHeight / 2)
currentSection = targetId
}
})
if (currentSection && currentSection !== activeSection)
setActiveSection(currentSection)
}
const scrollContainer = document.querySelector('.scroll-container')
if (scrollContainer) {
scrollContainer.addEventListener('scroll', handleScroll)
handleScroll() // Initial check
return () => scrollContainer.removeEventListener('scroll', handleScroll)
}
}, [toc, activeSection])
// Handle TOC item click
const handleTocClick = (e: React.MouseEvent<HTMLAnchorElement>, item: { href: string; text: string }) => {
e.preventDefault()
const targetId = item.href.replace('#', '')
const element = document.getElementById(targetId)
if (element) {
const scrollContainer = document.querySelector('.scroll-container')
if (scrollContainer) {
const headerOffset = -40
const elementTop = element.offsetTop - headerOffset
scrollContainer.scrollTo({
top: elementTop,
behavior: 'smooth',
})
}
}
}
const Template = useMemo(() => {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateZh apiBaseUrl={apiBaseUrl} />
case LanguagesSupported[7]:
return <TemplateJa apiBaseUrl={apiBaseUrl} />
default:
return <TemplateEn apiBaseUrl={apiBaseUrl} />
}
}, [apiBaseUrl, locale])
return (
<div className="flex">
<div className={`fixed right-20 top-32 z-10 transition-all duration-150 ease-out ${isTocExpanded ? 'w-[280px]' : 'w-11'}`}>
{isTocExpanded
? (
<nav className="toc flex max-h-[calc(100vh-150px)] w-full flex-col overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-background-default-hover shadow-xl">
<div className="relative z-10 flex items-center justify-between border-b border-components-panel-border-subtle bg-background-default-hover px-4 py-2.5">
<span className="text-xs font-medium uppercase tracking-wide text-text-tertiary">
{t('appApi.develop.toc')}
</span>
<button
onClick={() => setIsTocExpanded(false)}
className="group flex h-6 w-6 items-center justify-center rounded-md transition-colors hover:bg-state-base-hover"
aria-label="Close"
>
<RiCloseLine className="h-3 w-3 text-text-quaternary transition-colors group-hover:text-text-secondary" />
</button>
</div>
<div className="from-components-panel-border-subtle/20 pointer-events-none absolute left-0 right-0 top-[41px] z-10 h-2 bg-gradient-to-b to-transparent"></div>
<div className="pointer-events-none absolute left-0 right-0 top-[43px] z-10 h-3 bg-gradient-to-b from-background-default-hover to-transparent"></div>
<div className="relative flex-1 overflow-y-auto px-3 py-3 pt-1">
{toc.length === 0 ? (
<div className="px-2 py-8 text-center text-xs text-text-quaternary">
{t('appApi.develop.noContent')}
</div>
) : (
<ul className="space-y-0.5">
{toc.map((item, index) => {
const isActive = activeSection === item.href.replace('#', '')
return (
<li key={index}>
<a
href={item.href}
onClick={e => handleTocClick(e, item)}
className={cn(
'group relative flex items-center rounded-md px-3 py-2 text-[13px] transition-all duration-200',
isActive
? 'bg-state-base-hover font-medium text-text-primary'
: 'text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary',
)}
>
<span
className={cn(
'mr-2 h-1.5 w-1.5 rounded-full transition-all duration-200',
isActive
? 'scale-100 bg-text-accent'
: 'scale-75 bg-components-panel-border',
)}
/>
<span className="flex-1 truncate">
{item.text}
</span>
</a>
</li>
)
})}
</ul>
)}
</div>
<div className="pointer-events-none absolute bottom-0 left-0 right-0 z-10 h-4 rounded-b-xl bg-gradient-to-t from-background-default-hover to-transparent"></div>
</nav>
)
: (
<button
onClick={() => setIsTocExpanded(true)}
className="group flex h-11 w-11 items-center justify-center rounded-full border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg transition-all duration-150 hover:bg-background-default-hover hover:shadow-xl"
aria-label="Open table of contents"
>
<RiListUnordered className="h-5 w-5 text-text-tertiary transition-colors group-hover:text-text-secondary" />
</button>
)}
</div>
<article className={cn('prose-xl prose', theme === Theme.dark && 'prose-invert')}>
{Template}
</article>
</div>
)
}
export default Doc

View File

@ -1,6 +1,6 @@
'use client'
import { useTranslation } from 'react-i18next'
import Container from './Container'
import Container from './container'
import useDocumentTitle from '@/hooks/use-document-title'
const AppList = () => {

View File

@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</CodeGroup>
</div>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/document/create-by-text'
@ -163,7 +163,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/document/create-by-file'
@ -294,7 +294,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets'
@ -400,7 +400,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets'
@ -472,7 +472,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}'
@ -553,7 +553,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}'
@ -714,7 +714,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}'
@ -751,7 +751,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
@ -853,7 +853,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
@ -952,7 +952,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{batch}/indexing-status'
@ -1007,7 +1007,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}'
@ -1047,7 +1047,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents'
@ -1122,7 +1122,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}'
@ -1245,7 +1245,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
___
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/status/{action}'
@ -1302,7 +1302,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments'
@ -1388,7 +1388,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments'
@ -1476,7 +1476,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
@ -1546,7 +1546,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
@ -1590,7 +1590,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
@ -1679,7 +1679,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
@ -1750,7 +1750,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
@ -1827,7 +1827,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
@ -1873,7 +1873,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
@ -1947,7 +1947,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/upload-file'
@ -1998,7 +1998,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/retrieve'
@ -2177,7 +2177,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata'
@ -2224,7 +2224,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata/{metadata_id}'
@ -2273,7 +2273,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata/{metadata_id}'
@ -2306,7 +2306,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata/built-in/{action}'
@ -2339,7 +2339,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/metadata'
@ -2378,7 +2378,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata'
@ -2424,7 +2424,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/workspaces/current/models/model-types/text-embedding'
@ -2528,7 +2528,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
Okay, I will translate the Chinese text in your document while keeping all formatting and code content unchanged.
<Heading
@ -2574,7 +2574,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags'
@ -2615,7 +2615,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags'
@ -2662,7 +2662,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
@ -2704,7 +2704,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags/binding'
@ -2746,7 +2746,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags/unbinding'
@ -2789,7 +2789,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/<uuid:dataset_id>/tags'
@ -2837,7 +2837,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Row>

View File

@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</CodeGroup>
</div>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/document/create-by-text'
@ -163,7 +163,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/document/create-by-file'
@ -294,7 +294,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets'
@ -399,7 +399,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets'
@ -471,7 +471,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}'
@ -508,7 +508,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
@ -610,7 +610,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
@ -709,7 +709,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{batch}/indexing-status'
@ -764,7 +764,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}'
@ -804,7 +804,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents'
@ -879,7 +879,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}'
@ -1002,7 +1002,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
___
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
@ -1060,7 +1060,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments'
@ -1146,7 +1146,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments'
@ -1234,7 +1234,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
@ -1304,7 +1304,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
method='DELETE'
@ -1347,7 +1347,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
method='POST'
@ -1435,7 +1435,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
@ -1506,7 +1506,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
@ -1583,7 +1583,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
@ -1629,7 +1629,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
@ -1703,7 +1703,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/upload-file'
@ -1754,7 +1754,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/retrieve'
@ -1933,7 +1933,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata'
@ -1980,7 +1980,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata/{metadata_id}'
@ -2029,7 +2029,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata/{metadata_id}'
@ -2062,7 +2062,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata/built-in/{action}'
@ -2095,7 +2095,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/metadata'
@ -2136,7 +2136,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata'
@ -2182,7 +2182,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags'
method='POST'
@ -2226,7 +2226,7 @@ ___
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags'
@ -2267,7 +2267,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags'
@ -2314,7 +2314,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
@ -2356,7 +2356,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags/binding'
@ -2398,7 +2398,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags/unbinding'
@ -2441,7 +2441,7 @@ ___
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/<uuid:dataset_id>/tags'
@ -2489,7 +2489,7 @@ ___
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Row>
<Col>

View File

@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</CodeGroup>
</div>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/document/create-by-text'
@ -167,7 +167,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/document/create-by-file'
@ -298,7 +298,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets'
@ -403,7 +403,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets'
@ -475,7 +475,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}'
@ -556,7 +556,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}'
@ -721,7 +721,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}'
@ -758,7 +758,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
@ -860,7 +860,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
@ -959,7 +959,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{batch}/indexing-status'
@ -1014,7 +1014,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}'
@ -1054,7 +1054,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents'
@ -1129,7 +1129,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}'
@ -1252,7 +1252,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Col>
</Row>
___
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
@ -1310,7 +1310,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments'
@ -1396,7 +1396,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments'
@ -1484,7 +1484,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
@ -1528,7 +1528,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
@ -1598,7 +1598,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
method='POST'
@ -1687,7 +1687,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
@ -1758,7 +1758,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
@ -1835,7 +1835,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
@ -1881,7 +1881,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Row>
<Col>
@ -1915,7 +1915,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
@ -1989,7 +1989,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/upload-file'
@ -2040,7 +2040,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/retrieve'
@ -2219,7 +2219,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata'
@ -2266,7 +2266,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata/{metadata_id}'
@ -2315,7 +2315,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata/{metadata_id}'
@ -2348,7 +2348,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata/built-in/{action}'
@ -2381,7 +2381,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/documents/metadata'
@ -2422,7 +2422,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/{dataset_id}/metadata'
@ -2468,7 +2468,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/workspaces/current/models/model-types/text-embedding'
@ -2572,7 +2572,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags'
@ -2617,7 +2617,7 @@ ___
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags'
@ -2658,7 +2658,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags'
@ -2705,7 +2705,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
@ -2747,7 +2747,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags/binding'
@ -2789,7 +2789,7 @@ ___
</Col>
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/tags/unbinding'
@ -2832,7 +2832,7 @@ ___
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Heading
url='/datasets/<uuid:dataset_id>/tags'
@ -2880,7 +2880,7 @@ ___
</Row>
<hr className='ml-0 mr-0' />
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
<Row>
<Col>

View File

@ -87,7 +87,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
<Avatar {...props} />
<div
onClick={() => { setIsShowAvatarPicker(true) }}
className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black bg-opacity-50 opacity-0 transition-opacity group-hover:opacity-100"
className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black/50 opacity-0 transition-opacity group-hover:opacity-100"
>
<span className="text-xs text-white">
<RiPencilLine />

View File

@ -12,7 +12,6 @@ import {
RiFileUploadLine,
} from '@remixicon/react'
import AppIcon from '../base/app-icon'
import cn from '@/utils/classnames'
import { useStore as useAppStore } from '@/app/components/app/store'
import { ToastContext } from '@/app/components/base/toast'
import { useAppContext } from '@/context/app-context'
@ -31,6 +30,7 @@ import Divider from '../base/divider'
import type { Operation } from './app-operations'
import AppOperations from './app-operations'
import dynamic from 'next/dynamic'
import cn from '@/utils/classnames'
const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), {
ssr: false,
@ -256,32 +256,40 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
}}
className='block w-full'
>
<div className={cn('flex rounded-lg', expand ? 'flex-col gap-2 p-2 pb-2.5' : 'items-start justify-center gap-1 p-1', open && 'bg-state-base-hover', isCurrentWorkspaceEditor && 'cursor-pointer hover:bg-state-base-hover')}>
<div className={`flex items-center self-stretch ${expand ? 'justify-between' : 'flex-col gap-1'}`}>
<AppIcon
size={expand ? 'large' : 'small'}
iconType={appDetail.icon_type}
icon={appDetail.icon}
background={appDetail.icon_background}
imageUrl={appDetail.icon_url}
/>
<div className='flex items-center justify-center rounded-md p-0.5'>
<div className='flex h-5 w-5 items-center justify-center'>
<div className='flex flex-col gap-2 rounded-lg p-1 hover:bg-state-base-hover'>
<div className='flex items-center gap-1'>
<div className={cn(!expand && 'ml-1')}>
<AppIcon
size={expand ? 'large' : 'small'}
iconType={appDetail.icon_type}
icon={appDetail.icon}
background={appDetail.icon_background}
imageUrl={appDetail.icon_url}
/>
</div>
{expand && (
<div className='ml-auto flex items-center justify-center rounded-md p-0.5'>
<div className='flex h-5 w-5 items-center justify-center'>
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
</div>
</div>
)}
</div>
{!expand && (
<div className='flex items-center justify-center'>
<div className='flex h-5 w-5 items-center justify-center rounded-md p-0.5'>
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
</div>
</div>
</div>
<div className={cn(
'flex flex-col items-start gap-1 transition-all duration-200 ease-in-out',
expand
? 'w-auto opacity-100'
: 'pointer-events-none w-0 overflow-hidden opacity-0',
)}>
<div className='flex w-full'>
<div className='system-md-semibold truncate whitespace-nowrap text-text-secondary'>{appDetail.name}</div>
)}
{expand && (
<div className='flex flex-col items-start gap-1'>
<div className='flex w-full'>
<div className='system-md-semibold truncate whitespace-nowrap text-text-secondary'>{appDetail.name}</div>
</div>
<div className='system-2xs-medium-uppercase whitespace-nowrap text-text-tertiary'>{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}</div>
</div>
<div className='system-2xs-medium-uppercase whitespace-nowrap text-text-tertiary'>{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}</div>
</div>
)}
</div>
</button>
)}

View File

@ -32,7 +32,7 @@ const AccessControlDialog = ({
leaveFrom="opacity-100"
leaveTo="opacity-0"
>
<div className="fixed inset-0 bg-background-overlay bg-opacity-25" />
<div className="bg-background-overlay/25 fixed inset-0" />
</Transition.Child>
<div className="fixed inset-0 flex items-center justify-center">

View File

@ -106,7 +106,7 @@ function SelectedGroupsBreadCrumb() {
setSelectedGroupsForBreadcrumb([])
}, [setSelectedGroupsForBreadcrumb])
return <div className='flex h-7 items-center gap-x-0.5 px-2 py-0.5'>
<span className={classNames('system-xs-regular text-text-tertiary', selectedGroupsForBreadcrumb.length > 0 && 'text-text-accent cursor-pointer')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')}</span>
<span className={classNames('system-xs-regular text-text-tertiary', selectedGroupsForBreadcrumb.length > 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')}</span>
{selectedGroupsForBreadcrumb.map((group, index) => {
return <div key={index} className='system-xs-regular flex items-center gap-x-0.5 text-text-tertiary'>
<span>/</span>
@ -198,7 +198,7 @@ type BaseItemProps = {
children: React.ReactNode
}
function BaseItem({ children, className }: BaseItemProps) {
return <div className={classNames('p-1 pl-2 flex items-center space-x-2 hover:rounded-lg hover:bg-state-base-hover cursor-pointer', className)}>
return <div className={classNames('flex cursor-pointer items-center space-x-2 p-1 pl-2 hover:rounded-lg hover:bg-state-base-hover', className)}>
{children}
</div>
}

View File

@ -4,7 +4,6 @@ import React, { useRef, useState } from 'react'
import { useGetState, useInfiniteScroll } from 'ahooks'
import { useTranslation } from 'react-i18next'
import Link from 'next/link'
import produce from 'immer'
import TypeIcon from '../type-icon'
import Modal from '@/app/components/base/modal'
import type { DataSet } from '@/models/datasets'
@ -29,9 +28,10 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
onSelect,
}) => {
const { t } = useTranslation()
const [selected, setSelected] = React.useState<DataSet[]>(selectedIds.map(id => ({ id }) as any))
const [selected, setSelected] = React.useState<DataSet[]>([])
const [loaded, setLoaded] = React.useState(false)
const [datasets, setDataSets] = React.useState<DataSet[] | null>(null)
const [hasInitialized, setHasInitialized] = React.useState(false)
const hasNoData = !datasets || datasets?.length === 0
const canSelectMulti = true
@ -49,19 +49,17 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
const newList = [...(datasets || []), ...data.filter(item => item.indexing_technique || item.provider === 'external')]
setDataSets(newList)
setLoaded(true)
if (!selected.find(item => !item.name))
return { list: [] }
const newSelected = produce(selected, (draft) => {
selected.forEach((item, index) => {
if (!item.name) { // not fetched database
const newItem = newList.find(i => i.id === item.id)
if (newItem)
draft[index] = newItem
}
})
})
setSelected(newSelected)
// Initialize selected datasets based on selectedIds and available datasets
if (!hasInitialized) {
if (selectedIds.length > 0) {
const validSelectedDatasets = selectedIds
.map(id => newList.find(item => item.id === id))
.filter(Boolean) as DataSet[]
setSelected(validSelectedDatasets)
}
setHasInitialized(true)
}
}
return { list: [] }
},

View File

@ -55,8 +55,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
const { data: embeddingsModelList } = useModelList(ModelTypeEnum.textEmbedding)
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const { t } = useTranslation()
const docLink = useDocLink()

View File

@ -40,13 +40,13 @@ type CategoryItemProps = {
}
function CategoryItem({ category, active, onClick }: CategoryItemProps) {
return <li
className={classNames('p-1 pl-3 h-8 rounded-lg flex items-center gap-2 group cursor-pointer hover:bg-state-base-hover [&.active]:bg-state-base-active', active && 'active')}
className={classNames('group flex h-8 cursor-pointer items-center gap-2 rounded-lg p-1 pl-3 hover:bg-state-base-hover [&.active]:bg-state-base-active', active && 'active')}
onClick={() => { onClick?.(category) }}>
{category === AppCategories.RECOMMENDED && <div className='inline-flex h-5 w-5 items-center justify-center rounded-md'>
<RiThumbUpLine className='h-4 w-4 text-components-menu-item-text group-[.active]:text-components-menu-item-text-active' />
</div>}
<AppCategoryLabel category={category}
className={classNames('system-sm-medium text-components-menu-item-text group-[.active]:text-components-menu-item-text-active group-hover:text-components-menu-item-text-hover', active && 'system-sm-semibold')} />
className={classNames('system-sm-medium text-components-menu-item-text group-hover:text-components-menu-item-text-hover group-[.active]:text-components-menu-item-text-active', active && 'system-sm-semibold')} />
</li >
}

View File

@ -82,8 +82,11 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps)
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
getRedirection(isCurrentWorkspaceEditor, app, push)
}
catch {
notify({ type: 'error', message: t('app.newApp.appCreateFailed') })
catch (e: any) {
notify({
type: 'error',
message: e.message || t('app.newApp.appCreateFailed'),
})
}
isCreatingRef.current = false
}, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, push, isCurrentWorkspaceEditor])

View File

@ -106,8 +106,8 @@ const Uploader: FC<Props> = ({
<div className='flex w-full items-center justify-center space-x-2'>
<RiUploadCloud2Line className='h-6 w-6 text-text-tertiary' />
<div className='text-text-tertiary'>
{t('datasetCreation.stepOne.uploader.button')}
<span className='cursor-pointer pl-1 text-text-accent' onClick={selectHandle}>{t('datasetDocuments.list.batchModal.browse')}</span>
{t('app.dslUploader.button')}
<span className='cursor-pointer pl-1 text-text-accent' onClick={selectHandle}>{t('app.dslUploader.browse')}</span>
</div>
</div>
{dragging && <div ref={dragRef} className='absolute left-0 top-0 h-full w-full' />}

View File

@ -117,8 +117,11 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
if (onRefresh)
onRefresh()
}
catch {
notify({ type: 'error', message: t('app.editFailed') })
catch (e: any) {
notify({
type: 'error',
message: e.message || t('app.editFailed'),
})
}
}, [app.id, notify, onRefresh, t])
@ -364,26 +367,20 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
</div>
<div className='title-wrapper h-[90px] px-[14px] text-xs leading-normal text-text-tertiary'>
<div
className={cn(tags.length ? 'line-clamp-2' : 'line-clamp-4', 'group-hover:line-clamp-2')}
className='line-clamp-2'
title={app.description}
>
{app.description}
</div>
</div>
<div className={cn(
'absolute bottom-1 left-0 right-0 h-[42px] shrink-0 items-center pb-[6px] pl-[14px] pr-[6px] pt-1',
tags.length ? 'flex' : '!hidden group-hover:!flex',
)}>
<div className='absolute bottom-1 left-0 right-0 flex h-[42px] shrink-0 items-center pb-[6px] pl-[14px] pr-[6px] pt-1'>
{isCurrentWorkspaceEditor && (
<>
<div className={cn('flex w-0 grow items-center gap-1')} onClick={(e) => {
e.stopPropagation()
e.preventDefault()
}}>
<div className={cn(
'mr-[41px] w-full grow group-hover:!mr-0 group-hover:!block',
tags.length ? '!block' : '!hidden',
)}>
<div className='mr-[41px] w-full grow group-hover:!mr-0'>
<TagSelector
position='bl'
type='app'
@ -395,7 +392,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
/>
</div>
</div>
<div className='mx-1 !hidden h-[14px] w-[1px] shrink-0 group-hover:!flex' />
<div className='mx-1 !hidden h-[14px] w-[1px] shrink-0 bg-divider-regular group-hover:!flex' />
<div className='!hidden shrink-0 group-hover:!flex'>
<CustomPopover
htmlContent={<Operations />}

View File

@ -1,6 +1,6 @@
import React, { useState } from 'react'
import React from 'react'
import Link from 'next/link'
import { RiCloseLine, RiDiscordFill, RiGithubFill } from '@remixicon/react'
import { RiDiscordFill, RiGithubFill } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
type CustomLinkProps = {
@ -26,24 +26,9 @@ const CustomLink = React.memo(({
const Footer = () => {
const { t } = useTranslation()
const [isVisible, setIsVisible] = useState(true)
const handleClose = () => {
setIsVisible(false)
}
if (!isVisible)
return null
return (
<footer className='relative shrink-0 grow-0 px-12 py-2'>
<button
onClick={handleClose}
className='absolute right-2 top-2 flex h-6 w-6 cursor-pointer items-center justify-center rounded-full transition-colors duration-200 ease-in-out hover:bg-components-main-nav-nav-button-bg-active'
aria-label="Close footer"
>
<RiCloseLine className='h-4 w-4 text-text-tertiary hover:text-text-secondary' />
</button>
<h3 className='text-gradient text-xl font-semibold leading-tight'>{t('app.join')}</h3>
<p className='system-sm-regular mt-1 text-text-tertiary'>{t('app.communityIntro')}</p>
<div className='mt-3 flex items-center gap-2'>

View File

@ -1,14 +1,11 @@
'use client'
import { useEducationInit } from '@/app/education-apply/hooks'
import { useGlobalPublicStore } from '@/context/global-public-context'
import List from './list'
import Footer from './footer'
import useDocumentTitle from '@/hooks/use-document-title'
import { useTranslation } from 'react-i18next'
const Apps = () => {
const { t } = useTranslation()
const { systemFeatures } = useGlobalPublicStore()
useDocumentTitle(t('common.menus.apps'))
useEducationInit()
@ -16,9 +13,6 @@ const Apps = () => {
return (
<div className='relative flex h-0 shrink-0 grow flex-col overflow-y-auto bg-background-body'>
<List />
{!systemFeatures.branding.enabled && (
<Footer />
)}
</div >
)
}

View File

@ -32,6 +32,8 @@ import TagFilter from '@/app/components/base/tag-management/filter'
import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label'
import dynamic from 'next/dynamic'
import Empty from './empty'
import Footer from './footer'
import { useGlobalPublicStore } from '@/context/global-public-context'
const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), {
ssr: false,
@ -66,6 +68,7 @@ const getKey = (
const List = () => {
const { t } = useTranslation()
const { systemFeatures } = useGlobalPublicStore()
const router = useRouter()
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator } = useAppContext()
const showTagManagementModal = useTagStore(s => s.showTagManagementModal)
@ -229,6 +232,9 @@ const List = () => {
<span className="system-xs-regular">{t('app.newApp.dropDSLToCreateApp')}</span>
</div>
)}
{!systemFeatures.branding.enabled && (
<Footer />
)}
<CheckModal />
<div ref={anchorRef} className='h-0'> </div>
{showTagManagementModal && (

Some files were not shown because too many files have changed in this diff Show More