mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/model-auth
This commit is contained in:
commit
1a642084b5
File diff suppressed because it is too large
Load Diff
|
|
@ -4,6 +4,23 @@ title: "[Chore/Refactor] "
|
|||
labels:
|
||||
- refactor
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
||||
required: true
|
||||
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report, otherwise it will be closed.
|
||||
required: true
|
||||
- label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :)
|
||||
required: true
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,18 @@
|
|||
name: Check i18n Files and Create PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.ts'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
check-and-update:
|
||||
if: github.event.pull_request.merged == true
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
|
|
@ -15,8 +20,8 @@ jobs:
|
|||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
persist-credentials: false
|
||||
fetch-depth: 2
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
|
|
@ -27,6 +32,13 @@ jobs:
|
|||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
file_args=""
|
||||
for file in $changed_files; do
|
||||
filename=$(basename "$file" .ts)
|
||||
file_args="$file_args --file=$filename"
|
||||
done
|
||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||
echo "File arguments: $file_args"
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
|
@ -49,14 +61,15 @@ jobs:
|
|||
if: env.FILES_CHANGED == 'true'
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
- name: Generate i18n translations
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: pnpm run auto-gen-i18n
|
||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
|
|
|
|||
|
|
@ -215,3 +215,4 @@ mise.toml
|
|||
# AI Assistant
|
||||
.roo/
|
||||
api/.env.backup
|
||||
/clickzetta
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ RUN apt-get update \
|
|||
|
||||
# Install Python dependencies
|
||||
COPY pyproject.toml uv.lock ./
|
||||
RUN uv sync --locked
|
||||
RUN uv sync --locked --no-dev
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import sqlalchemy as sa
|
|||
from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
|
|
@ -181,8 +181,8 @@ def migrate_annotation_vector_database():
|
|||
)
|
||||
if not apps:
|
||||
break
|
||||
except NotFound:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
page += 1
|
||||
for app in apps:
|
||||
|
|
@ -308,8 +308,8 @@ def migrate_knowledge_vector_database():
|
|||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
except NotFound:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
|
|
@ -561,8 +561,8 @@ def old_metadata_migration():
|
|||
.order_by(DatasetDocument.created_at.desc())
|
||||
)
|
||||
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
except NotFound:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if not documents:
|
||||
break
|
||||
for document in documents:
|
||||
|
|
|
|||
|
|
@ -330,17 +330,17 @@ class HttpConfig(BaseSettings):
|
|||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="Maximum connection timeout in seconds for HTTP requests")
|
||||
] = 10
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=60, description="Maximum read timeout in seconds for HTTP requests")
|
||||
] = 60
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="Maximum write timeout in seconds for HTTP requests")
|
||||
] = 20
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size in bytes for binary data in HTTP requests",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
|||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
|
||||
from .storage.clickzetta_volume_storage_config import ClickZettaVolumeStorageConfig
|
||||
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from .storage.oci_storage_config import OCIStorageConfig
|
||||
|
|
@ -20,6 +21,7 @@ from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
|||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
from .vdb.clickzetta_config import ClickzettaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||
|
|
@ -52,6 +54,7 @@ class StorageConfig(BaseSettings):
|
|||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"clickzetta-volume",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
|
|
@ -61,8 +64,9 @@ class StorageConfig(BaseSettings):
|
|||
"local",
|
||||
] = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
|
||||
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', "
|
||||
"'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
|
||||
"'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
default="opendal",
|
||||
)
|
||||
|
||||
|
|
@ -303,6 +307,7 @@ class MiddlewareConfig(
|
|||
AliyunOSSStorageConfig,
|
||||
AzureBlobStorageConfig,
|
||||
BaiduOBSStorageConfig,
|
||||
ClickZettaVolumeStorageConfig,
|
||||
GoogleCloudStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
OCIStorageConfig,
|
||||
|
|
@ -315,6 +320,7 @@ class MiddlewareConfig(
|
|||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
ClickzettaConfig,
|
||||
HuaweiCloudConfig,
|
||||
MilvusConfig,
|
||||
MyScaleConfig,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
"""ClickZetta Volume Storage Configuration"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ClickZettaVolumeStorageConfig(BaseSettings):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field(
|
||||
description="Username for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field(
|
||||
description="Password for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field(
|
||||
description="ClickZetta instance identifier",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SERVICE: str = Field(
|
||||
description="ClickZetta service endpoint",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_WORKSPACE: str = Field(
|
||||
description="ClickZetta workspace name",
|
||||
default="quick_start",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_VCLUSTER: str = Field(
|
||||
description="ClickZetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SCHEMA: str = Field(
|
||||
description="ClickZetta schema name",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TYPE: str = Field(
|
||||
description="ClickZetta volume type (table|user|external)",
|
||||
default="user",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_NAME: Optional[str] = Field(
|
||||
description="ClickZetta volume name for external volumes",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field(
|
||||
description="Prefix for ClickZetta volume table names",
|
||||
default="dataset_",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field(
|
||||
description="Directory prefix for User Volume to organize Dify files",
|
||||
default="dify_km",
|
||||
)
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ClickzettaConfig(BaseModel):
|
||||
"""
|
||||
Clickzetta Lakehouse vector database configuration
|
||||
"""
|
||||
|
||||
CLICKZETTA_USERNAME: Optional[str] = Field(
|
||||
description="Username for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_PASSWORD: Optional[str] = Field(
|
||||
description="Password for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_INSTANCE: Optional[str] = Field(
|
||||
description="Clickzetta Lakehouse instance ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_SERVICE: Optional[str] = Field(
|
||||
description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_WORKSPACE: Optional[str] = Field(
|
||||
description="Clickzetta workspace name",
|
||||
default="default",
|
||||
)
|
||||
|
||||
CLICKZETTA_VCLUSTER: Optional[str] = Field(
|
||||
description="Clickzetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_SCHEMA: Optional[str] = Field(
|
||||
description="Database schema name in Clickzetta",
|
||||
default="public",
|
||||
)
|
||||
|
||||
CLICKZETTA_BATCH_SIZE: Optional[int] = Field(
|
||||
description="Batch size for bulk insert operations",
|
||||
default=100,
|
||||
)
|
||||
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field(
|
||||
description="Enable inverted index for full-text search capabilities",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field(
|
||||
description="Analyzer type for full-text search: keyword, english, chinese, unicode",
|
||||
default="chinese",
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_MODE: Optional[str] = Field(
|
||||
description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)",
|
||||
default="smart",
|
||||
)
|
||||
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field(
|
||||
description="Distance function for vector similarity: l2_distance or cosine_distance",
|
||||
default="cosine_distance",
|
||||
)
|
||||
|
|
@ -225,14 +225,15 @@ class AnnotationBatchImportApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
# check file type
|
||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
|
|
|||
|
|
@ -28,6 +28,12 @@ from services.feature_service import FeatureService
|
|||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class AppListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -94,7 +100,7 @@ class AppListApi(Resource):
|
|||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
|
|
@ -146,7 +152,7 @@ class AppApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
|
|
@ -189,7 +195,7 @@ class AppCopyApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def _validate_name(name):
|
|||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
|
@ -113,7 +113,7 @@ class DatasetListApi(Resource):
|
|||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
|
|
@ -683,6 +683,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
|
@ -731,6 +732,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
| VectorType.TENCENT
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ class FileApi(Resource):
|
|||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
def post(self):
|
||||
file = request.files["file"]
|
||||
source_str = request.form.get("source")
|
||||
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
|
||||
|
||||
|
|
@ -58,6 +57,7 @@ class FileApi(Resource):
|
|||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
file = request.files["file"]
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
|
|
|||
|
|
@ -191,9 +191,6 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
|
@ -201,6 +198,8 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
|||
api = ExternalApi(bp)
|
||||
|
||||
from . import index
|
||||
from .app import annotation, app, audio, completion, conversation, file, message, site, workflow
|
||||
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
|
||||
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
|
||||
from .workspace import models
|
||||
|
|
|
|||
|
|
@ -107,3 +107,15 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
|||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class FileNotFoundError(BaseHTTPException):
|
||||
error_code = "file_not_found"
|
||||
description = "The requested file was not found."
|
||||
code = 404
|
||||
|
||||
|
||||
class FileAccessDeniedError(BaseHTTPException):
|
||||
error_code = "file_access_denied"
|
||||
description = "Access to the requested file is denied."
|
||||
code = 403
|
||||
|
|
|
|||
|
|
@ -20,18 +20,17 @@ class FileApi(Resource):
|
|||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
file = request.files["file"]
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,186 @@
|
|||
import logging
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
FileAccessDeniedError,
|
||||
FileNotFoundError,
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, EndUser, Message, MessageFile, UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FilePreviewApi(Resource):
|
||||
"""
|
||||
Service API File Preview endpoint
|
||||
|
||||
Provides secure file preview/download functionality for external API users.
|
||||
Files can only be accessed if they belong to messages within the requesting app's context.
|
||||
"""
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, file_id: str):
|
||||
"""
|
||||
Preview/Download a file that was uploaded via Service API
|
||||
|
||||
Args:
|
||||
app_model: The authenticated app model
|
||||
end_user: The authenticated end user (optional)
|
||||
file_id: UUID of the file to preview
|
||||
|
||||
Query Parameters:
|
||||
user: Optional user identifier
|
||||
as_attachment: Boolean, whether to download as attachment (default: false)
|
||||
|
||||
Returns:
|
||||
Stream response with file content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: File does not exist
|
||||
FileAccessDeniedError: File access denied (not owned by app)
|
||||
"""
|
||||
file_id = str(file_id)
|
||||
|
||||
# Parse query parameters
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
|
||||
# Get file content generator
|
||||
try:
|
||||
generator = storage.load(upload_file.key, stream=True)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"Failed to load file content: {str(e)}")
|
||||
|
||||
# Build response with appropriate headers
|
||||
response = self._build_file_response(generator, upload_file, args["as_attachment"])
|
||||
|
||||
return response
|
||||
|
||||
def _validate_file_ownership(self, file_id: str, app_id: str) -> tuple[MessageFile, UploadFile]:
|
||||
"""
|
||||
Validate that the file belongs to a message within the requesting app's context
|
||||
|
||||
Security validations performed:
|
||||
1. File exists in MessageFile table (was used in a conversation)
|
||||
2. Message belongs to the requesting app
|
||||
3. UploadFile record exists and is accessible
|
||||
4. File tenant matches app tenant (additional security layer)
|
||||
|
||||
Args:
|
||||
file_id: UUID of the file to validate
|
||||
app_id: UUID of the requesting app
|
||||
|
||||
Returns:
|
||||
Tuple of (MessageFile, UploadFile) if validation passes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: File or related records not found
|
||||
FileAccessDeniedError: File does not belong to the app's context
|
||||
"""
|
||||
try:
|
||||
# Input validation
|
||||
if not file_id or not app_id:
|
||||
raise FileAccessDeniedError("Invalid file or app identifier")
|
||||
|
||||
# First, find the MessageFile that references this upload file
|
||||
message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first()
|
||||
|
||||
if not message_file:
|
||||
raise FileNotFoundError("File not found in message context")
|
||||
|
||||
# Get the message and verify it belongs to the requesting app
|
||||
message = (
|
||||
db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise FileAccessDeniedError("File access denied: not owned by requesting app")
|
||||
|
||||
# Get the actual upload file record
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise FileNotFoundError("Upload file record not found")
|
||||
|
||||
# Additional security: verify tenant isolation
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if app and upload_file.tenant_id != app.tenant_id:
|
||||
raise FileAccessDeniedError("File access denied: tenant mismatch")
|
||||
|
||||
return message_file, upload_file
|
||||
|
||||
except (FileNotFoundError, FileAccessDeniedError):
|
||||
# Re-raise our custom exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log unexpected errors for debugging
|
||||
logger.exception(
|
||||
"Unexpected error during file ownership validation",
|
||||
extra={"file_id": file_id, "app_id": app_id, "error": str(e)},
|
||||
)
|
||||
raise FileAccessDeniedError("File access validation failed")
|
||||
|
||||
def _build_file_response(self, generator, upload_file: UploadFile, as_attachment: bool = False) -> Response:
|
||||
"""
|
||||
Build Flask Response object with appropriate headers for file streaming
|
||||
|
||||
Args:
|
||||
generator: File content generator from storage
|
||||
upload_file: UploadFile database record
|
||||
as_attachment: Whether to set Content-Disposition as attachment
|
||||
|
||||
Returns:
|
||||
Flask Response object with streaming file content
|
||||
"""
|
||||
response = Response(
|
||||
generator,
|
||||
mimetype=upload_file.mime_type,
|
||||
direct_passthrough=True,
|
||||
headers={},
|
||||
)
|
||||
|
||||
# Add Content-Length if known
|
||||
if upload_file.size and upload_file.size > 0:
|
||||
response.headers["Content-Length"] = str(upload_file.size)
|
||||
|
||||
# Add Accept-Ranges header for audio/video files to support seeking
|
||||
if upload_file.mime_type in [
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/mp4",
|
||||
"audio/ogg",
|
||||
"audio/flac",
|
||||
"audio/aac",
|
||||
"video/mp4",
|
||||
"video/webm",
|
||||
"video/quicktime",
|
||||
"audio/x-m4a",
|
||||
]:
|
||||
response.headers["Accept-Ranges"] = "bytes"
|
||||
|
||||
# Set Content-Disposition for downloads
|
||||
if as_attachment and upload_file.name:
|
||||
encoded_filename = quote(upload_file.name)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
# Override content-type for downloads to force download
|
||||
response.headers["Content-Type"] = "application/octet-stream"
|
||||
|
||||
# Add caching headers for performance
|
||||
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Register the API endpoint
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
|
|
@ -29,7 +29,7 @@ def _validate_name(name):
|
|||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ class DatasetListApi(DatasetApiResource):
|
|||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
|
|
|
|||
|
|
@ -234,8 +234,6 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
|
@ -243,6 +241,8 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask import request
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.web import api
|
||||
|
|
@ -75,14 +76,14 @@ class AppWebAuthPermission(Resource):
|
|||
try:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None:
|
||||
raise
|
||||
raise Unauthorized("Authorization header is missing.")
|
||||
if " " not in auth_header:
|
||||
raise
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
auth_scheme, tk = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != "bearer":
|
||||
raise
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
decoded = PassportService().verify(tk)
|
||||
user_id = decoded.get("user_id", "visitor")
|
||||
|
|
|
|||
|
|
@ -12,18 +12,17 @@ from services.file_service import FileService
|
|||
class FileApi(WebApiResource):
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
file = request.files["file"]
|
||||
source = request.form.get("source")
|
||||
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
source = request.form.get("source")
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
|
|
|
|||
|
|
@ -118,26 +118,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
):
|
||||
return
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id,
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
db_conversation_variables = session.scalars(stmt).all()
|
||||
if not db_conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
db_conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in self._workflow.conversation_variables
|
||||
]
|
||||
session.add_all(db_conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in db_conversation_variables]
|
||||
|
||||
session.commit()
|
||||
# Initialize conversation variables
|
||||
conversation_variables = self._initialize_conversation_variables()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = SystemVariable(
|
||||
|
|
@ -292,3 +274,100 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
def _initialize_conversation_variables(self) -> list[VariableUnion]:
|
||||
"""
|
||||
Initialize conversation variables for the current conversation.
|
||||
|
||||
This method:
|
||||
1. Loads existing variables from the database
|
||||
2. Creates new variables if none exist
|
||||
3. Syncs missing variables from the workflow definition
|
||||
|
||||
:return: List of conversation variables ready for use
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
existing_variables = self._load_existing_conversation_variables(session)
|
||||
|
||||
if not existing_variables:
|
||||
# First time initialization - create all variables
|
||||
existing_variables = self._create_all_conversation_variables(session)
|
||||
else:
|
||||
# Check and add any missing variables from the workflow
|
||||
existing_variables = self._sync_missing_conversation_variables(session, existing_variables)
|
||||
|
||||
# Convert to Variable objects for use in the workflow
|
||||
conversation_variables = [var.to_variable() for var in existing_variables]
|
||||
|
||||
session.commit()
|
||||
return cast(list[VariableUnion], conversation_variables)
|
||||
|
||||
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
||||
"""
|
||||
Load existing conversation variables from the database.
|
||||
|
||||
:param session: Database session
|
||||
:return: List of existing conversation variables
|
||||
"""
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id,
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
return list(session.scalars(stmt).all())
|
||||
|
||||
def _create_all_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
||||
"""
|
||||
Create all conversation variables for a new conversation.
|
||||
|
||||
:param session: Database session
|
||||
:return: List of created conversation variables
|
||||
"""
|
||||
new_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in self._workflow.conversation_variables
|
||||
]
|
||||
|
||||
if new_variables:
|
||||
session.add_all(new_variables)
|
||||
|
||||
return new_variables
|
||||
|
||||
def _sync_missing_conversation_variables(
|
||||
self, session: Session, existing_variables: list[ConversationVariable]
|
||||
) -> list[ConversationVariable]:
|
||||
"""
|
||||
Sync missing conversation variables from the workflow definition.
|
||||
|
||||
This handles the case where new variables are added to a workflow
|
||||
after conversations have already been created.
|
||||
|
||||
:param session: Database session
|
||||
:param existing_variables: List of existing conversation variables
|
||||
:return: Updated list including any newly created variables
|
||||
"""
|
||||
# Get IDs of existing and workflow variables
|
||||
existing_ids = {var.id for var in existing_variables}
|
||||
workflow_variables = {var.id: var for var in self._workflow.conversation_variables}
|
||||
|
||||
# Find missing variable IDs
|
||||
missing_ids = set(workflow_variables.keys()) - existing_ids
|
||||
|
||||
if not missing_ids:
|
||||
return existing_variables
|
||||
|
||||
# Create missing variables with their default values
|
||||
new_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id,
|
||||
conversation_id=self.conversation.id,
|
||||
variable=workflow_variables[var_id],
|
||||
)
|
||||
for var_id in missing_ids
|
||||
]
|
||||
|
||||
session.add_all(new_variables)
|
||||
|
||||
# Return combined list
|
||||
return existing_variables + new_variables
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from core.app.entities.task_entities import (
|
|||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
StreamEvent,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
|
|
@ -180,11 +181,15 @@ class MessageCycleManager:
|
|||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first()
|
||||
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
|
||||
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
from_variable_selector=from_variable_selector,
|
||||
event=event_type,
|
||||
)
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
|
|
|
|||
|
|
@ -843,7 +843,7 @@ class ProviderConfiguration(BaseModel):
|
|||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
if m.model in model_setting_map:
|
||||
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
||||
model_setting = model_setting_map[m.model_type][m.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
|
|
|||
|
|
@ -121,9 +121,8 @@ class TokenBufferMemory:
|
|||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
pruned_memory.append(prompt_messages.pop(0))
|
||||
prompt_messages.pop(0)
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
|
|
|||
|
|
@ -0,0 +1,190 @@
|
|||
# Clickzetta Vector Database Integration
|
||||
|
||||
This module provides integration with Clickzetta Lakehouse as a vector database for Dify.
|
||||
|
||||
## Features
|
||||
|
||||
- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type
|
||||
- **Vector Search**: Efficient similarity search using HNSW algorithm
|
||||
- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities
|
||||
- **Hybrid Search**: Combine vector similarity and full-text search for better results
|
||||
- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing
|
||||
- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments
|
||||
|
||||
## Configuration
|
||||
|
||||
### Required Environment Variables
|
||||
|
||||
All seven configuration parameters are required:
|
||||
|
||||
```bash
|
||||
# Authentication
|
||||
CLICKZETTA_USERNAME=your_username
|
||||
CLICKZETTA_PASSWORD=your_password
|
||||
|
||||
# Instance configuration
|
||||
CLICKZETTA_INSTANCE=your_instance_id
|
||||
CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
CLICKZETTA_WORKSPACE=your_workspace
|
||||
CLICKZETTA_VCLUSTER=your_vcluster
|
||||
CLICKZETTA_SCHEMA=your_schema
|
||||
```
|
||||
|
||||
### Optional Configuration
|
||||
|
||||
```bash
|
||||
# Batch processing
|
||||
CLICKZETTA_BATCH_SIZE=100
|
||||
|
||||
# Full-text search configuration
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX=true
|
||||
CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode
|
||||
CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart
|
||||
|
||||
# Vector search configuration
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Set Clickzetta as the Vector Store
|
||||
|
||||
In your Dify configuration, set:
|
||||
|
||||
```bash
|
||||
VECTOR_STORE=clickzetta
|
||||
```
|
||||
|
||||
### 2. Table Structure
|
||||
|
||||
Clickzetta will automatically create tables with the following structure:
|
||||
|
||||
```sql
|
||||
CREATE TABLE <collection_name> (
|
||||
id STRING NOT NULL,
|
||||
content STRING NOT NULL,
|
||||
metadata JSON,
|
||||
vector VECTOR(FLOAT, <dimension>) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
-- Vector index for similarity search
|
||||
CREATE VECTOR INDEX idx_<collection_name>_vec
|
||||
ON TABLE <schema>.<collection_name>(vector)
|
||||
PROPERTIES (
|
||||
"distance.function" = "cosine_distance",
|
||||
"scalar.type" = "f32"
|
||||
);
|
||||
|
||||
-- Inverted index for full-text search (if enabled)
|
||||
CREATE INVERTED INDEX idx_<collection_name>_text
|
||||
ON <schema>.<collection_name>(content)
|
||||
PROPERTIES (
|
||||
"analyzer" = "chinese",
|
||||
"mode" = "smart"
|
||||
);
|
||||
```
|
||||
|
||||
## Full-Text Search Capabilities
|
||||
|
||||
Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
|
||||
### Analyzer Types
|
||||
|
||||
1. **keyword**: No tokenization, treats the entire string as a single token
|
||||
- Best for: Exact matching, IDs, codes
|
||||
|
||||
2. **english**: Designed for English text
|
||||
- Features: Recognizes ASCII letters and numbers, converts to lowercase
|
||||
- Best for: English content
|
||||
|
||||
3. **chinese**: Chinese text tokenizer
|
||||
- Features: Recognizes Chinese and English characters, removes punctuation
|
||||
- Best for: Chinese or mixed Chinese-English content
|
||||
|
||||
4. **unicode**: Multi-language tokenizer based on Unicode
|
||||
- Features: Recognizes text boundaries in multiple languages
|
||||
- Best for: Multi-language content
|
||||
|
||||
### Analyzer Modes
|
||||
|
||||
- **max_word**: Fine-grained tokenization (more tokens)
|
||||
- **smart**: Intelligent tokenization (balanced)
|
||||
|
||||
### Full-Text Search Functions
|
||||
|
||||
- `MATCH_ALL(column, query)`: All terms must be present
|
||||
- `MATCH_ANY(column, query)`: At least one term must be present
|
||||
- `MATCH_PHRASE(column, query)`: Exact phrase matching
|
||||
- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching
|
||||
- `MATCH_REGEXP(column, pattern)`: Regular expression matching
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Vector Search
|
||||
|
||||
1. **Adjust exploration factor** for accuracy vs speed trade-off:
|
||||
```sql
|
||||
SET cz.vector.index.search.ef=64;
|
||||
```
|
||||
|
||||
2. **Use appropriate distance functions**:
|
||||
- `cosine_distance`: Best for normalized embeddings (e.g., from language models)
|
||||
- `l2_distance`: Best for raw feature vectors
|
||||
|
||||
### Full-Text Search
|
||||
|
||||
1. **Choose the right analyzer**:
|
||||
- Use `keyword` for exact matching
|
||||
- Use language-specific analyzers for better tokenization
|
||||
|
||||
2. **Combine with vector search**:
|
||||
- Pre-filter with full-text search for better performance
|
||||
- Use hybrid search for improved relevance
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Issues
|
||||
|
||||
1. Verify all 7 required configuration parameters are set
|
||||
2. Check network connectivity to Clickzetta service
|
||||
3. Ensure the user has proper permissions on the schema
|
||||
|
||||
### Search Performance
|
||||
|
||||
1. Verify vector index exists:
|
||||
```sql
|
||||
SHOW INDEX FROM <schema>.<table_name>;
|
||||
```
|
||||
|
||||
2. Check if vector index is being used:
|
||||
```sql
|
||||
EXPLAIN SELECT ... WHERE l2_distance(...) < threshold;
|
||||
```
|
||||
Look for `vector_index_search_type` in the execution plan.
|
||||
|
||||
### Full-Text Search Not Working
|
||||
|
||||
1. Verify inverted index is created
|
||||
2. Check analyzer configuration matches your content language
|
||||
3. Use `TOKENIZE()` function to test tokenization:
|
||||
```sql
|
||||
SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart'));
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns
|
||||
2. Full-text search relevance scores are not provided by Clickzetta
|
||||
3. Inverted index creation may fail for very large existing tables (continue without error)
|
||||
4. Index naming constraints:
|
||||
- Index names must be unique within a schema
|
||||
- Only one vector index can be created per column
|
||||
- The implementation uses timestamps to ensure unique index names
|
||||
5. A column can only have one vector index at a time
|
||||
|
||||
## References
|
||||
|
||||
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
|
||||
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
|
||||
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)
|
||||
|
|
@ -0,0 +1 @@
|
|||
# Clickzetta Vector Database Integration for Dify
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -246,6 +246,10 @@ class TencentVector(BaseVector):
|
|||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = None
|
||||
if document_ids_filter:
|
||||
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
|
||||
if not self._enable_hybrid_search:
|
||||
return []
|
||||
res = self._client.hybrid_search(
|
||||
|
|
@ -269,6 +273,7 @@ class TencentVector(BaseVector):
|
|||
),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
filter=filter,
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
|
|
|||
|
|
@ -172,6 +172,10 @@ class Vector:
|
|||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
|
||||
|
||||
return MatrixoneVectorFactory
|
||||
case VectorType.CLICKZETTA:
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
|
||||
|
||||
return ClickzettaVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
|
|
|||
|
|
@ -30,3 +30,4 @@ class VectorType(StrEnum):
|
|||
TABLESTORE = "tablestore"
|
||||
HUAWEI_CLOUD = "huawei_cloud"
|
||||
MATRIXONE = "matrixone"
|
||||
CLICKZETTA = "clickzetta"
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class WordExtractor(BaseExtractor):
|
|||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
content = self.parse_docx(self.file_path, "storage")
|
||||
content = self.parse_docx(self.file_path)
|
||||
return [
|
||||
Document(
|
||||
page_content=content,
|
||||
|
|
@ -189,23 +189,8 @@ class WordExtractor(BaseExtractor):
|
|||
paragraph_content.append(run.text)
|
||||
return "".join(paragraph_content).strip()
|
||||
|
||||
def _parse_paragraph(self, paragraph, image_map):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if run.element.xpath(".//a:blip"):
|
||||
for blip in run.element.xpath(".//a:blip"):
|
||||
embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if embed_id:
|
||||
rel_target = run.part.rels[embed_id].target_ref
|
||||
if rel_target in image_map:
|
||||
paragraph_content.append(image_map[rel_target])
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
return " ".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
def parse_docx(self, docx_path, image_folder):
|
||||
def parse_docx(self, docx_path):
|
||||
doc = DocxDocument(docx_path)
|
||||
os.makedirs(image_folder, exist_ok=True)
|
||||
|
||||
content = []
|
||||
|
||||
|
|
|
|||
|
|
@ -5,14 +5,13 @@ from __future__ import annotations
|
|||
from typing import Any, Optional
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
from core.rag.splitter.text_splitter import (
|
||||
TS,
|
||||
Collection,
|
||||
Literal,
|
||||
RecursiveCharacterTextSplitter,
|
||||
Set,
|
||||
TokenTextSplitter,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
|
@ -45,14 +44,6 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
|||
|
||||
return [len(text) for text in texts]
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
|
||||
"allowed_special": allowed_special,
|
||||
"disallowed_special": disallowed_special,
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=_character_encoder, **kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -37,12 +37,12 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
|||
@staticmethod
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||
try:
|
||||
if local_tz is None:
|
||||
local_tz = datetime.now().astimezone().tzinfo
|
||||
if isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
local_time = datetime.strptime(localtime, time_format)
|
||||
localtime = local_tz.localize(local_time) # type: ignore
|
||||
if local_tz is None:
|
||||
localtime = local_time.astimezone() # type: ignore
|
||||
elif isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
localtime = local_tz.localize(local_time) # type: ignore
|
||||
timestamp = int(localtime.timestamp()) # type: ignore
|
||||
return timestamp
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from os import getenv
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
|
@ -20,6 +21,20 @@ API_TOOL_DEFAULT_TIMEOUT = (
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedResponse:
|
||||
"""Represents a parsed HTTP response with type information"""
|
||||
|
||||
content: Union[str, dict]
|
||||
is_json: bool
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Convert response to string format for credential validation"""
|
||||
if isinstance(self.content, dict):
|
||||
return json.dumps(self.content, ensure_ascii=False)
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class ApiTool(Tool):
|
||||
"""
|
||||
Api tool
|
||||
|
|
@ -58,7 +73,9 @@ class ApiTool(Tool):
|
|||
|
||||
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
|
||||
# validate response
|
||||
return self.validate_and_parse_response(response)
|
||||
parsed_response = self.validate_and_parse_response(response)
|
||||
# For credential validation, always return as string
|
||||
return parsed_response.to_string()
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API
|
||||
|
|
@ -112,23 +129,36 @@ class ApiTool(Tool):
|
|||
|
||||
return headers
|
||||
|
||||
def validate_and_parse_response(self, response: httpx.Response) -> str:
|
||||
def validate_and_parse_response(self, response: httpx.Response) -> ParsedResponse:
|
||||
"""
|
||||
validate the response
|
||||
validate the response and return parsed content with type information
|
||||
|
||||
:return: ParsedResponse with content and is_json flag
|
||||
"""
|
||||
if isinstance(response, httpx.Response):
|
||||
if response.status_code >= 400:
|
||||
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
|
||||
if not response.content:
|
||||
return "Empty response from the tool, please check your parameters and try again."
|
||||
return ParsedResponse(
|
||||
"Empty response from the tool, please check your parameters and try again.", False
|
||||
)
|
||||
|
||||
# Check content type
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
is_json_content_type = "application/json" in content_type
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
response = response.json()
|
||||
try:
|
||||
return json.dumps(response, ensure_ascii=False)
|
||||
except Exception:
|
||||
return json.dumps(response)
|
||||
json_response = response.json()
|
||||
# If content-type indicates JSON, return as JSON object
|
||||
if is_json_content_type:
|
||||
return ParsedResponse(json_response, True)
|
||||
else:
|
||||
# If content-type doesn't indicate JSON, treat as text regardless of content
|
||||
return ParsedResponse(response.text, False)
|
||||
except Exception:
|
||||
return response.text
|
||||
# Not valid JSON, return as text
|
||||
return ParsedResponse(response.text, False)
|
||||
else:
|
||||
raise ValueError(f"Invalid response type {type(response)}")
|
||||
|
||||
|
|
@ -369,7 +399,14 @@ class ApiTool(Tool):
|
|||
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
|
||||
|
||||
# validate response
|
||||
response = self.validate_and_parse_response(response)
|
||||
parsed_response = self.validate_and_parse_response(response)
|
||||
|
||||
# assemble invoke message
|
||||
yield self.create_text_message(response)
|
||||
# assemble invoke message based on response type
|
||||
if parsed_response.is_json and isinstance(parsed_response.content, dict):
|
||||
yield self.create_json_message(parsed_response.content)
|
||||
else:
|
||||
# Convert to string if needed and create text message
|
||||
text_response = (
|
||||
parsed_response.content if isinstance(parsed_response.content, str) else str(parsed_response.content)
|
||||
)
|
||||
yield self.create_text_message(text_response)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from core.tools.errors import (
|
|||
ToolProviderCredentialValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
|
|
@ -247,7 +247,8 @@ class ToolEngine:
|
|||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
result += json.dumps(
|
||||
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
||||
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
else:
|
||||
result += str(response.message)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,14 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
|
@ -10,6 +17,41 @@ from core.tools.tool_file_manager import ToolFileManager
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_json_value(v):
|
||||
if isinstance(v, datetime):
|
||||
tz_name = getattr(current_user, "timezone", None) if current_user is not None else None
|
||||
if not tz_name:
|
||||
tz_name = "UTC"
|
||||
return v.astimezone(pytz.timezone(tz_name)).isoformat()
|
||||
elif isinstance(v, date):
|
||||
return v.isoformat()
|
||||
elif isinstance(v, UUID):
|
||||
return str(v)
|
||||
elif isinstance(v, Decimal):
|
||||
return float(v)
|
||||
elif isinstance(v, bytes):
|
||||
try:
|
||||
return v.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return v.hex()
|
||||
elif isinstance(v, memoryview):
|
||||
return v.tobytes().hex()
|
||||
elif isinstance(v, np.ndarray):
|
||||
return v.tolist()
|
||||
elif isinstance(v, dict):
|
||||
return safe_json_dict(v)
|
||||
elif isinstance(v, list | tuple | set):
|
||||
return [safe_json_value(i) for i in v]
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
def safe_json_dict(d):
|
||||
if not isinstance(d, dict):
|
||||
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
|
||||
return {k: safe_json_value(v) for k, v in d.items()}
|
||||
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(
|
||||
|
|
@ -113,6 +155,12 @@ class ToolFileMessageTransformer:
|
|||
)
|
||||
else:
|
||||
yield message
|
||||
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
|
||||
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
|
||||
json_msg.json_object = safe_json_value(json_msg.json_object)
|
||||
yield message
|
||||
else:
|
||||
yield message
|
||||
|
||||
|
|
|
|||
|
|
@ -119,6 +119,13 @@ class ObjectSegment(Segment):
|
|||
|
||||
|
||||
class ArraySegment(Segment):
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# Return empty string for empty arrays instead of "[]"
|
||||
if not self.value:
|
||||
return ""
|
||||
return super().text
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
items = []
|
||||
|
|
@ -155,6 +162,9 @@ class ArrayStringSegment(ArraySegment):
|
|||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# Return empty string for empty arrays instead of "[]"
|
||||
if not self.value:
|
||||
return ""
|
||||
return json.dumps(self.value, ensure_ascii=False)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -168,7 +168,57 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
|||
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
|
||||
"""Extract text from a file based on its file extension."""
|
||||
match file_extension:
|
||||
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml":
|
||||
case (
|
||||
".txt"
|
||||
| ".markdown"
|
||||
| ".md"
|
||||
| ".html"
|
||||
| ".htm"
|
||||
| ".xml"
|
||||
| ".c"
|
||||
| ".h"
|
||||
| ".cpp"
|
||||
| ".hpp"
|
||||
| ".cc"
|
||||
| ".cxx"
|
||||
| ".c++"
|
||||
| ".py"
|
||||
| ".js"
|
||||
| ".ts"
|
||||
| ".jsx"
|
||||
| ".tsx"
|
||||
| ".java"
|
||||
| ".php"
|
||||
| ".rb"
|
||||
| ".go"
|
||||
| ".rs"
|
||||
| ".swift"
|
||||
| ".kt"
|
||||
| ".scala"
|
||||
| ".sh"
|
||||
| ".bash"
|
||||
| ".bat"
|
||||
| ".ps1"
|
||||
| ".sql"
|
||||
| ".r"
|
||||
| ".m"
|
||||
| ".pl"
|
||||
| ".lua"
|
||||
| ".vim"
|
||||
| ".asm"
|
||||
| ".s"
|
||||
| ".css"
|
||||
| ".scss"
|
||||
| ".less"
|
||||
| ".sass"
|
||||
| ".ini"
|
||||
| ".cfg"
|
||||
| ".conf"
|
||||
| ".toml"
|
||||
| ".env"
|
||||
| ".log"
|
||||
| ".vtt"
|
||||
):
|
||||
return _extract_text_from_plain_text(file_content)
|
||||
case ".json":
|
||||
return _extract_text_from_json(file_content)
|
||||
|
|
@ -194,8 +244,6 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
|||
return _extract_text_from_eml(file_content)
|
||||
case ".msg":
|
||||
return _extract_text_from_msg(file_content)
|
||||
case ".vtt":
|
||||
return _extract_text_from_vtt(file_content)
|
||||
case ".properties":
|
||||
return _extract_text_from_properties(file_content)
|
||||
case _:
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class Executor:
|
|||
self.auth = node_data.authorization
|
||||
self.timeout = timeout
|
||||
self.ssl_verify = node_data.ssl_verify
|
||||
self.params = []
|
||||
self.params = None
|
||||
self.headers = {}
|
||||
self.content = None
|
||||
self.files = None
|
||||
|
|
@ -139,7 +139,8 @@ class Executor:
|
|||
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
|
||||
)
|
||||
|
||||
self.params = result
|
||||
if result:
|
||||
self.params = result
|
||||
|
||||
def _init_headers(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -69,6 +69,19 @@ class Storage:
|
|||
from extensions.storage.supabase_storage import SupabaseStorage
|
||||
|
||||
return SupabaseStorage
|
||||
case StorageType.CLICKZETTA_VOLUME:
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
def create_clickzetta_volume_storage():
|
||||
# ClickZettaVolumeConfig will automatically read from environment variables
|
||||
# and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set
|
||||
volume_config = ClickZettaVolumeConfig()
|
||||
return ClickZettaVolumeStorage(volume_config)
|
||||
|
||||
return create_clickzetta_volume_storage
|
||||
case _:
|
||||
raise ValueError(f"unsupported storage type {storage_type}")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
from .clickzetta_volume_storage import ClickZettaVolumeStorage
|
||||
|
||||
__all__ = ["ClickZettaVolumeStorage"]
|
||||
|
|
@ -0,0 +1,530 @@
|
|||
"""ClickZetta Volume Storage Implementation
|
||||
|
||||
This module provides storage backend using ClickZetta Volume functionality.
|
||||
Supports Table Volume, User Volume, and External Volume types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import clickzetta # type: ignore[import]
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
from .volume_permissions import VolumePermissionManager, check_volume_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClickZettaVolumeConfig(BaseModel):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
instance: str = ""
|
||||
service: str = "api.clickzetta.com"
|
||||
workspace: str = "quick_start"
|
||||
vcluster: str = "default_ap"
|
||||
schema_name: str = "dify"
|
||||
volume_type: str = "table" # table|user|external
|
||||
volume_name: Optional[str] = None # For external volumes
|
||||
table_prefix: str = "dataset_" # Prefix for table volume names
|
||||
dify_prefix: str = "dify_km" # Directory prefix for User Volume
|
||||
permission_check: bool = True # Enable/disable permission checking
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""Validate the configuration values.
|
||||
|
||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
then fall back to CLICKZETTA_* environment variables (for vector DB config).
|
||||
"""
|
||||
import os
|
||||
|
||||
# Helper function to get environment variable with fallback
|
||||
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
|
||||
# First try CLICKZETTA_VOLUME_* specific config
|
||||
volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
|
||||
if volume_value:
|
||||
return str(volume_value)
|
||||
|
||||
# Then try environment variables
|
||||
volume_env = os.getenv(volume_key)
|
||||
if volume_env:
|
||||
return volume_env
|
||||
|
||||
# Fall back to existing CLICKZETTA_* config
|
||||
fallback_env = os.getenv(fallback_key)
|
||||
if fallback_env:
|
||||
return fallback_env
|
||||
|
||||
return default or ""
|
||||
|
||||
# Apply environment variables with fallback to existing CLICKZETTA_* config
|
||||
values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
|
||||
values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
|
||||
values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
|
||||
values.setdefault(
|
||||
"service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
|
||||
)
|
||||
values.setdefault(
|
||||
"workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
|
||||
)
|
||||
values.setdefault(
|
||||
"vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
|
||||
)
|
||||
values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
|
||||
|
||||
# Volume-specific configurations (no fallback to vector DB config)
|
||||
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
|
||||
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
|
||||
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
|
||||
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
|
||||
# 暂时禁用权限检查功能,直接设置为false
|
||||
values.setdefault("permission_check", False)
|
||||
|
||||
# Validate required fields
|
||||
if not values.get("username"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required")
|
||||
if not values.get("instance"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required")
|
||||
|
||||
# Validate volume type
|
||||
volume_type = values["volume_type"]
|
||||
if volume_type not in ["table", "user", "external"]:
|
||||
raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external")
|
||||
|
||||
if volume_type == "external" and not values.get("volume_name"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class ClickZettaVolumeStorage(BaseStorage):
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
def __init__(self, config: ClickZettaVolumeConfig):
|
||||
"""Initialize ClickZetta Volume storage.
|
||||
|
||||
Args:
|
||||
config: ClickZetta Volume configuration
|
||||
"""
|
||||
self._config = config
|
||||
self._connection = None
|
||||
self._permission_manager: VolumePermissionManager | None = None
|
||||
self._init_connection()
|
||||
self._init_permission_manager()
|
||||
|
||||
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
|
||||
|
||||
def _init_connection(self):
|
||||
"""Initialize ClickZetta connection."""
|
||||
try:
|
||||
self._connection = clickzetta.connect(
|
||||
username=self._config.username,
|
||||
password=self._config.password,
|
||||
instance=self._config.instance,
|
||||
service=self._config.service,
|
||||
workspace=self._config.workspace,
|
||||
vcluster=self._config.vcluster,
|
||||
schema=self._config.schema_name,
|
||||
)
|
||||
logger.debug("ClickZetta connection established")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to connect to ClickZetta")
|
||||
raise
|
||||
|
||||
def _init_permission_manager(self):
|
||||
"""Initialize permission manager."""
|
||||
try:
|
||||
self._permission_manager = VolumePermissionManager(
|
||||
self._connection, self._config.volume_type, self._config.volume_name
|
||||
)
|
||||
logger.debug("Permission manager initialized")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to initialize permission manager")
|
||||
raise
|
||||
|
||||
def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str:
|
||||
"""Get the appropriate volume path based on volume type."""
|
||||
if self._config.volume_type == "user":
|
||||
# Add dify prefix for User Volume to organize files
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
elif self._config.volume_type == "table":
|
||||
# Check if this should use User Volume (special directories)
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
# Use User Volume with dify prefix for special directories
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
|
||||
if dataset_id:
|
||||
return f"{self._config.table_prefix}{dataset_id}/{filename}"
|
||||
else:
|
||||
# Extract dataset_id from filename if not provided
|
||||
# Format: dataset_id/filename
|
||||
if "/" in filename:
|
||||
return filename
|
||||
else:
|
||||
raise ValueError("dataset_id is required for table volume or filename must include dataset_id/")
|
||||
elif self._config.volume_type == "external":
|
||||
return filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str:
|
||||
"""Get SQL prefix for volume operations."""
|
||||
if self._config.volume_type == "user":
|
||||
return "USER VOLUME"
|
||||
elif self._config.volume_type == "table":
|
||||
# For Dify's current file storage pattern, most files are stored in
|
||||
# paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext"
|
||||
# These should use USER VOLUME for better compatibility
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return "USER VOLUME"
|
||||
|
||||
# Only use TABLE VOLUME for actual dataset-specific paths
|
||||
# like "dataset_12345/file.pdf" or paths with dataset_ prefix
|
||||
if dataset_id:
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
else:
|
||||
# Default table name for generic operations
|
||||
table_name = "default_dataset"
|
||||
return f"TABLE VOLUME {table_name}"
|
||||
elif self._config.volume_type == "external":
|
||||
return f"VOLUME {self._config.volume_name}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _execute_sql(self, sql: str, fetch: bool = False):
|
||||
"""Execute SQL command."""
|
||||
try:
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Connection not initialized")
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
if fetch:
|
||||
return cursor.fetchall()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("SQL execution failed: %s", sql)
|
||||
raise
|
||||
|
||||
def _ensure_table_volume_exists(self, dataset_id: str) -> None:
|
||||
"""Ensure table volume exists for the given dataset_id."""
|
||||
if self._config.volume_type != "table" or not dataset_id:
|
||||
return
|
||||
|
||||
# Skip for upload_files and other special directories that use USER VOLUME
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return
|
||||
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
|
||||
try:
|
||||
# Check if table exists
|
||||
check_sql = f"SHOW TABLES LIKE '{table_name}'"
|
||||
result = self._execute_sql(check_sql, fetch=True)
|
||||
|
||||
if not result:
|
||||
# Create table with volume
|
||||
create_sql = f"""
|
||||
CREATE TABLE {table_name} (
|
||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||
filename VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_filename (filename)
|
||||
) WITH VOLUME
|
||||
"""
|
||||
self._execute_sql(create_sql)
|
||||
logger.info("Created table volume: %s", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create table volume %s: %s", table_name, e)
|
||||
# Don't raise exception, let the operation continue
|
||||
# The table might exist but not be visible due to permissions
|
||||
|
||||
def save(self, filename: str, data: bytes) -> None:
|
||||
"""Save data to ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
data: File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Ensure table volume exists (for table volumes)
|
||||
if dataset_id:
|
||||
self._ensure_table_volume_exists(dataset_id)
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "save", dataset_id)
|
||||
|
||||
# Write data to temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_file.write(data)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Upload to volume
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
Path(temp_file_path).unlink(missing_ok=True)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
"""Load file content from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "load_once", dataset_id)
|
||||
|
||||
# Download to temporary directory
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'"
|
||||
else:
|
||||
sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
# Find the downloaded file (may be in subdirectories)
|
||||
downloaded_file = None
|
||||
for root, dirs, files in os.walk(temp_dir):
|
||||
for file in files:
|
||||
if file == filename or file == os.path.basename(filename):
|
||||
downloaded_file = Path(root) / file
|
||||
break
|
||||
if downloaded_file:
|
||||
break
|
||||
|
||||
if not downloaded_file or not downloaded_file.exists():
|
||||
raise FileNotFoundError(f"Downloaded file not found: {filename}")
|
||||
|
||||
content = downloaded_file.read_bytes()
|
||||
|
||||
logger.debug("File %s loaded from ClickZetta Volume", filename)
|
||||
return content
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
"""Load file as stream from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Yields:
|
||||
File content chunks
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
batch_size = 4096
|
||||
stream = BytesIO(content)
|
||||
|
||||
while chunk := stream.read(batch_size):
|
||||
yield chunk
|
||||
|
||||
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
|
||||
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
"""Download file from ClickZetta Volume to local path.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
target_filepath: Local target file path
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
|
||||
with Path(target_filepath).open("wb") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
|
||||
|
||||
def exists(self, filename: str) -> bool:
|
||||
"""Check if file exists in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
True if file exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
exists = len(rows) > 0
|
||||
logger.debug("File %s exists check: %s", filename, exists)
|
||||
return exists
|
||||
except Exception as e:
|
||||
logger.warning("Error checking file existence for %s: %s", filename, e)
|
||||
return False
|
||||
|
||||
def delete(self, filename: str):
|
||||
"""Delete file from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
"""
|
||||
if not self.exists(filename):
|
||||
logger.debug("File %s not found, skip delete", filename)
|
||||
return
|
||||
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"REMOVE {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"REMOVE {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
logger.debug("File %s deleted from ClickZetta Volume", filename)
|
||||
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
"""Scan files and directories in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
path: Path to scan (dataset_id for table volumes)
|
||||
files: Include files in results
|
||||
directories: Include directories in results
|
||||
|
||||
Returns:
|
||||
List of file/directory paths
|
||||
"""
|
||||
try:
|
||||
# For table volumes, path is treated as dataset_id
|
||||
dataset_id = None
|
||||
if self._config.volume_type == "table":
|
||||
dataset_id = path
|
||||
path = "" # Root of the table volume
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# For User Volume, add dify prefix to path
|
||||
if volume_prefix == "USER VOLUME":
|
||||
if path:
|
||||
scan_path = f"{self._config.dify_prefix}/{path}"
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'"
|
||||
else:
|
||||
if path:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix}"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
result = []
|
||||
for row in rows:
|
||||
file_path = row[0] # relative_path column
|
||||
|
||||
# For User Volume, remove dify prefix from results
|
||||
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
|
||||
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
|
||||
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
|
||||
|
||||
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
|
||||
result.append(file_path)
|
||||
|
||||
logger.debug("Scanned %d items in path %s", len(result), path)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error scanning path %s", path)
|
||||
return []
|
||||
|
|
@ -0,0 +1,516 @@
|
|||
"""ClickZetta Volume文件生命周期管理
|
||||
|
||||
该模块提供文件版本控制、自动清理、备份和恢复等生命周期管理功能。
|
||||
支持知识库文件的完整生命周期管理。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStatus(Enum):
|
||||
"""文件状态枚举"""
|
||||
|
||||
ACTIVE = "active" # 活跃状态
|
||||
ARCHIVED = "archived" # 已归档
|
||||
DELETED = "deleted" # 已删除(软删除)
|
||||
BACKUP = "backup" # 备份文件
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileMetadata:
|
||||
"""文件元数据"""
|
||||
|
||||
filename: str
|
||||
size: int | None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
version: int | None
|
||||
status: FileStatus
|
||||
checksum: Optional[str] = None
|
||||
tags: Optional[dict[str, str]] = None
|
||||
parent_version: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
data = asdict(self)
|
||||
data["created_at"] = self.created_at.isoformat()
|
||||
data["modified_at"] = self.modified_at.isoformat()
|
||||
data["status"] = self.status.value
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "FileMetadata":
|
||||
"""从字典创建实例"""
|
||||
data = data.copy()
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
|
||||
data["status"] = FileStatus(data["status"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class FileLifecycleManager:
|
||||
"""文件生命周期管理器"""
|
||||
|
||||
def __init__(self, storage, dataset_id: Optional[str] = None):
|
||||
"""初始化生命周期管理器
|
||||
|
||||
Args:
|
||||
storage: ClickZetta Volume存储实例
|
||||
dataset_id: 数据集ID(用于Table Volume)
|
||||
"""
|
||||
self._storage = storage
|
||||
self._dataset_id = dataset_id
|
||||
self._metadata_file = ".dify_file_metadata.json"
|
||||
self._version_prefix = ".versions/"
|
||||
self._backup_prefix = ".backups/"
|
||||
self._deleted_prefix = ".deleted/"
|
||||
|
||||
# 获取权限管理器(如果存在)
|
||||
self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None)
|
||||
|
||||
def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata:
|
||||
"""保存文件并管理生命周期
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
data: 文件内容
|
||||
tags: 文件标签
|
||||
|
||||
Returns:
|
||||
文件元数据
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "save"):
|
||||
from .volume_permissions import VolumePermissionError
|
||||
|
||||
raise VolumePermissionError(
|
||||
f"Permission denied for lifecycle save operation on file: {filename}",
|
||||
operation="save",
|
||||
volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
|
||||
dataset_id=self._dataset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. 检查是否存在旧版本
|
||||
metadata_dict = self._load_metadata()
|
||||
current_metadata = metadata_dict.get(filename)
|
||||
|
||||
# 2. 如果存在旧版本,创建版本备份
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata)
|
||||
|
||||
# 3. 计算文件信息
|
||||
now = datetime.now()
|
||||
checksum = self._calculate_checksum(data)
|
||||
new_version = (current_metadata["version"] + 1) if current_metadata else 1
|
||||
|
||||
# 4. 保存新文件
|
||||
self._storage.save(filename, data)
|
||||
|
||||
# 5. 创建元数据
|
||||
created_at = now
|
||||
parent_version = None
|
||||
|
||||
if current_metadata:
|
||||
# 如果created_at是字符串,转换为datetime
|
||||
if isinstance(current_metadata["created_at"], str):
|
||||
created_at = datetime.fromisoformat(current_metadata["created_at"])
|
||||
else:
|
||||
created_at = current_metadata["created_at"]
|
||||
parent_version = current_metadata["version"]
|
||||
|
||||
file_metadata = FileMetadata(
|
||||
filename=filename,
|
||||
size=len(data),
|
||||
created_at=created_at,
|
||||
modified_at=now,
|
||||
version=new_version,
|
||||
status=FileStatus.ACTIVE,
|
||||
checksum=checksum,
|
||||
tags=tags or {},
|
||||
parent_version=parent_version,
|
||||
)
|
||||
|
||||
# 6. 更新元数据
|
||||
metadata_dict[filename] = file_metadata.to_dict()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s saved with lifecycle management, version %s", filename, new_version)
|
||||
return file_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save file with lifecycle")
|
||||
raise
|
||||
|
||||
def get_file_metadata(self, filename: str) -> Optional[FileMetadata]:
|
||||
"""获取文件元数据
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
文件元数据,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
return FileMetadata.from_dict(metadata_dict[filename])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get file metadata for %s", filename)
|
||||
return None
|
||||
|
||||
def list_file_versions(self, filename: str) -> list[FileMetadata]:
|
||||
"""列出文件的所有版本
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
文件版本列表,按版本号排序
|
||||
"""
|
||||
try:
|
||||
versions = []
|
||||
|
||||
# 获取当前版本
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
versions.append(current_metadata)
|
||||
|
||||
# 获取历史版本
|
||||
version_pattern = f"{self._version_prefix}{filename}.v*"
|
||||
try:
|
||||
version_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
for file_path in version_files:
|
||||
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
|
||||
# 解析版本号
|
||||
version_str = file_path.split(".v")[-1].split(".")[0]
|
||||
try:
|
||||
version_num = int(version_str)
|
||||
# 这里简化处理,实际应该从版本文件中读取元数据
|
||||
# 暂时创建基本的元数据信息
|
||||
except ValueError:
|
||||
continue
|
||||
except:
|
||||
# 如果无法扫描版本文件,只返回当前版本
|
||||
pass
|
||||
|
||||
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list file versions for %s", filename)
|
||||
return []
|
||||
|
||||
def restore_version(self, filename: str, version: int) -> bool:
|
||||
"""恢复文件到指定版本
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
version: 要恢复的版本号
|
||||
|
||||
Returns:
|
||||
恢复是否成功
|
||||
"""
|
||||
try:
|
||||
version_filename = f"{self._version_prefix}{filename}.v{version}"
|
||||
|
||||
# 检查版本文件是否存在
|
||||
if not self._storage.exists(version_filename):
|
||||
logger.warning("Version %s of %s not found", version, filename)
|
||||
return False
|
||||
|
||||
# 读取版本文件内容
|
||||
version_data = self._storage.load_once(version_filename)
|
||||
|
||||
# 保存当前版本为备份
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata.to_dict())
|
||||
|
||||
# 恢复文件
|
||||
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to restore %s to version %s", filename, version)
|
||||
return False
|
||||
|
||||
def archive_file(self, filename: str) -> bool:
|
||||
"""归档文件
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
归档是否成功
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "archive"):
|
||||
logger.warning("Permission denied for archive operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# 更新文件状态为归档
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename not in metadata_dict:
|
||||
logger.warning("File %s not found in metadata", filename)
|
||||
return False
|
||||
|
||||
metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s archived successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to archive file %s", filename)
|
||||
return False
|
||||
|
||||
def soft_delete_file(self, filename: str) -> bool:
|
||||
"""软删除文件(移动到删除目录)
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "delete"):
|
||||
logger.warning("Permission denied for soft delete operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查文件是否存在
|
||||
if not self._storage.exists(filename):
|
||||
logger.warning("File %s not found", filename)
|
||||
return False
|
||||
|
||||
# 读取文件内容
|
||||
file_data = self._storage.load_once(filename)
|
||||
|
||||
# 移动到删除目录
|
||||
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self._storage.save(deleted_filename, file_data)
|
||||
|
||||
# 删除原文件
|
||||
self._storage.delete(filename)
|
||||
|
||||
# 更新元数据
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
metadata_dict[filename]["status"] = FileStatus.DELETED.value
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s soft deleted successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to soft delete file %s", filename)
|
||||
return False
|
||||
|
||||
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
|
||||
"""清理旧版本文件
|
||||
|
||||
Args:
|
||||
max_versions: 保留的最大版本数
|
||||
max_age_days: 版本文件的最大保留天数
|
||||
|
||||
Returns:
|
||||
清理的文件数量
|
||||
"""
|
||||
try:
|
||||
cleaned_count = 0
|
||||
cutoff_date = datetime.now() - timedelta(days=max_age_days)
|
||||
|
||||
# 获取所有版本文件
|
||||
try:
|
||||
all_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
|
||||
|
||||
# 按文件分组
|
||||
file_versions: dict[str, list[tuple[int, str]]] = {}
|
||||
for version_file in version_files:
|
||||
# 解析文件名和版本
|
||||
parts = version_file[len(self._version_prefix) :].split(".v")
|
||||
if len(parts) >= 2:
|
||||
base_filename = parts[0]
|
||||
version_part = parts[1].split(".")[0]
|
||||
try:
|
||||
version_num = int(version_part)
|
||||
if base_filename not in file_versions:
|
||||
file_versions[base_filename] = []
|
||||
file_versions[base_filename].append((version_num, version_file))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# 清理每个文件的旧版本
|
||||
for base_filename, versions in file_versions.items():
|
||||
# 按版本号排序
|
||||
versions.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 保留最新的max_versions个版本,删除其余的
|
||||
if len(versions) > max_versions:
|
||||
to_delete = versions[max_versions:]
|
||||
for version_num, version_file in to_delete:
|
||||
self._storage.delete(version_file)
|
||||
cleaned_count += 1
|
||||
logger.debug("Cleaned old version: %s", version_file)
|
||||
|
||||
logger.info("Cleaned %d old version files", cleaned_count)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not scan for version files: %s", e)
|
||||
|
||||
return cleaned_count
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to cleanup old versions")
|
||||
return 0
|
||||
|
||||
def get_storage_statistics(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息
|
||||
|
||||
Returns:
|
||||
存储统计字典
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
|
||||
stats: dict[str, Any] = {
|
||||
"total_files": len(metadata_dict),
|
||||
"active_files": 0,
|
||||
"archived_files": 0,
|
||||
"deleted_files": 0,
|
||||
"total_size": 0,
|
||||
"versions_count": 0,
|
||||
"oldest_file": None,
|
||||
"newest_file": None,
|
||||
}
|
||||
|
||||
oldest_date = None
|
||||
newest_date = None
|
||||
|
||||
for filename, metadata in metadata_dict.items():
|
||||
file_meta = FileMetadata.from_dict(metadata)
|
||||
|
||||
# 统计文件状态
|
||||
if file_meta.status == FileStatus.ACTIVE:
|
||||
stats["active_files"] = (stats["active_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.ARCHIVED:
|
||||
stats["archived_files"] = (stats["archived_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.DELETED:
|
||||
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
|
||||
|
||||
# 统计大小
|
||||
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
|
||||
|
||||
# 统计版本
|
||||
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
|
||||
|
||||
# 找出最新和最旧的文件
|
||||
if oldest_date is None or file_meta.created_at < oldest_date:
|
||||
oldest_date = file_meta.created_at
|
||||
stats["oldest_file"] = filename
|
||||
|
||||
if newest_date is None or file_meta.modified_at > newest_date:
|
||||
newest_date = file_meta.modified_at
|
||||
stats["newest_file"] = filename
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get storage statistics")
|
||||
return {}
|
||||
|
||||
def _create_version_backup(self, filename: str, metadata: dict):
|
||||
"""创建版本备份"""
|
||||
try:
|
||||
# 读取当前文件内容
|
||||
current_data = self._storage.load_once(filename)
|
||||
|
||||
# 保存为版本文件
|
||||
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
|
||||
self._storage.save(version_filename, current_data)
|
||||
|
||||
logger.debug("Created version backup: %s", version_filename)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create version backup for %s: %s", filename, e)
|
||||
|
||||
def _load_metadata(self) -> dict[str, Any]:
|
||||
"""加载元数据文件"""
|
||||
try:
|
||||
if self._storage.exists(self._metadata_file):
|
||||
metadata_content = self._storage.load_once(self._metadata_file)
|
||||
result = json.loads(metadata_content.decode("utf-8"))
|
||||
return dict(result) if result else {}
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load metadata: %s", e)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, metadata_dict: dict):
|
||||
"""保存元数据文件"""
|
||||
try:
|
||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
|
||||
logger.debug("Metadata saved successfully")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save metadata")
|
||||
raise
|
||||
|
||||
def _calculate_checksum(self, data: bytes) -> str:
|
||||
"""计算文件校验和"""
|
||||
import hashlib
|
||||
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
def _check_permission(self, filename: str, operation: str) -> bool:
|
||||
"""检查文件操作权限
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
operation: 操作类型
|
||||
|
||||
Returns:
|
||||
True if permission granted, False otherwise
|
||||
"""
|
||||
# 如果没有权限管理器,默认允许
|
||||
if not self._permission_manager:
|
||||
return True
|
||||
|
||||
try:
|
||||
# 根据操作类型映射到权限
|
||||
operation_mapping = {
|
||||
"save": "save",
|
||||
"load": "load_once",
|
||||
"delete": "delete",
|
||||
"archive": "delete", # 归档需要删除权限
|
||||
"restore": "save", # 恢复需要写权限
|
||||
"cleanup": "delete", # 清理需要删除权限
|
||||
"read": "load_once",
|
||||
"write": "save",
|
||||
}
|
||||
|
||||
mapped_operation = operation_mapping.get(operation, operation)
|
||||
|
||||
# 检查权限
|
||||
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Permission check failed for %s operation %s", filename, operation)
|
||||
# 安全默认:权限检查失败时拒绝访问
|
||||
return False
|
||||
|
|
@ -0,0 +1,646 @@
|
|||
"""ClickZetta Volume权限管理机制
|
||||
|
||||
该模块提供Volume权限检查、验证和管理功能。
|
||||
根据ClickZetta的权限模型,不同Volume类型有不同的权限要求。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VolumePermission(Enum):
|
||||
"""Volume权限类型枚举"""
|
||||
|
||||
READ = "SELECT" # 对应ClickZetta的SELECT权限
|
||||
WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限
|
||||
LIST = "SELECT" # 列出文件需要SELECT权限
|
||||
DELETE = "INSERT,UPDATE,DELETE" # 删除文件需要写权限
|
||||
USAGE = "USAGE" # External Volume需要的基本权限
|
||||
|
||||
|
||||
class VolumePermissionManager:
|
||||
"""Volume权限管理器"""
|
||||
|
||||
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None):
|
||||
"""初始化权限管理器
|
||||
|
||||
Args:
|
||||
connection_or_config: ClickZetta连接对象或配置字典
|
||||
volume_type: Volume类型 (user|table|external)
|
||||
volume_name: Volume名称 (用于external volume)
|
||||
"""
|
||||
# 支持两种初始化方式:连接对象或配置字典
|
||||
if isinstance(connection_or_config, dict):
|
||||
# 从配置字典创建连接
|
||||
import clickzetta # type: ignore[import-untyped]
|
||||
|
||||
config = connection_or_config
|
||||
self._connection = clickzetta.connect(
|
||||
username=config.get("username"),
|
||||
password=config.get("password"),
|
||||
instance=config.get("instance"),
|
||||
service=config.get("service"),
|
||||
workspace=config.get("workspace"),
|
||||
vcluster=config.get("vcluster"),
|
||||
schema=config.get("schema") or config.get("database"),
|
||||
)
|
||||
self._volume_type = config.get("volume_type", volume_type)
|
||||
self._volume_name = config.get("volume_name", volume_name)
|
||||
else:
|
||||
# 直接使用连接对象
|
||||
self._connection = connection_or_config
|
||||
self._volume_type = volume_type
|
||||
self._volume_name = volume_name
|
||||
|
||||
if not self._connection:
|
||||
raise ValueError("Valid connection or config is required")
|
||||
if not self._volume_type:
|
||||
raise ValueError("volume_type is required")
|
||||
|
||||
self._permission_cache: dict[str, set[str]] = {}
|
||||
self._current_username = None # 将从连接中获取当前用户名
|
||||
|
||||
def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool:
|
||||
"""检查用户是否有执行特定操作的权限
|
||||
|
||||
Args:
|
||||
operation: 要执行的操作类型
|
||||
dataset_id: 数据集ID (用于table volume)
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self._volume_type == "user":
|
||||
return self._check_user_volume_permission(operation)
|
||||
elif self._volume_type == "table":
|
||||
return self._check_table_volume_permission(operation, dataset_id)
|
||||
elif self._volume_type == "external":
|
||||
return self._check_external_volume_permission(operation)
|
||||
else:
|
||||
logger.warning("Unknown volume type: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Permission check failed")
|
||||
return False
|
||||
|
||||
def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""检查User Volume权限
|
||||
|
||||
User Volume权限规则:
|
||||
- 用户对自己的User Volume有全部权限
|
||||
- 只要用户能够连接到ClickZetta,就默认具有User Volume的基本权限
|
||||
- 更注重连接身份验证,而不是复杂的权限检查
|
||||
"""
|
||||
try:
|
||||
# 获取当前用户名
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# 检查基本连接状态
|
||||
with self._connection.cursor() as cursor:
|
||||
# 简单的连接测试,如果能执行查询说明用户有基本权限
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
logger.debug(
|
||||
"User Volume permission check for %s, operation %s: granted (basic connection verified)",
|
||||
current_user,
|
||||
operation.name,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"User Volume permission check failed: cannot verify basic connection for %s", current_user
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("User Volume permission check failed")
|
||||
# 对于User Volume,如果权限检查失败,可能是配置问题,给出更友好的错误提示
|
||||
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool:
|
||||
"""检查Table Volume权限
|
||||
|
||||
Table Volume权限规则:
|
||||
- Table Volume权限继承对应表的权限
|
||||
- SELECT权限 -> 可以READ/LIST文件
|
||||
- INSERT,UPDATE,DELETE权限 -> 可以WRITE/DELETE文件
|
||||
"""
|
||||
if not dataset_id:
|
||||
logger.warning("dataset_id is required for table volume permission check")
|
||||
return False
|
||||
|
||||
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
|
||||
|
||||
try:
|
||||
# 检查表权限
|
||||
permissions = self._get_table_permissions(table_name)
|
||||
required_permissions = set(operation.value.split(","))
|
||||
|
||||
# 检查是否有所需的所有权限
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"Table Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
table_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Table volume permission check failed for %s", table_name)
|
||||
return False
|
||||
|
||||
def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""检查External Volume权限
|
||||
|
||||
External Volume权限规则:
|
||||
- 尝试获取对External Volume的权限
|
||||
- 如果权限检查失败,进行备选验证
|
||||
- 对于开发环境,提供更宽松的权限检查
|
||||
"""
|
||||
if not self._volume_name:
|
||||
logger.warning("volume_name is required for external volume permission check")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查External Volume权限
|
||||
permissions = self._get_external_volume_permissions(self._volume_name)
|
||||
|
||||
# External Volume权限映射:根据操作类型确定所需权限
|
||||
required_permissions = set()
|
||||
|
||||
if operation in [VolumePermission.READ, VolumePermission.LIST]:
|
||||
required_permissions.add("read")
|
||||
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
|
||||
required_permissions.add("write")
|
||||
|
||||
# 检查是否有所需的所有权限
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"External Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
self._volume_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
# 如果权限检查失败,尝试备选验证
|
||||
if not has_permission:
|
||||
logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name)
|
||||
|
||||
# 备选验证:尝试列出Volume来验证基本访问权限
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == self._volume_name:
|
||||
logger.info("Fallback verification successful for %s", self._volume_name)
|
||||
return True
|
||||
except Exception as fallback_e:
|
||||
logger.warning("Fallback verification failed for %s: %s", self._volume_name, fallback_e)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("External volume permission check failed for %s", self._volume_name)
|
||||
logger.info("External Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _get_table_permissions(self, table_name: str) -> set[str]:
|
||||
"""获取用户对指定表的权限
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
用户对该表的权限集合
|
||||
"""
|
||||
cache_key = f"table:{table_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# 使用正确的ClickZetta语法检查当前用户权限
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# 解析权限结果,查找对该表的权限
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
object_type = grant[1].upper() if len(grant) > 1 else ""
|
||||
object_name = grant[2] if len(grant) > 2 else ""
|
||||
|
||||
# 检查是否是对该表的权限
|
||||
if (
|
||||
object_type == "TABLE"
|
||||
and object_name == table_name
|
||||
or object_type == "SCHEMA"
|
||||
and object_name in table_name
|
||||
):
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
# 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
|
||||
permissions.add("SELECT")
|
||||
except Exception:
|
||||
logger.debug("Cannot query table %s, no SELECT permission", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check table permissions for %s: %s", table_name, e)
|
||||
# 安全默认:权限检查失败时拒绝访问
|
||||
pass
|
||||
|
||||
# 缓存权限信息
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_current_username(self) -> str:
|
||||
"""获取当前用户名"""
|
||||
if self._current_username:
|
||||
return self._current_username
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
self._current_username = result[0]
|
||||
return str(self._current_username)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get current username")
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _get_user_permissions(self, username: str) -> set[str]:
|
||||
"""获取用户的基本权限集合"""
|
||||
cache_key = f"user_permissions:{username}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# 使用正确的ClickZetta语法检查当前用户权限
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# 解析权限结果,查找用户的基本权限
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
object_type = grant[1].upper() if len(grant) > 1 else ""
|
||||
|
||||
# 收集所有相关权限
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check user permissions for %s: %s", username, e)
|
||||
# 安全默认:权限检查失败时拒绝访问
|
||||
pass
|
||||
|
||||
# 缓存权限信息
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
|
||||
"""获取用户对指定External Volume的权限
|
||||
|
||||
Args:
|
||||
volume_name: External Volume名称
|
||||
|
||||
Returns:
|
||||
用户对该Volume的权限集合
|
||||
"""
|
||||
cache_key = f"external_volume:{volume_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# 使用正确的ClickZetta语法检查Volume权限
|
||||
logger.info("Checking permissions for volume: %s", volume_name)
|
||||
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
logger.info("Raw grants result for %s: %s", volume_name, grants)
|
||||
|
||||
# 解析权限结果
|
||||
# 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
|
||||
# grantee_name, grantor_name, grant_option, granted_time)
|
||||
for grant in grants:
|
||||
logger.info("Processing grant: %s", grant)
|
||||
if len(grant) >= 5:
|
||||
granted_type = grant[0]
|
||||
privilege = grant[1].upper()
|
||||
granted_on = grant[3]
|
||||
object_name = grant[4]
|
||||
|
||||
logger.info(
|
||||
"Grant details - type: %s, privilege: %s, granted_on: %s, object_name: %s",
|
||||
granted_type,
|
||||
privilege,
|
||||
granted_on,
|
||||
object_name,
|
||||
)
|
||||
|
||||
# 检查是否是对该Volume的权限或者是层级权限
|
||||
if (
|
||||
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
|
||||
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
|
||||
logger.info("Matching grant found for %s", volume_name)
|
||||
|
||||
if "READ" in privilege:
|
||||
permissions.add("read")
|
||||
logger.info("Added READ permission for %s", volume_name)
|
||||
if "WRITE" in privilege:
|
||||
permissions.add("write")
|
||||
logger.info("Added WRITE permission for %s", volume_name)
|
||||
if "ALTER" in privilege:
|
||||
permissions.add("alter")
|
||||
logger.info("Added ALTER permission for %s", volume_name)
|
||||
if privilege == "ALL":
|
||||
permissions.update(["read", "write", "alter"])
|
||||
logger.info("Added ALL permissions for %s", volume_name)
|
||||
|
||||
logger.info("Final permissions for %s: %s", volume_name, permissions)
|
||||
|
||||
# 如果没有找到明确的权限,尝试查看Volume列表来验证基本权限
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
permissions.add("read") # 至少有读权限
|
||||
logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name)
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Cannot access volume %s, no basic permission", volume_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check external volume permissions for %s: %s", volume_name, e)
|
||||
# 在权限检查失败时,尝试基本的Volume访问验证
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
logger.info("Basic volume access verified for %s", volume_name)
|
||||
permissions.add("read")
|
||||
permissions.add("write") # 假设有写权限
|
||||
break
|
||||
except Exception as basic_e:
|
||||
logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e)
|
||||
# 最后的备选方案:假设有基本权限
|
||||
permissions.add("read")
|
||||
|
||||
# 缓存权限信息
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def clear_permission_cache(self):
|
||||
"""清空权限缓存"""
|
||||
self._permission_cache.clear()
|
||||
logger.debug("Permission cache cleared")
|
||||
|
||||
def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]:
|
||||
"""获取权限摘要
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集ID (用于table volume)
|
||||
|
||||
Returns:
|
||||
权限摘要字典
|
||||
"""
|
||||
summary = {}
|
||||
|
||||
for operation in VolumePermission:
|
||||
summary[operation.name.lower()] = self.check_permission(operation, dataset_id)
|
||||
|
||||
return summary
|
||||
|
||||
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
|
||||
"""检查文件路径的权限继承
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
operation: 要执行的操作
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
# 解析文件路径
|
||||
path_parts = file_path.strip("/").split("/")
|
||||
|
||||
if not path_parts:
|
||||
logger.warning("Invalid file path for permission inheritance check")
|
||||
return False
|
||||
|
||||
# 对于Table Volume,第一层是dataset_id
|
||||
if self._volume_type == "table":
|
||||
if len(path_parts) < 1:
|
||||
return False
|
||||
|
||||
dataset_id = path_parts[0]
|
||||
|
||||
# 检查对dataset的权限
|
||||
has_dataset_permission = self.check_permission(operation, dataset_id)
|
||||
|
||||
if not has_dataset_permission:
|
||||
logger.debug("Permission denied for dataset %s", dataset_id)
|
||||
return False
|
||||
|
||||
# 检查路径遍历攻击
|
||||
if self._contains_path_traversal(file_path):
|
||||
logger.warning("Path traversal attack detected: %s", file_path)
|
||||
return False
|
||||
|
||||
# 检查是否访问敏感目录
|
||||
if self._is_sensitive_path(file_path):
|
||||
logger.warning("Access to sensitive path denied: %s", file_path)
|
||||
return False
|
||||
|
||||
logger.debug("Permission inherited for path %s", file_path)
|
||||
return True
|
||||
|
||||
elif self._volume_type == "user":
|
||||
# User Volume的权限继承
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# 检查是否试图访问其他用户的目录
|
||||
if len(path_parts) > 1 and path_parts[0] != current_user:
|
||||
logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0])
|
||||
return False
|
||||
|
||||
# 检查基本权限
|
||||
return self.check_permission(operation)
|
||||
|
||||
elif self._volume_type == "external":
|
||||
# External Volume的权限继承
|
||||
# 检查对External Volume的权限
|
||||
return self.check_permission(operation)
|
||||
|
||||
else:
|
||||
logger.warning("Unknown volume type for permission inheritance: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Permission inheritance check failed")
|
||||
return False
|
||||
|
||||
def _contains_path_traversal(self, file_path: str) -> bool:
|
||||
"""检查路径是否包含路径遍历攻击"""
|
||||
# 检查常见的路径遍历模式
|
||||
traversal_patterns = [
|
||||
"../",
|
||||
"..\\",
|
||||
"..%2f",
|
||||
"..%2F",
|
||||
"..%5c",
|
||||
"..%5C",
|
||||
"%2e%2e%2f",
|
||||
"%2e%2e%5c",
|
||||
"....//",
|
||||
"....\\\\",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
for pattern in traversal_patterns:
|
||||
if pattern in file_path_lower:
|
||||
return True
|
||||
|
||||
# 检查绝对路径
|
||||
if file_path.startswith("/") or file_path.startswith("\\"):
|
||||
return True
|
||||
|
||||
# 检查Windows驱动器路径
|
||||
if len(file_path) >= 2 and file_path[1] == ":":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_sensitive_path(self, file_path: str) -> bool:
|
||||
"""检查路径是否为敏感路径"""
|
||||
sensitive_patterns = [
|
||||
"passwd",
|
||||
"shadow",
|
||||
"hosts",
|
||||
"config",
|
||||
"secrets",
|
||||
"private",
|
||||
"key",
|
||||
"certificate",
|
||||
"cert",
|
||||
"ssl",
|
||||
"database",
|
||||
"backup",
|
||||
"dump",
|
||||
"log",
|
||||
"tmp",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
return any(pattern in file_path_lower for pattern in sensitive_patterns)
|
||||
|
||||
def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool:
|
||||
"""验证操作权限
|
||||
|
||||
Args:
|
||||
operation: 操作名称 (save|load|exists|delete|scan)
|
||||
dataset_id: 数据集ID
|
||||
|
||||
Returns:
|
||||
True if operation is allowed, False otherwise
|
||||
"""
|
||||
operation_mapping = {
|
||||
"save": VolumePermission.WRITE,
|
||||
"load": VolumePermission.READ,
|
||||
"load_once": VolumePermission.READ,
|
||||
"load_stream": VolumePermission.READ,
|
||||
"download": VolumePermission.READ,
|
||||
"exists": VolumePermission.READ,
|
||||
"delete": VolumePermission.DELETE,
|
||||
"scan": VolumePermission.LIST,
|
||||
}
|
||||
|
||||
if operation not in operation_mapping:
|
||||
logger.warning("Unknown operation: %s", operation)
|
||||
return False
|
||||
|
||||
volume_permission = operation_mapping[operation]
|
||||
return self.check_permission(volume_permission, dataset_id)
|
||||
|
||||
|
||||
class VolumePermissionError(Exception):
|
||||
"""Volume权限错误异常"""
|
||||
|
||||
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None):
|
||||
self.operation = operation
|
||||
self.volume_type = volume_type
|
||||
self.dataset_id = dataset_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_volume_permission(
|
||||
permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""权限检查装饰器函数
|
||||
|
||||
Args:
|
||||
permission_manager: 权限管理器
|
||||
operation: 操作名称
|
||||
dataset_id: 数据集ID
|
||||
|
||||
Raises:
|
||||
VolumePermissionError: 如果没有权限
|
||||
"""
|
||||
if not permission_manager.validate_operation(operation, dataset_id):
|
||||
error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"
|
||||
if dataset_id:
|
||||
error_message += f" (dataset: {dataset_id})"
|
||||
|
||||
raise VolumePermissionError(
|
||||
error_message,
|
||||
operation=operation,
|
||||
volume_type=permission_manager._volume_type or "unknown",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
|
@ -5,6 +5,7 @@ class StorageType(StrEnum):
|
|||
ALIYUN_OSS = "aliyun-oss"
|
||||
AZURE_BLOB = "azure-blob"
|
||||
BAIDU_OBS = "baidu-obs"
|
||||
CLICKZETTA_VOLUME = "clickzetta-volume"
|
||||
GOOGLE_STORAGE = "google-storage"
|
||||
HUAWEI_OBS = "huawei-obs"
|
||||
LOCAL = "local"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import hashlib
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
|
|
@ -18,7 +17,7 @@ def generate_key_pair(tenant_id: str) -> str:
|
|||
pem_private = private_key.export_key()
|
||||
pem_public = public_key.export_key()
|
||||
|
||||
filepath = os.path.join("privkeys", tenant_id, "private.pem")
|
||||
filepath = f"privkeys/{tenant_id}/private.pem"
|
||||
|
||||
storage.save(filepath, pem_private)
|
||||
|
||||
|
|
@ -48,7 +47,7 @@ def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
|
|||
|
||||
|
||||
def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
|
||||
filepath = os.path.join("privkeys", tenant_id, "private.pem")
|
||||
filepath = f"privkeys/{tenant_id}/private.pem"
|
||||
|
||||
cache_key = f"tenant_privkey:{hashlib.sha3_256(filepath.encode()).hexdigest()}"
|
||||
private_key = redis_client.get(cache_key)
|
||||
|
|
|
|||
|
|
@ -194,6 +194,7 @@ vdb = [
|
|||
"alibabacloud_tea_openapi~=0.3.9",
|
||||
"chromadb==0.5.20",
|
||||
"clickhouse-connect~=0.7.16",
|
||||
"clickzetta-connector-python>=0.8.102",
|
||||
"couchbase~=4.3.0",
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==2.4.0",
|
||||
|
|
@ -213,3 +214,4 @@ vdb = [
|
|||
"xinference-client~=1.2.2",
|
||||
"mo-vector~=0.1.13",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from sqlalchemy import text
|
||||
from werkzeug.exceptions import NotFound
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
|
|
@ -27,8 +27,8 @@ def clean_embedding_cache_task():
|
|||
.all()
|
||||
)
|
||||
embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
|
||||
except NotFound:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if embedding_ids:
|
||||
for embedding_id in embedding_ids:
|
||||
db.session.execute(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import time
|
||||
|
||||
import click
|
||||
from werkzeug.exceptions import NotFound
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
|
|
@ -42,8 +42,8 @@ def clean_messages():
|
|||
.all()
|
||||
)
|
||||
|
||||
except NotFound:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if not messages:
|
||||
break
|
||||
for message in messages:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
|
|
@ -65,8 +65,8 @@ def clean_unused_datasets_task():
|
|||
|
||||
datasets = db.paginate(stmt, page=1, per_page=50)
|
||||
|
||||
except NotFound:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if datasets.items is None or len(datasets.items) == 0:
|
||||
break
|
||||
for dataset in datasets:
|
||||
|
|
@ -146,8 +146,8 @@ def clean_unused_datasets_task():
|
|||
)
|
||||
datasets = db.paginate(stmt, page=1, per_page=50)
|
||||
|
||||
except NotFound:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if datasets.items is None or len(datasets.items) == 0:
|
||||
break
|
||||
for dataset in datasets:
|
||||
|
|
|
|||
|
|
@ -50,12 +50,16 @@ class ConversationService:
|
|||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
|
||||
)
|
||||
# Check if include_ids is not None and not empty to avoid WHERE false condition
|
||||
if include_ids is not None and len(include_ids) > 0:
|
||||
# Check if include_ids is not None to apply filter
|
||||
if include_ids is not None:
|
||||
if len(include_ids) == 0:
|
||||
# If include_ids is empty, return empty result
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
stmt = stmt.where(Conversation.id.in_(include_ids))
|
||||
# Check if exclude_ids is not None and not empty to avoid WHERE false condition
|
||||
if exclude_ids is not None and len(exclude_ids) > 0:
|
||||
stmt = stmt.where(~Conversation.id.in_(exclude_ids))
|
||||
# Check if exclude_ids is not None to apply filter
|
||||
if exclude_ids is not None:
|
||||
if len(exclude_ids) > 0:
|
||||
stmt = stmt.where(~Conversation.id.in_(exclude_ids))
|
||||
|
||||
# define sort fields and directions
|
||||
sort_field, sort_direction = cls._get_sort_params(sort_by)
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ class WorkflowDraftVariableService:
|
|||
def _reset_node_var_or_sys_var(
|
||||
self, workflow: Workflow, variable: WorkflowDraftVariable
|
||||
) -> WorkflowDraftVariable | None:
|
||||
# If a variable does not allow updating, it makes no sence to resetting it.
|
||||
# If a variable does not allow updating, it makes no sense to reset it.
|
||||
if not variable.editable:
|
||||
return variable
|
||||
# No execution record for this variable, delete the variable instead.
|
||||
|
|
@ -478,7 +478,7 @@ def _batch_upsert_draft_variable(
|
|||
"node_execution_id": stmt.excluded.node_execution_id,
|
||||
},
|
||||
)
|
||||
elif _UpsertPolicy.IGNORE:
|
||||
elif policy == _UpsertPolicy.IGNORE:
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name())
|
||||
else:
|
||||
raise Exception("Invalid value for update policy.")
|
||||
|
|
|
|||
|
|
@ -56,15 +56,24 @@ def clean_dataset_task(
|
|||
documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
|
||||
|
||||
# Fix: Always clean vector database resources regardless of document existence
|
||||
# This ensures all 33 vector databases properly drop tables/collections/indices
|
||||
if doc_form is None:
|
||||
# Use default paragraph index type for empty datasets to enable vector database cleanup
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
|
||||
doc_form = IndexType.PARAGRAPH_INDEX
|
||||
logging.info(
|
||||
click.style(f"No documents found, using default index type for cleanup: {doc_form}", fg="yellow")
|
||||
)
|
||||
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
if documents is None or len(documents) == 0:
|
||||
logging.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
|
||||
else:
|
||||
logging.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
|
||||
# Specify the index type before initializing the index processor
|
||||
if doc_form is None:
|
||||
raise ValueError("Index type must be specified.")
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for document in documents:
|
||||
db.session.delete(document)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,168 @@
|
|||
"""
|
||||
Unit tests for App description validation functions.
|
||||
|
||||
This test module validates the 400-character limit enforcement
|
||||
for App descriptions across all creation and editing endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the API root to Python path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
|
||||
|
||||
class TestAppDescriptionValidationUnit:
|
||||
"""Unit tests for description validation function"""
|
||||
|
||||
def test_validate_description_length_function(self):
|
||||
"""Test the _validate_description_length function directly"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
|
||||
# Test valid descriptions
|
||||
assert _validate_description_length("") == ""
|
||||
assert _validate_description_length("x" * 400) == "x" * 400
|
||||
assert _validate_description_length(None) is None
|
||||
|
||||
# Test invalid descriptions
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_validate_description_length("x" * 401)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_validate_description_length("x" * 500)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_validate_description_length("x" * 1000)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
def test_validation_consistency_with_dataset(self):
|
||||
"""Test that App and Dataset validation functions are consistent"""
|
||||
from controllers.console.app.app import _validate_description_length as app_validate
|
||||
from controllers.console.datasets.datasets import _validate_description_length as dataset_validate
|
||||
from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate
|
||||
|
||||
# Test same valid inputs
|
||||
valid_desc = "x" * 400
|
||||
assert app_validate(valid_desc) == dataset_validate(valid_desc) == service_dataset_validate(valid_desc)
|
||||
assert app_validate("") == dataset_validate("") == service_dataset_validate("")
|
||||
assert app_validate(None) == dataset_validate(None) == service_dataset_validate(None)
|
||||
|
||||
# Test same invalid inputs produce same error
|
||||
invalid_desc = "x" * 401
|
||||
|
||||
app_error = None
|
||||
dataset_error = None
|
||||
service_dataset_error = None
|
||||
|
||||
try:
|
||||
app_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
app_error = str(e)
|
||||
|
||||
try:
|
||||
dataset_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
dataset_error = str(e)
|
||||
|
||||
try:
|
||||
service_dataset_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
service_dataset_error = str(e)
|
||||
|
||||
assert app_error == dataset_error == service_dataset_error
|
||||
assert app_error == "Description cannot exceed 400 characters."
|
||||
|
||||
def test_boundary_values(self):
|
||||
"""Test boundary values for description validation"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
|
||||
# Test exact boundary
|
||||
exactly_400 = "x" * 400
|
||||
assert _validate_description_length(exactly_400) == exactly_400
|
||||
|
||||
# Test just over boundary
|
||||
just_over_400 = "x" * 401
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(just_over_400)
|
||||
|
||||
# Test just under boundary
|
||||
just_under_400 = "x" * 399
|
||||
assert _validate_description_length(just_under_400) == just_under_400
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases for description validation"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
|
||||
# Test None input
|
||||
assert _validate_description_length(None) is None
|
||||
|
||||
# Test empty string
|
||||
assert _validate_description_length("") == ""
|
||||
|
||||
# Test single character
|
||||
assert _validate_description_length("a") == "a"
|
||||
|
||||
# Test unicode characters
|
||||
unicode_desc = "测试" * 200 # 400 characters in Chinese
|
||||
assert _validate_description_length(unicode_desc) == unicode_desc
|
||||
|
||||
# Test unicode over limit
|
||||
unicode_over = "测试" * 201 # 402 characters
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(unicode_over)
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test how validation handles whitespace"""
|
||||
from controllers.console.app.app import _validate_description_length
|
||||
|
||||
# Test description with spaces
|
||||
spaces_400 = " " * 400
|
||||
assert _validate_description_length(spaces_400) == spaces_400
|
||||
|
||||
# Test description with spaces over limit
|
||||
spaces_401 = " " * 401
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(spaces_401)
|
||||
|
||||
# Test mixed content
|
||||
mixed_400 = "a" * 200 + " " * 200
|
||||
assert _validate_description_length(mixed_400) == mixed_400
|
||||
|
||||
# Test mixed over limit
|
||||
mixed_401 = "a" * 200 + " " * 201
|
||||
with pytest.raises(ValueError):
|
||||
_validate_description_length(mixed_401)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests directly
|
||||
import traceback
|
||||
|
||||
test_instance = TestAppDescriptionValidationUnit()
|
||||
test_methods = [method for method in dir(test_instance) if method.startswith("test_")]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test_method in test_methods:
|
||||
try:
|
||||
print(f"Running {test_method}...")
|
||||
getattr(test_instance, test_method)()
|
||||
print(f"✅ {test_method} PASSED")
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ {test_method} FAILED: {str(e)}")
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print(f"\n📊 Test Results: {passed} passed, {failed} failed")
|
||||
|
||||
if failed == 0:
|
||||
print("🎉 All tests passed!")
|
||||
else:
|
||||
print("💥 Some tests failed!")
|
||||
sys.exit(1)
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
"""Integration tests for ClickZetta Volume Storage."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
|
||||
class TestClickZettaVolumeStorage(unittest.TestCase):
|
||||
"""Test cases for ClickZetta Volume Storage."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.config = ClickZettaVolumeConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_pass"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "uat-api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
volume_type="table",
|
||||
table_prefix="test_dataset_",
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_user_volume_operations(self):
|
||||
"""Test basic operations with User Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "user"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations
|
||||
test_filename = "test_file.txt"
|
||||
test_content = b"Hello, ClickZetta Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test streaming
|
||||
stream_content = b""
|
||||
for chunk in storage.load_stream(test_filename):
|
||||
stream_content += chunk
|
||||
assert stream_content == test_content
|
||||
|
||||
# Test download
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
storage.download(test_filename, temp_file.name)
|
||||
with open(temp_file.name, "rb") as f:
|
||||
downloaded_content = f.read()
|
||||
assert downloaded_content == test_content
|
||||
|
||||
# Test scan
|
||||
files = storage.scan("", files=True, directories=False)
|
||||
assert test_filename in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_table_volume_operations(self):
|
||||
"""Test basic operations with Table Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "table"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations with dataset_id
|
||||
dataset_id = "12345"
|
||||
test_filename = f"{dataset_id}/test_file.txt"
|
||||
test_content = b"Hello, Table Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test scan for dataset
|
||||
files = storage.scan(dataset_id, files=True, directories=False)
|
||||
assert "test_file.txt" in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
def test_config_validation(self):
|
||||
"""Test configuration validation."""
|
||||
# Test missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="", # Empty username should fail
|
||||
password="pass",
|
||||
instance="instance",
|
||||
)
|
||||
|
||||
# Test invalid volume type
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type")
|
||||
|
||||
# Test external volume without volume_name
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="user",
|
||||
password="pass",
|
||||
instance="instance",
|
||||
volume_type="external",
|
||||
# Missing volume_name
|
||||
)
|
||||
|
||||
def test_volume_path_generation(self):
|
||||
"""Test volume path generation for different types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume path
|
||||
path = storage._get_volume_path("test.txt", "12345")
|
||||
assert path == "test_dataset_12345/test.txt"
|
||||
|
||||
# Test path with existing dataset_id prefix
|
||||
path = storage._get_volume_path("12345/test.txt")
|
||||
assert path == "12345/test.txt"
|
||||
|
||||
# Test user volume
|
||||
storage._config.volume_type = "user"
|
||||
path = storage._get_volume_path("test.txt")
|
||||
assert path == "test.txt"
|
||||
|
||||
def test_sql_prefix_generation(self):
|
||||
"""Test SQL prefix generation for different volume types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume SQL prefix
|
||||
prefix = storage._get_volume_sql_prefix("12345")
|
||||
assert prefix == "TABLE VOLUME test_dataset_12345"
|
||||
|
||||
# Test user volume SQL prefix
|
||||
storage._config.volume_type = "user"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "USER VOLUME"
|
||||
|
||||
# Test external volume SQL prefix
|
||||
storage._config.volume_type = "external"
|
||||
storage._config.volume_name = "my_external_volume"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "VOLUME my_external_volume"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
# Clickzetta Integration Tests
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the Clickzetta integration tests, you need to set the following environment variables:
|
||||
|
||||
```bash
|
||||
export CLICKZETTA_USERNAME=your_username
|
||||
export CLICKZETTA_PASSWORD=your_password
|
||||
export CLICKZETTA_INSTANCE=your_instance
|
||||
export CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
export CLICKZETTA_WORKSPACE=your_workspace
|
||||
export CLICKZETTA_VCLUSTER=your_vcluster
|
||||
export CLICKZETTA_SCHEMA=dify
|
||||
```
|
||||
|
||||
Then run the tests:
|
||||
|
||||
```bash
|
||||
pytest api/tests/integration_tests/vdb/clickzetta/
|
||||
```
|
||||
|
||||
## Security Note
|
||||
|
||||
Never commit credentials to the repository. Always use environment variables or secure credential management systems.
|
||||
|
|
@ -0,0 +1,224 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class TestClickzettaVector(AbstractVectorTest):
|
||||
"""
|
||||
Test cases for Clickzetta vector database integration.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self):
|
||||
"""Create a Clickzetta vector store instance for testing."""
|
||||
# Skip test if Clickzetta credentials are not configured
|
||||
if not os.getenv("CLICKZETTA_USERNAME"):
|
||||
pytest.skip("CLICKZETTA_USERNAME is not configured")
|
||||
if not os.getenv("CLICKZETTA_PASSWORD"):
|
||||
pytest.skip("CLICKZETTA_PASSWORD is not configured")
|
||||
if not os.getenv("CLICKZETTA_INSTANCE"):
|
||||
pytest.skip("CLICKZETTA_INSTANCE is not configured")
|
||||
|
||||
config = ClickzettaConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", ""),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", ""),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", ""),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
|
||||
batch_size=10, # Small batch size for testing
|
||||
enable_inverted_index=True,
|
||||
analyzer_type="chinese",
|
||||
analyzer_mode="smart",
|
||||
vector_distance_function="cosine_distance",
|
||||
)
|
||||
|
||||
with setup_mock_redis():
|
||||
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
|
||||
|
||||
yield vector
|
||||
|
||||
# Cleanup: delete the test collection
|
||||
try:
|
||||
vector.delete()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_clickzetta_vector_basic_operations(self, vector_store):
|
||||
"""Test basic CRUD operations on Clickzetta vector store."""
|
||||
# Prepare test data
|
||||
texts = [
|
||||
"这是第一个测试文档,包含一些中文内容。",
|
||||
"This is the second test document with English content.",
|
||||
"第三个文档混合了English和中文内容。",
|
||||
]
|
||||
embeddings = [
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.6, 0.7, 0.8],
|
||||
[0.9, 1.0, 1.1, 1.2],
|
||||
]
|
||||
documents = [
|
||||
Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"})
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
||||
# Test create (initial insert)
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test text_exists
|
||||
assert vector_store.text_exists("doc_0")
|
||||
assert not vector_store.text_exists("doc_999")
|
||||
|
||||
# Test search_by_vector
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=2)
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == texts[0] # Should match the first document
|
||||
|
||||
# Test search_by_full_text (Chinese)
|
||||
results = vector_store.search_by_full_text("中文", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with Chinese content
|
||||
|
||||
# Test search_by_full_text (English)
|
||||
results = vector_store.search_by_full_text("English", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with English content
|
||||
|
||||
# Test delete_by_ids
|
||||
vector_store.delete_by_ids(["doc_0"])
|
||||
assert not vector_store.text_exists("doc_0")
|
||||
assert vector_store.text_exists("doc_1")
|
||||
|
||||
# Test delete_by_metadata_field
|
||||
vector_store.delete_by_metadata_field("source", "test")
|
||||
assert not vector_store.text_exists("doc_1")
|
||||
assert not vector_store.text_exists("doc_2")
|
||||
|
||||
def test_clickzetta_vector_advanced_search(self, vector_store):
|
||||
"""Test advanced search features of Clickzetta vector store."""
|
||||
# Prepare test data with more complex metadata
|
||||
documents = []
|
||||
embeddings = []
|
||||
for i in range(10):
|
||||
doc = Document(
|
||||
page_content=f"Document {i}: " + get_example_text(),
|
||||
metadata={
|
||||
"doc_id": f"adv_doc_{i}",
|
||||
"category": "technical" if i % 2 == 0 else "general",
|
||||
"document_id": f"doc_{i // 3}", # Group documents
|
||||
"importance": i,
|
||||
},
|
||||
)
|
||||
documents.append(doc)
|
||||
# Create varied embeddings
|
||||
embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i])
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test vector search with document filter
|
||||
query_vector = [0.5, 1.0, 1.5, 2.0]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"])
|
||||
assert len(results) > 0
|
||||
# All results should belong to doc_0 or doc_1 groups
|
||||
for result in results:
|
||||
assert result.metadata["document_id"] in ["doc_0", "doc_1"]
|
||||
|
||||
# Test score threshold
|
||||
results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5)
|
||||
# Check that all results have a score above threshold
|
||||
for result in results:
|
||||
assert result.metadata.get("score", 0) >= 0.5
|
||||
|
||||
def test_clickzetta_batch_operations(self, vector_store):
|
||||
"""Test batch insertion operations."""
|
||||
# Prepare large batch of documents
|
||||
batch_size = 25
|
||||
documents = []
|
||||
embeddings = []
|
||||
|
||||
for i in range(batch_size):
|
||||
doc = Document(
|
||||
page_content=f"Batch document {i}: This is a test document for batch processing.",
|
||||
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"},
|
||||
)
|
||||
documents.append(doc)
|
||||
embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)])
|
||||
|
||||
# Test batch insert
|
||||
vector_store.add_texts(documents=documents, embeddings=embeddings)
|
||||
|
||||
# Verify all documents were inserted
|
||||
for i in range(batch_size):
|
||||
assert vector_store.text_exists(f"batch_doc_{i}")
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("batch", "test_batch")
|
||||
|
||||
def test_clickzetta_edge_cases(self, vector_store):
|
||||
"""Test edge cases and error handling."""
|
||||
# Test empty operations
|
||||
vector_store.create(texts=[], embeddings=[])
|
||||
vector_store.add_texts(documents=[], embeddings=[])
|
||||
vector_store.delete_by_ids([])
|
||||
|
||||
# Test special characters in content
|
||||
special_doc = Document(
|
||||
page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline",
|
||||
metadata={"doc_id": "special_doc", "test": "edge_case"},
|
||||
)
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
||||
|
||||
vector_store.add_texts(documents=[special_doc], embeddings=embeddings)
|
||||
assert vector_store.text_exists("special_doc")
|
||||
|
||||
# Test search with special characters
|
||||
results = vector_store.search_by_full_text("quotes", top_k=1)
|
||||
if results: # Full-text search might not be available
|
||||
assert len(results) > 0
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_ids(["special_doc"])
|
||||
|
||||
def test_clickzetta_full_text_search_modes(self, vector_store):
|
||||
"""Test different full-text search capabilities."""
|
||||
# Prepare documents with various language content
|
||||
documents = [
|
||||
Document(
|
||||
page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Clickzetta provides powerful Lakehouse solutions",
|
||||
metadata={"doc_id": "en_doc_1", "lang": "english"},
|
||||
),
|
||||
Document(
|
||||
page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Modern data architecture includes Lakehouse technology",
|
||||
metadata={"doc_id": "en_doc_2", "lang": "english"},
|
||||
),
|
||||
]
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents]
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test Chinese full-text search
|
||||
results = vector_store.search_by_full_text("Lakehouse", top_k=4)
|
||||
assert len(results) >= 2 # Should find at least documents with "Lakehouse"
|
||||
|
||||
# Test English full-text search
|
||||
results = vector_store.search_by_full_text("solutions", top_k=2)
|
||||
assert len(results) >= 1 # Should find English documents with "solutions"
|
||||
|
||||
# Test mixed search
|
||||
results = vector_store.search_by_full_text("数据架构", top_k=2)
|
||||
assert len(results) >= 1 # Should find Chinese documents with this phrase
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("lang", "chinese")
|
||||
vector_store.delete_by_metadata_field("lang", "english")
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Clickzetta integration in Docker environment
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
from clickzetta import connect
|
||||
|
||||
|
||||
def test_clickzetta_connection():
|
||||
"""Test direct connection to Clickzetta"""
|
||||
print("=== Testing direct Clickzetta connection ===")
|
||||
try:
|
||||
conn = connect(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_password"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"),
|
||||
database=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Test basic connectivity
|
||||
cursor.execute("SELECT 1 as test")
|
||||
result = cursor.fetchone()
|
||||
print(f"✓ Connection test: {result}")
|
||||
|
||||
# Check if our test table exists
|
||||
cursor.execute("SHOW TABLES IN dify")
|
||||
tables = cursor.fetchall()
|
||||
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
|
||||
|
||||
# Check if test collection exists
|
||||
test_collection = "collection_test_dataset"
|
||||
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
|
||||
cursor.execute(f"DESCRIBE dify.{test_collection}")
|
||||
columns = cursor.fetchall()
|
||||
print(f"✓ Table structure for {test_collection}:")
|
||||
for col in columns:
|
||||
print(f" - {col[0]}: {col[1]}")
|
||||
|
||||
# Check for indexes
|
||||
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
|
||||
indexes = cursor.fetchall()
|
||||
print(f"✓ Indexes on {test_collection}:")
|
||||
for idx in indexes:
|
||||
print(f" - {idx}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_dify_api():
|
||||
"""Test Dify API with Clickzetta backend"""
|
||||
print("\n=== Testing Dify API ===")
|
||||
base_url = "http://localhost:5001"
|
||||
|
||||
# Wait for API to be ready
|
||||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = requests.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
break
|
||||
except:
|
||||
if i == max_retries - 1:
|
||||
print("✗ Dify API is not responding")
|
||||
return False
|
||||
time.sleep(2)
|
||||
|
||||
# Check vector store configuration
|
||||
try:
|
||||
# This is a simplified check - in production, you'd use proper auth
|
||||
print("✓ Dify is configured to use Clickzetta as vector store")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ API test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def verify_table_structure():
|
||||
"""Verify the table structure meets Dify requirements"""
|
||||
print("\n=== Verifying Table Structure ===")
|
||||
|
||||
expected_columns = {
|
||||
"id": "VARCHAR",
|
||||
"page_content": "VARCHAR",
|
||||
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
|
||||
"vector": "ARRAY<FLOAT>",
|
||||
}
|
||||
|
||||
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
|
||||
|
||||
print("✓ Expected table structure:")
|
||||
for col, dtype in expected_columns.items():
|
||||
print(f" - {col}: {dtype}")
|
||||
|
||||
print("\n✓ Required metadata fields:")
|
||||
for field in expected_metadata_fields:
|
||||
print(f" - {field}")
|
||||
|
||||
print("\n✓ Index requirements:")
|
||||
print(" - Vector index (HNSW) on 'vector' column")
|
||||
print(" - Full-text index on 'page_content' (optional)")
|
||||
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
|
||||
print(" - Functional index on metadata->>'$.document_id' (recommended)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
|
||||
tests = [
|
||||
("Direct Clickzetta Connection", test_clickzetta_connection),
|
||||
("Dify API Status", test_dify_api),
|
||||
("Table Structure Verification", verify_table_structure),
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, success))
|
||||
except Exception as e:
|
||||
print(f"\n✗ {test_name} crashed: {e}")
|
||||
results.append((test_name, False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Test Summary:")
|
||||
print("=" * 50)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
total = len(results)
|
||||
|
||||
for test_name, success in results:
|
||||
status = "✅ PASSED" if success else "❌ FAILED"
|
||||
print(f"{test_name}: {status}")
|
||||
|
||||
print(f"\nTotal: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
|
||||
print("\nNext steps:")
|
||||
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
|
||||
print("2. Access Dify at http://localhost:3000")
|
||||
print("3. Create a dataset and test vector storage with Clickzetta")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,487 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
|
||||
|
||||
class TestAPIBasedExtensionService:
|
||||
"""Integration tests for APIBasedExtensionService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch("services.api_based_extension_service.APIBasedExtensionRequestor") as mock_requestor,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_account_feature_service.get_features.return_value.billing.enabled = False
|
||||
|
||||
# Mock successful ping response
|
||||
mock_requestor_instance = mock_requestor.return_value
|
||||
mock_requestor_instance.request.return_value = {"result": "pong"}
|
||||
|
||||
yield {
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"requestor": mock_requestor,
|
||||
"requestor_instance": mock_requestor_instance,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful saving of API-based extension.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
# Save extension
|
||||
saved_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Verify extension was saved correctly
|
||||
assert saved_extension.id is not None
|
||||
assert saved_extension.tenant_id == tenant.id
|
||||
assert saved_extension.name == extension_data.name
|
||||
assert saved_extension.api_endpoint == extension_data.api_endpoint
|
||||
assert saved_extension.api_key == extension_data.api_key # Should be decrypted when retrieved
|
||||
assert saved_extension.created_at is not None
|
||||
|
||||
# Verify extension was saved to database
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(saved_extension)
|
||||
assert saved_extension.id is not None
|
||||
|
||||
# Verify ping connection was called
|
||||
mock_external_service_dependencies["requestor_instance"].request.assert_called_once()
|
||||
|
||||
def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation errors when saving extension with invalid data.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Test empty name
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = ""
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test empty api_endpoint
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = ""
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test empty api_key
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = ""
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of all extensions by tenant ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create multiple extensions
|
||||
extensions = []
|
||||
for i in range(3):
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = f"Extension {i}: {fake.company()}"
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
saved_extension = APIBasedExtensionService.save(extension_data)
|
||||
extensions.append(saved_extension)
|
||||
|
||||
# Get all extensions for tenant
|
||||
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
|
||||
|
||||
# Verify results
|
||||
assert len(extension_list) == 3
|
||||
|
||||
# Verify all extensions belong to the correct tenant and are ordered by created_at desc
|
||||
for i, extension in enumerate(extension_list):
|
||||
assert extension.tenant_id == tenant.id
|
||||
assert extension.api_key is not None # Should be decrypted
|
||||
if i > 0:
|
||||
# Verify descending order (newer first)
|
||||
assert extension.created_at <= extension_list[i - 1].created_at
|
||||
|
||||
def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of extension by tenant ID and extension ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create an extension
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Get extension by ID
|
||||
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
|
||||
|
||||
# Verify extension was retrieved correctly
|
||||
assert retrieved_extension is not None
|
||||
assert retrieved_extension.id == created_extension.id
|
||||
assert retrieved_extension.tenant_id == tenant.id
|
||||
assert retrieved_extension.name == extension_data.name
|
||||
assert retrieved_extension.api_endpoint == extension_data.api_endpoint
|
||||
assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted
|
||||
assert retrieved_extension.created_at is not None
|
||||
|
||||
def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test retrieval of extension when extension is not found.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
non_existent_extension_id = fake.uuid4()
|
||||
|
||||
# Try to get non-existent extension
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id)
|
||||
|
||||
def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of extension.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create an extension first
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
extension_id = created_extension.id
|
||||
|
||||
# Delete the extension
|
||||
APIBasedExtensionService.delete(created_extension)
|
||||
|
||||
# Verify extension was deleted
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_extension = db.session.query(APIBasedExtension).filter(APIBasedExtension.id == extension_id).first()
|
||||
assert deleted_extension is None
|
||||
|
||||
def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation error when saving extension with duplicate name.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create first extension
|
||||
extension_data1 = APIBasedExtension()
|
||||
extension_data1.tenant_id = tenant.id
|
||||
extension_data1.name = "Test Extension"
|
||||
extension_data1.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data1.api_key = fake.password(length=20)
|
||||
|
||||
APIBasedExtensionService.save(extension_data1)
|
||||
|
||||
# Try to create second extension with same name
|
||||
extension_data2 = APIBasedExtension()
|
||||
extension_data2.tenant_id = tenant.id
|
||||
extension_data2.name = "Test Extension" # Same name
|
||||
extension_data2.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data2.api_key = fake.password(length=20)
|
||||
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService.save(extension_data2)
|
||||
|
||||
def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful update of existing extension.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create initial extension
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Save original values for later comparison
|
||||
original_name = created_extension.name
|
||||
original_endpoint = created_extension.api_endpoint
|
||||
|
||||
# Update the extension
|
||||
new_name = fake.company()
|
||||
new_endpoint = f"https://{fake.domain_name()}/api"
|
||||
new_api_key = fake.password(length=20)
|
||||
|
||||
created_extension.name = new_name
|
||||
created_extension.api_endpoint = new_endpoint
|
||||
created_extension.api_key = new_api_key
|
||||
|
||||
updated_extension = APIBasedExtensionService.save(created_extension)
|
||||
|
||||
# Verify extension was updated correctly
|
||||
assert updated_extension.id == created_extension.id
|
||||
assert updated_extension.tenant_id == tenant.id
|
||||
assert updated_extension.name == new_name
|
||||
assert updated_extension.api_endpoint == new_endpoint
|
||||
|
||||
# Verify original values were changed
|
||||
assert updated_extension.name != original_name
|
||||
assert updated_extension.api_endpoint != original_endpoint
|
||||
|
||||
# Verify ping connection was called for both create and update
|
||||
assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2
|
||||
|
||||
# Verify the update by retrieving the extension again
|
||||
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
|
||||
assert retrieved_extension.name == new_name
|
||||
assert retrieved_extension.api_endpoint == new_endpoint
|
||||
assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved
|
||||
|
||||
def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test connection error when saving extension with invalid endpoint.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Mock connection error
|
||||
mock_external_service_dependencies["requestor_instance"].request.side_effect = ValueError(
|
||||
"connection error: request timeout"
|
||||
)
|
||||
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = "https://invalid-endpoint.com/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
# Try to save extension with connection error
|
||||
with pytest.raises(ValueError, match="connection error: request timeout"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_invalid_api_key_length(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test validation error when saving extension with API key that is too short.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup extension data with short API key
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = "1234" # Less than 5 characters
|
||||
|
||||
# Try to save extension with short API key
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation errors when saving extension with empty required fields.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Test with None values
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = None
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test with None api_endpoint
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = None
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test with None api_key
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = None
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test retrieval of extensions when no extensions exist for tenant.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Get all extensions for tenant (none exist)
|
||||
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
|
||||
|
||||
# Verify empty list is returned
|
||||
assert len(extension_list) == 0
|
||||
assert extension_list == []
|
||||
|
||||
def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation error when ping response is invalid.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Mock invalid ping response
|
||||
mock_external_service_dependencies["requestor_instance"].request.return_value = {"result": "invalid"}
|
||||
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
# Try to save extension with invalid ping response
|
||||
with pytest.raises(ValueError, match="{'result': 'invalid'}"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation error when ping response is missing result field.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Mock ping response without result field
|
||||
mock_external_service_dependencies["requestor_instance"].request.return_value = {"status": "ok"}
|
||||
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
# Try to save extension with missing ping result
|
||||
with pytest.raises(ValueError, match="{'status': 'ok'}"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test retrieval of extension when tenant ID doesn't match.
|
||||
"""
|
||||
fake = Faker()
|
||||
account1, tenant1 = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create second account and tenant
|
||||
account2, tenant2 = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create extension in first tenant
|
||||
extension_data = APIBasedExtension()
|
||||
extension_data.tenant_id = tenant1.id
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = fake.password(length=20)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Try to get extension with wrong tenant ID
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)
|
||||
|
|
@ -0,0 +1,473 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from faker import Faker
|
||||
|
||||
from models.model import App, AppModelConfig
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class TestAppDslService:
|
||||
"""Integration tests for AppDslService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.app_dsl_service.WorkflowService") as mock_workflow_service,
|
||||
patch("services.app_dsl_service.DependenciesAnalysisService") as mock_dependencies_service,
|
||||
patch("services.app_dsl_service.WorkflowDraftVariableService") as mock_draft_variable_service,
|
||||
patch("services.app_dsl_service.ssrf_proxy") as mock_ssrf_proxy,
|
||||
patch("services.app_dsl_service.redis_client") as mock_redis_client,
|
||||
patch("services.app_dsl_service.app_was_created") as mock_app_was_created,
|
||||
patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_workflow_service.return_value.get_draft_workflow.return_value = None
|
||||
mock_workflow_service.return_value.sync_draft_workflow.return_value = MagicMock()
|
||||
mock_dependencies_service.generate_latest_dependencies.return_value = []
|
||||
mock_dependencies_service.get_leaked_dependencies.return_value = []
|
||||
mock_dependencies_service.generate_dependencies.return_value = []
|
||||
mock_draft_variable_service.return_value.delete_workflow_variables.return_value = None
|
||||
mock_ssrf_proxy.get.return_value.content = b"test content"
|
||||
mock_ssrf_proxy.get.return_value.raise_for_status.return_value = None
|
||||
mock_redis_client.setex.return_value = None
|
||||
mock_redis_client.get.return_value = None
|
||||
mock_redis_client.delete.return_value = None
|
||||
mock_app_was_created.send.return_value = None
|
||||
mock_app_model_config_was_updated.send.return_value = None
|
||||
|
||||
# Mock ModelManager for app service
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
|
||||
# Mock FeatureService and EnterpriseService
|
||||
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||
|
||||
yield {
|
||||
"workflow_service": mock_workflow_service,
|
||||
"dependencies_service": mock_dependencies_service,
|
||||
"draft_variable_service": mock_draft_variable_service,
|
||||
"ssrf_proxy": mock_ssrf_proxy,
|
||||
"redis_client": mock_redis_client,
|
||||
"app_was_created": mock_app_was_created,
|
||||
"app_model_config_was_updated": mock_app_model_config_was_updated,
|
||||
"model_manager": mock_model_manager,
|
||||
"feature_service": mock_feature_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (app, account) - Created app and account instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
with patch("services.account_service.FeatureService") as mock_account_feature_service:
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Create app
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_simple_yaml_content(self, app_name="Test App", app_mode="chat"):
|
||||
"""
|
||||
Helper method to create simple YAML content for testing.
|
||||
"""
|
||||
yaml_data = {
|
||||
"version": "0.3.0",
|
||||
"kind": "app",
|
||||
"app": {
|
||||
"name": app_name,
|
||||
"mode": app_mode,
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FFEAD5",
|
||||
"description": "Test app description",
|
||||
"use_icon_as_answer_icon": False,
|
||||
},
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
},
|
||||
"pre_prompt": "You are a helpful assistant.",
|
||||
"prompt_type": "simple",
|
||||
},
|
||||
}
|
||||
return yaml.dump(yaml_data, allow_unicode=True)
|
||||
|
||||
def test_import_app_yaml_content_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app import from YAML content.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create YAML content
|
||||
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
|
||||
|
||||
# Import app
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name="Imported App",
|
||||
description="Imported app description",
|
||||
)
|
||||
|
||||
# Verify import result
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
assert result.app_id is not None
|
||||
assert result.app_mode == "chat"
|
||||
assert result.imported_dsl_version == "0.3.0"
|
||||
assert result.error == ""
|
||||
|
||||
# Verify app was created in database
|
||||
imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first()
|
||||
assert imported_app is not None
|
||||
assert imported_app.name == "Imported App"
|
||||
assert imported_app.description == "Imported app description"
|
||||
assert imported_app.mode == "chat"
|
||||
assert imported_app.tenant_id == account.current_tenant_id
|
||||
assert imported_app.created_by == account.id
|
||||
|
||||
# Verify model config was created
|
||||
model_config = (
|
||||
db_session_with_containers.query(AppModelConfig).filter(AppModelConfig.app_id == result.app_id).first()
|
||||
)
|
||||
assert model_config is not None
|
||||
# The provider and model_id are stored in the model field as JSON
|
||||
model_dict = model_config.model_dict
|
||||
assert model_dict["provider"] == "openai"
|
||||
assert model_dict["name"] == "gpt-3.5-turbo"
|
||||
|
||||
def test_import_app_yaml_url_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app import from YAML URL.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create YAML content for mock response
|
||||
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
|
||||
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = yaml_content.encode("utf-8")
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_external_service_dependencies["ssrf_proxy"].get.return_value = mock_response
|
||||
|
||||
# Import app from URL
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_URL,
|
||||
yaml_url="https://example.com/app.yaml",
|
||||
name="URL Imported App",
|
||||
description="App imported from URL",
|
||||
)
|
||||
|
||||
# Verify import result
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
assert result.app_id is not None
|
||||
assert result.app_mode == "chat"
|
||||
assert result.imported_dsl_version == "0.3.0"
|
||||
assert result.error == ""
|
||||
|
||||
# Verify app was created in database
|
||||
imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first()
|
||||
assert imported_app is not None
|
||||
assert imported_app.name == "URL Imported App"
|
||||
assert imported_app.description == "App imported from URL"
|
||||
assert imported_app.mode == "chat"
|
||||
assert imported_app.tenant_id == account.current_tenant_id
|
||||
|
||||
# Verify ssrf_proxy was called
|
||||
mock_external_service_dependencies["ssrf_proxy"].get.assert_called_once_with(
|
||||
"https://example.com/app.yaml", follow_redirects=True, timeout=(10, 10)
|
||||
)
|
||||
|
||||
def test_import_app_invalid_yaml_format(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app import with invalid YAML format.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create invalid YAML content
|
||||
invalid_yaml = "invalid: yaml: content: ["
|
||||
|
||||
# Import app with invalid YAML
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=invalid_yaml,
|
||||
name="Invalid App",
|
||||
)
|
||||
|
||||
# Verify import failed
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.app_id is None
|
||||
assert "Invalid YAML format" in result.error
|
||||
assert result.imported_dsl_version == ""
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app import with missing YAML content.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Import app without YAML content
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
name="Missing Content App",
|
||||
)
|
||||
|
||||
# Verify import failed
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.app_id is None
|
||||
assert "yaml_content is required" in result.error
|
||||
assert result.imported_dsl_version == ""
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app import with missing YAML URL.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Import app without YAML URL
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_URL,
|
||||
name="Missing URL App",
|
||||
)
|
||||
|
||||
# Verify import failed
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.app_id is None
|
||||
assert "yaml_url is required" in result.error
|
||||
assert result.imported_dsl_version == ""
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app import with invalid import mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create YAML content
|
||||
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
|
||||
|
||||
# Import app with invalid mode should raise ValueError
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Invalid import_mode: invalid-mode"):
|
||||
dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode="invalid-mode",
|
||||
yaml_content=yaml_content,
|
||||
name="Invalid Mode App",
|
||||
)
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful DSL export for chat app.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create model config for the app
|
||||
model_config = AppModelConfig()
|
||||
model_config.id = fake.uuid4()
|
||||
model_config.app_id = app.id
|
||||
model_config.provider = "openai"
|
||||
model_config.model_id = "gpt-3.5-turbo"
|
||||
model_config.model = json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}
|
||||
)
|
||||
model_config.pre_prompt = "You are a helpful assistant."
|
||||
model_config.prompt_type = "simple"
|
||||
model_config.created_by = account.id
|
||||
model_config.updated_by = account.id
|
||||
|
||||
# Set the app_model_config_id to link the config
|
||||
app.app_model_config_id = model_config.id
|
||||
|
||||
db_session_with_containers.add(model_config)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Export DSL
|
||||
exported_dsl = AppDslService.export_dsl(app, include_secret=False)
|
||||
|
||||
# Parse exported YAML
|
||||
exported_data = yaml.safe_load(exported_dsl)
|
||||
|
||||
# Verify exported data structure
|
||||
assert exported_data["kind"] == "app"
|
||||
assert exported_data["app"]["name"] == app.name
|
||||
assert exported_data["app"]["mode"] == app.mode
|
||||
assert exported_data["app"]["icon"] == app.icon
|
||||
assert exported_data["app"]["icon_background"] == app.icon_background
|
||||
assert exported_data["app"]["description"] == app.description
|
||||
|
||||
# Verify model config was exported
|
||||
assert "model_config" in exported_data
|
||||
# The exported model_config structure may be different from the database structure
|
||||
# Check that the model config exists and has the expected content
|
||||
assert exported_data["model_config"] is not None
|
||||
|
||||
# Verify dependencies were exported
|
||||
assert "dependencies" in exported_data
|
||||
assert isinstance(exported_data["dependencies"], list)
|
||||
|
||||
def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful DSL export for workflow app.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Update app to workflow mode
|
||||
app.mode = "workflow"
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock workflow service to return a workflow
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.to_dict.return_value = {
|
||||
"graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []},
|
||||
"features": {},
|
||||
"environment_variables": [],
|
||||
"conversation_variables": [],
|
||||
}
|
||||
mock_external_service_dependencies[
|
||||
"workflow_service"
|
||||
].return_value.get_draft_workflow.return_value = mock_workflow
|
||||
|
||||
# Export DSL
|
||||
exported_dsl = AppDslService.export_dsl(app, include_secret=False)
|
||||
|
||||
# Parse exported YAML
|
||||
exported_data = yaml.safe_load(exported_dsl)
|
||||
|
||||
# Verify exported data structure
|
||||
assert exported_data["kind"] == "app"
|
||||
assert exported_data["app"]["name"] == app.name
|
||||
assert exported_data["app"]["mode"] == "workflow"
|
||||
|
||||
# Verify workflow was exported
|
||||
assert "workflow" in exported_data
|
||||
assert "graph" in exported_data["workflow"]
|
||||
assert "nodes" in exported_data["workflow"]["graph"]
|
||||
|
||||
# Verify dependencies were exported
|
||||
assert "dependencies" in exported_data
|
||||
assert isinstance(exported_data["dependencies"], list)
|
||||
|
||||
# Verify workflow service was called
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
|
||||
app
|
||||
)
|
||||
|
||||
def test_check_dependencies_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful dependency checking.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Mock Redis to return dependencies
|
||||
mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}'
|
||||
mock_external_service_dependencies["redis_client"].get.return_value = mock_dependencies_json
|
||||
|
||||
# Check dependencies
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.check_dependencies(app_model=app)
|
||||
|
||||
# Verify result
|
||||
assert result.leaked_dependencies == []
|
||||
|
||||
# Verify Redis was queried
|
||||
mock_external_service_dependencies["redis_client"].get.assert_called_once_with(
|
||||
f"app_check_dependencies:{app.id}"
|
||||
)
|
||||
|
||||
# Verify dependencies service was called
|
||||
mock_external_service_dependencies["dependencies_service"].get_leaked_dependencies.assert_called_once()
|
||||
|
|
@ -0,0 +1,928 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from constants.model_template import default_app_templates
|
||||
from models.model import App, Site
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class TestAppService:
|
||||
"""Integration tests for AppService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
# Setup default mock returns for app service
|
||||
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||
|
||||
# Setup default mock returns for account service
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Mock ModelManager for model configuration
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
|
||||
yield {
|
||||
"feature_service": mock_feature_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app creation with basic parameters.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Create app
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Verify app was created correctly
|
||||
assert app.name == app_args["name"]
|
||||
assert app.description == app_args["description"]
|
||||
assert app.mode == app_args["mode"]
|
||||
assert app.icon_type == app_args["icon_type"]
|
||||
assert app.icon == app_args["icon"]
|
||||
assert app.icon_background == app_args["icon_background"]
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.api_rph == app_args["api_rph"]
|
||||
assert app.api_rpm == app_args["api_rpm"]
|
||||
assert app.created_by == account.id
|
||||
assert app.updated_by == account.id
|
||||
assert app.status == "normal"
|
||||
assert app.enable_site is True
|
||||
assert app.enable_api is True
|
||||
assert app.is_demo is False
|
||||
assert app.is_public is False
|
||||
assert app.is_universal is False
|
||||
|
||||
def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app creation with different app modes.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Test different app modes
|
||||
# from AppMode enum in default_app_model_template
|
||||
app_modes = [v.value for v in default_app_templates]
|
||||
|
||||
for mode in app_modes:
|
||||
app_args = {
|
||||
"name": f"{fake.company()} {mode}",
|
||||
"description": f"Test app for {mode} mode",
|
||||
"mode": mode,
|
||||
"icon_type": "emoji",
|
||||
"icon": "🚀",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Verify app mode was set correctly
|
||||
assert app.mode == mode
|
||||
assert app.name == app_args["name"]
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.created_by == account.id
|
||||
|
||||
def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app retrieval.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
created_app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get app using the service
|
||||
retrieved_app = app_service.get_app(created_app)
|
||||
|
||||
# Verify retrieved app matches created app
|
||||
assert retrieved_app.id == created_app.id
|
||||
assert retrieved_app.name == created_app.name
|
||||
assert retrieved_app.description == created_app.description
|
||||
assert retrieved_app.mode == created_app.mode
|
||||
assert retrieved_app.tenant_id == created_app.tenant_id
|
||||
assert retrieved_app.created_by == created_app.created_by
|
||||
|
||||
def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful paginated app list retrieval.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create multiple apps
|
||||
app_names = [fake.company() for _ in range(5)]
|
||||
for name in app_names:
|
||||
app_args = {
|
||||
"name": name,
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "📱",
|
||||
"icon_background": "#96CEB4",
|
||||
}
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get paginated apps
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
}
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
|
||||
# Verify pagination results
|
||||
assert paginated_apps is not None
|
||||
assert len(paginated_apps.items) >= 5 # Should have at least 5 apps
|
||||
assert paginated_apps.page == 1
|
||||
assert paginated_apps.per_page == 10
|
||||
|
||||
# Verify all apps belong to the correct tenant
|
||||
for app in paginated_apps.items:
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.mode == "chat"
|
||||
|
||||
def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test paginated app list with various filters.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create apps with different modes
|
||||
chat_app_args = {
|
||||
"name": "Chat App",
|
||||
"description": "A chat application",
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "💬",
|
||||
"icon_background": "#FF6B6B",
|
||||
}
|
||||
completion_app_args = {
|
||||
"name": "Completion App",
|
||||
"description": "A completion application",
|
||||
"mode": "completion",
|
||||
"icon_type": "emoji",
|
||||
"icon": "✍️",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
|
||||
chat_app = app_service.create_app(tenant.id, chat_app_args, account)
|
||||
completion_app = app_service.create_app(tenant.id, completion_app_args, account)
|
||||
|
||||
# Test filter by mode
|
||||
chat_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
}
|
||||
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, chat_args)
|
||||
assert len(chat_apps.items) == 1
|
||||
assert chat_apps.items[0].mode == "chat"
|
||||
|
||||
# Test filter by name
|
||||
name_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"name": "Chat",
|
||||
}
|
||||
filtered_apps = app_service.get_paginate_apps(account.id, tenant.id, name_args)
|
||||
assert len(filtered_apps.items) == 1
|
||||
assert "Chat" in filtered_apps.items[0].name
|
||||
|
||||
# Test filter by created_by_me
|
||||
created_by_me_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "completion",
|
||||
"is_created_by_me": True,
|
||||
}
|
||||
my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args)
|
||||
assert len(my_apps.items) == 1
|
||||
|
||||
def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test paginated app list with tag filters.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create an app
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🏷️",
|
||||
"icon_background": "#FFEAA7",
|
||||
}
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Mock TagService to return the app ID for tag filtering
|
||||
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
|
||||
mock_tag_service.return_value = [app.id]
|
||||
|
||||
# Test with tag filter
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"tag_ids": ["tag1", "tag2"],
|
||||
}
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
|
||||
# Verify tag service was called
|
||||
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"])
|
||||
|
||||
# Verify results
|
||||
assert paginated_apps is not None
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].id == app.id
|
||||
|
||||
# Test with tag filter that returns no results
|
||||
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
|
||||
mock_tag_service.return_value = []
|
||||
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"tag_ids": ["nonexistent_tag"],
|
||||
}
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
|
||||
# Should return None when no apps match tag filter
|
||||
assert paginated_apps is None
|
||||
|
||||
def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app update with all fields.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original values
|
||||
original_name = app.name
|
||||
original_description = app.description
|
||||
original_icon = app.icon
|
||||
original_icon_background = app.icon_background
|
||||
original_use_icon_as_answer_icon = app.use_icon_as_answer_icon
|
||||
|
||||
# Update app
|
||||
update_args = {
|
||||
"name": "Updated App Name",
|
||||
"description": "Updated app description",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔄",
|
||||
"icon_background": "#FF8C42",
|
||||
"use_icon_as_answer_icon": True,
|
||||
}
|
||||
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app(app, update_args)
|
||||
|
||||
# Verify updated fields
|
||||
assert updated_app.name == update_args["name"]
|
||||
assert updated_app.description == update_args["description"]
|
||||
assert updated_app.icon == update_args["icon"]
|
||||
assert updated_app.icon_background == update_args["icon_background"]
|
||||
assert updated_app.use_icon_as_answer_icon is True
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app name update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original name
|
||||
original_name = app.name
|
||||
|
||||
# Update app name
|
||||
new_name = "New App Name"
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_name(app, new_name)
|
||||
|
||||
assert updated_app.name == new_name
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app icon update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original values
|
||||
original_icon = app.icon
|
||||
original_icon_background = app.icon_background
|
||||
|
||||
# Update app icon
|
||||
new_icon = "🌟"
|
||||
new_icon_background = "#FFD93D"
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_icon(app, new_icon, new_icon_background)
|
||||
|
||||
assert updated_app.icon == new_icon
|
||||
assert updated_app.icon_background == new_icon_background
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app site status update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🌐",
|
||||
"icon_background": "#74B9FF",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original site status
|
||||
original_site_status = app.enable_site
|
||||
|
||||
# Update site status to disabled
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_site_status(app, False)
|
||||
assert updated_app.enable_site is False
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Update site status back to enabled
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_site_status(updated_app, True)
|
||||
assert updated_app.enable_site is True
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app API status update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔌",
|
||||
"icon_background": "#A29BFE",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original API status
|
||||
original_api_status = app.enable_api
|
||||
|
||||
# Update API status to disabled
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_api_status(app, False)
|
||||
assert updated_app.enable_api is False
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Update API status back to enabled
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_api_status(updated_app, True)
|
||||
assert updated_app.enable_api is True
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app site status update when status doesn't change.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔄",
|
||||
"icon_background": "#FD79A8",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original values
|
||||
original_site_status = app.enable_site
|
||||
original_updated_at = app.updated_at
|
||||
|
||||
# Update site status to the same value (no change)
|
||||
updated_app = app_service.update_app_site_status(app, original_site_status)
|
||||
|
||||
# Verify app is returned unchanged
|
||||
assert updated_app.id == app.id
|
||||
assert updated_app.enable_site == original_site_status
|
||||
assert updated_app.updated_at == original_updated_at
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app deletion.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🗑️",
|
||||
"icon_background": "#E17055",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store app ID for verification
|
||||
app_id = app.id
|
||||
|
||||
# Mock the async deletion task
|
||||
with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task:
|
||||
mock_delete_task.delay.return_value = None
|
||||
|
||||
# Delete app
|
||||
app_service.delete_app(app)
|
||||
|
||||
# Verify async deletion task was called
|
||||
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
|
||||
|
||||
# Verify app was deleted from database
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_app = db.session.query(App).filter_by(id=app_id).first()
|
||||
assert deleted_app is None
|
||||
|
||||
def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app deletion with related data cleanup.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🧹",
|
||||
"icon_background": "#00B894",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store app ID for verification
|
||||
app_id = app.id
|
||||
|
||||
# Mock webapp auth cleanup
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
].get_system_features.return_value.webapp_auth.enabled = True
|
||||
|
||||
# Mock the async deletion task
|
||||
with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task:
|
||||
mock_delete_task.delay.return_value = None
|
||||
|
||||
# Delete app
|
||||
app_service.delete_app(app)
|
||||
|
||||
# Verify webapp auth cleanup was called
|
||||
mock_external_service_dependencies["enterprise_service"].WebAppAuth.cleanup_webapp.assert_called_once_with(
|
||||
app_id
|
||||
)
|
||||
|
||||
# Verify async deletion task was called
|
||||
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
|
||||
|
||||
# Verify app was deleted from database
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_app = db.session.query(App).filter_by(id=app_id).first()
|
||||
assert deleted_app is None
|
||||
|
||||
def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app metadata retrieval.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "📊",
|
||||
"icon_background": "#6C5CE7",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get app metadata
|
||||
app_meta = app_service.get_app_meta(app)
|
||||
|
||||
# Verify metadata contains expected fields
|
||||
assert "tool_icons" in app_meta
|
||||
# Note: get_app_meta currently only returns tool_icons
|
||||
|
||||
def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app code retrieval by app ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔗",
|
||||
"icon_background": "#FDCB6E",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get app code by ID
|
||||
app_code = AppService.get_app_code_by_id(app.id)
|
||||
|
||||
# Verify app code was retrieved correctly
|
||||
# Note: Site would be created when App is created, site.code is auto-generated
|
||||
assert app_code is not None
|
||||
assert len(app_code) > 0
|
||||
|
||||
def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app ID retrieval by app code.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🆔",
|
||||
"icon_background": "#E84393",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Create a site for the app
|
||||
site = Site()
|
||||
site.app_id = app.id
|
||||
site.code = fake.postalcode()
|
||||
site.title = fake.company()
|
||||
site.status = "normal"
|
||||
site.default_language = "en-US"
|
||||
site.customize_token_strategy = "uuid"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(site)
|
||||
db.session.commit()
|
||||
|
||||
# Get app ID by code
|
||||
app_id = AppService.get_app_id_by_code(site.code)
|
||||
|
||||
# Verify app ID was retrieved correctly
|
||||
assert app_id == app.id
|
||||
|
||||
def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app creation with invalid mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments with invalid mode
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "invalid_mode", # Invalid mode
|
||||
"icon_type": "emoji",
|
||||
"icon": "❌",
|
||||
"icon_background": "#D63031",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Attempt to create app with invalid mode
|
||||
with pytest.raises(ValueError, match="invalid mode value"):
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
import pytest
|
||||
|
||||
from controllers.console.app.app import _validate_description_length as app_validate
|
||||
from controllers.console.datasets.datasets import _validate_description_length as dataset_validate
|
||||
from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate
|
||||
|
||||
|
||||
class TestDescriptionValidationUnit:
|
||||
"""Unit tests for description validation functions in App and Dataset APIs"""
|
||||
|
||||
def test_app_validate_description_length_valid(self):
|
||||
"""Test App validation function with valid descriptions"""
|
||||
# Empty string should be valid
|
||||
assert app_validate("") == ""
|
||||
|
||||
# None should be valid
|
||||
assert app_validate(None) is None
|
||||
|
||||
# Short description should be valid
|
||||
short_desc = "Short description"
|
||||
assert app_validate(short_desc) == short_desc
|
||||
|
||||
# Exactly 400 characters should be valid
|
||||
exactly_400 = "x" * 400
|
||||
assert app_validate(exactly_400) == exactly_400
|
||||
|
||||
# Just under limit should be valid
|
||||
just_under = "x" * 399
|
||||
assert app_validate(just_under) == just_under
|
||||
|
||||
def test_app_validate_description_length_invalid(self):
|
||||
"""Test App validation function with invalid descriptions"""
|
||||
# 401 characters should fail
|
||||
just_over = "x" * 401
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
app_validate(just_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
# 500 characters should fail
|
||||
way_over = "x" * 500
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
app_validate(way_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
# 1000 characters should fail
|
||||
very_long = "x" * 1000
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
app_validate(very_long)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
def test_dataset_validate_description_length_valid(self):
|
||||
"""Test Dataset validation function with valid descriptions"""
|
||||
# Empty string should be valid
|
||||
assert dataset_validate("") == ""
|
||||
|
||||
# Short description should be valid
|
||||
short_desc = "Short description"
|
||||
assert dataset_validate(short_desc) == short_desc
|
||||
|
||||
# Exactly 400 characters should be valid
|
||||
exactly_400 = "x" * 400
|
||||
assert dataset_validate(exactly_400) == exactly_400
|
||||
|
||||
# Just under limit should be valid
|
||||
just_under = "x" * 399
|
||||
assert dataset_validate(just_under) == just_under
|
||||
|
||||
def test_dataset_validate_description_length_invalid(self):
|
||||
"""Test Dataset validation function with invalid descriptions"""
|
||||
# 401 characters should fail
|
||||
just_over = "x" * 401
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
dataset_validate(just_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
# 500 characters should fail
|
||||
way_over = "x" * 500
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
dataset_validate(way_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
def test_service_dataset_validate_description_length_valid(self):
|
||||
"""Test Service Dataset validation function with valid descriptions"""
|
||||
# Empty string should be valid
|
||||
assert service_dataset_validate("") == ""
|
||||
|
||||
# None should be valid
|
||||
assert service_dataset_validate(None) is None
|
||||
|
||||
# Short description should be valid
|
||||
short_desc = "Short description"
|
||||
assert service_dataset_validate(short_desc) == short_desc
|
||||
|
||||
# Exactly 400 characters should be valid
|
||||
exactly_400 = "x" * 400
|
||||
assert service_dataset_validate(exactly_400) == exactly_400
|
||||
|
||||
# Just under limit should be valid
|
||||
just_under = "x" * 399
|
||||
assert service_dataset_validate(just_under) == just_under
|
||||
|
||||
def test_service_dataset_validate_description_length_invalid(self):
|
||||
"""Test Service Dataset validation function with invalid descriptions"""
|
||||
# 401 characters should fail
|
||||
just_over = "x" * 401
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service_dataset_validate(just_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
# 500 characters should fail
|
||||
way_over = "x" * 500
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service_dataset_validate(way_over)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
def test_app_dataset_validation_consistency(self):
|
||||
"""Test that App and Dataset validation functions behave identically"""
|
||||
test_cases = [
|
||||
"", # Empty string
|
||||
"Short description", # Normal description
|
||||
"x" * 100, # Medium description
|
||||
"x" * 400, # Exactly at limit
|
||||
]
|
||||
|
||||
# Test valid cases produce same results
|
||||
for test_desc in test_cases:
|
||||
assert app_validate(test_desc) == dataset_validate(test_desc) == service_dataset_validate(test_desc)
|
||||
|
||||
# Test invalid cases produce same errors
|
||||
invalid_cases = [
|
||||
"x" * 401, # Just over limit
|
||||
"x" * 500, # Way over limit
|
||||
"x" * 1000, # Very long
|
||||
]
|
||||
|
||||
for invalid_desc in invalid_cases:
|
||||
app_error = None
|
||||
dataset_error = None
|
||||
service_dataset_error = None
|
||||
|
||||
# Capture App validation error
|
||||
try:
|
||||
app_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
app_error = str(e)
|
||||
|
||||
# Capture Dataset validation error
|
||||
try:
|
||||
dataset_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
dataset_error = str(e)
|
||||
|
||||
# Capture Service Dataset validation error
|
||||
try:
|
||||
service_dataset_validate(invalid_desc)
|
||||
except ValueError as e:
|
||||
service_dataset_error = str(e)
|
||||
|
||||
# All should produce errors
|
||||
assert app_error is not None, f"App validation should fail for {len(invalid_desc)} characters"
|
||||
assert dataset_error is not None, f"Dataset validation should fail for {len(invalid_desc)} characters"
|
||||
error_msg = f"Service Dataset validation should fail for {len(invalid_desc)} characters"
|
||||
assert service_dataset_error is not None, error_msg
|
||||
|
||||
# Errors should be identical
|
||||
error_msg = f"Error messages should be identical for {len(invalid_desc)} characters"
|
||||
assert app_error == dataset_error == service_dataset_error, error_msg
|
||||
assert app_error == "Description cannot exceed 400 characters."
|
||||
|
||||
def test_boundary_values(self):
|
||||
"""Test boundary values around the 400 character limit"""
|
||||
boundary_tests = [
|
||||
(0, True), # Empty
|
||||
(1, True), # Minimum
|
||||
(399, True), # Just under limit
|
||||
(400, True), # Exactly at limit
|
||||
(401, False), # Just over limit
|
||||
(402, False), # Over limit
|
||||
(500, False), # Way over limit
|
||||
]
|
||||
|
||||
for length, should_pass in boundary_tests:
|
||||
test_desc = "x" * length
|
||||
|
||||
if should_pass:
|
||||
# Should not raise exception
|
||||
assert app_validate(test_desc) == test_desc
|
||||
assert dataset_validate(test_desc) == test_desc
|
||||
assert service_dataset_validate(test_desc) == test_desc
|
||||
else:
|
||||
# Should raise ValueError
|
||||
with pytest.raises(ValueError):
|
||||
app_validate(test_desc)
|
||||
with pytest.raises(ValueError):
|
||||
dataset_validate(test_desc)
|
||||
with pytest.raises(ValueError):
|
||||
service_dataset_validate(test_desc)
|
||||
|
||||
def test_special_characters(self):
|
||||
"""Test validation with special characters, Unicode, etc."""
|
||||
# Unicode characters
|
||||
unicode_desc = "测试描述" * 100 # Chinese characters
|
||||
if len(unicode_desc) <= 400:
|
||||
assert app_validate(unicode_desc) == unicode_desc
|
||||
assert dataset_validate(unicode_desc) == unicode_desc
|
||||
assert service_dataset_validate(unicode_desc) == unicode_desc
|
||||
|
||||
# Special characters
|
||||
special_desc = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?" * 10
|
||||
if len(special_desc) <= 400:
|
||||
assert app_validate(special_desc) == special_desc
|
||||
assert dataset_validate(special_desc) == special_desc
|
||||
assert service_dataset_validate(special_desc) == special_desc
|
||||
|
||||
# Mixed content
|
||||
mixed_desc = "Mixed content: 测试 123 !@# " * 15
|
||||
if len(mixed_desc) <= 400:
|
||||
assert app_validate(mixed_desc) == mixed_desc
|
||||
assert dataset_validate(mixed_desc) == mixed_desc
|
||||
assert service_dataset_validate(mixed_desc) == mixed_desc
|
||||
elif len(mixed_desc) > 400:
|
||||
with pytest.raises(ValueError):
|
||||
app_validate(mixed_desc)
|
||||
with pytest.raises(ValueError):
|
||||
dataset_validate(mixed_desc)
|
||||
with pytest.raises(ValueError):
|
||||
service_dataset_validate(mixed_desc)
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test validation with various whitespace scenarios"""
|
||||
# Leading/trailing whitespace
|
||||
whitespace_desc = " Description with whitespace "
|
||||
if len(whitespace_desc) <= 400:
|
||||
assert app_validate(whitespace_desc) == whitespace_desc
|
||||
assert dataset_validate(whitespace_desc) == whitespace_desc
|
||||
assert service_dataset_validate(whitespace_desc) == whitespace_desc
|
||||
|
||||
# Newlines and tabs
|
||||
multiline_desc = "Line 1\nLine 2\tTabbed content"
|
||||
if len(multiline_desc) <= 400:
|
||||
assert app_validate(multiline_desc) == multiline_desc
|
||||
assert dataset_validate(multiline_desc) == multiline_desc
|
||||
assert service_dataset_validate(multiline_desc) == multiline_desc
|
||||
|
||||
# Only whitespace over limit
|
||||
only_spaces = " " * 401
|
||||
with pytest.raises(ValueError):
|
||||
app_validate(only_spaces)
|
||||
with pytest.raises(ValueError):
|
||||
dataset_validate(only_spaces)
|
||||
with pytest.raises(ValueError):
|
||||
service_dataset_validate(only_spaces)
|
||||
|
|
@ -0,0 +1,336 @@
|
|||
"""
|
||||
Unit tests for Service API File Preview endpoint
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.service_api.app.error import FileAccessDeniedError, FileNotFoundError
|
||||
from controllers.service_api.app.file_preview import FilePreviewApi
|
||||
from models.model import App, EndUser, Message, MessageFile, UploadFile
|
||||
|
||||
|
||||
class TestFilePreviewApi:
|
||||
"""Test suite for FilePreviewApi"""
|
||||
|
||||
@pytest.fixture
|
||||
def file_preview_api(self):
|
||||
"""Create FilePreviewApi instance for testing"""
|
||||
return FilePreviewApi()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Mock App model"""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user(self):
|
||||
"""Mock EndUser model"""
|
||||
end_user = Mock(spec=EndUser)
|
||||
end_user.id = str(uuid.uuid4())
|
||||
return end_user
|
||||
|
||||
@pytest.fixture
|
||||
def mock_upload_file(self):
|
||||
"""Mock UploadFile model"""
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.id = str(uuid.uuid4())
|
||||
upload_file.name = "test_file.jpg"
|
||||
upload_file.mime_type = "image/jpeg"
|
||||
upload_file.size = 1024
|
||||
upload_file.key = "storage/key/test_file.jpg"
|
||||
upload_file.tenant_id = str(uuid.uuid4())
|
||||
return upload_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file(self):
|
||||
"""Mock MessageFile model"""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.upload_file_id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Mock Message model"""
|
||||
message = Mock(spec=Message)
|
||||
message.id = str(uuid.uuid4())
|
||||
message.app_id = str(uuid.uuid4())
|
||||
return message
|
||||
|
||||
def test_validate_file_ownership_success(
|
||||
self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message
|
||||
):
|
||||
"""Test successful file ownership validation"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = mock_app.id
|
||||
|
||||
# Set up the mocks
|
||||
mock_upload_file.tenant_id = mock_app.tenant_id
|
||||
mock_message.app_id = app_id
|
||||
mock_message_file.upload_file_id = file_id
|
||||
mock_message_file.message_id = mock_message.id
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
|
||||
# Execute the method
|
||||
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
# Assertions
|
||||
assert result_message_file == mock_message_file
|
||||
assert result_upload_file == mock_upload_file
|
||||
|
||||
def test_validate_file_ownership_file_not_found(self, file_preview_api):
|
||||
"""Test file ownership validation when MessageFile not found"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock MessageFile not found
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileNotFoundError) as exc_info:
|
||||
file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
assert "File not found in message context" in str(exc_info.value)
|
||||
|
||||
def test_validate_file_ownership_access_denied(self, file_preview_api, mock_message_file):
|
||||
"""Test file ownership validation when Message not owned by app"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock MessageFile found but Message not owned by app
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query - found
|
||||
None, # Message query - not found (access denied)
|
||||
]
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileAccessDeniedError) as exc_info:
|
||||
file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
assert "not owned by requesting app" in str(exc_info.value)
|
||||
|
||||
def test_validate_file_ownership_upload_file_not_found(self, file_preview_api, mock_message_file, mock_message):
|
||||
"""Test file ownership validation when UploadFile not found"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock MessageFile and Message found but UploadFile not found
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query - found
|
||||
mock_message, # Message query - found
|
||||
None, # UploadFile query - not found
|
||||
]
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileNotFoundError) as exc_info:
|
||||
file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
assert "Upload file record not found" in str(exc_info.value)
|
||||
|
||||
def test_validate_file_ownership_tenant_mismatch(
|
||||
self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message
|
||||
):
|
||||
"""Test file ownership validation with tenant mismatch"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = mock_app.id
|
||||
|
||||
# Set up tenant mismatch
|
||||
mock_upload_file.tenant_id = "different_tenant_id"
|
||||
mock_app.tenant_id = "app_tenant_id"
|
||||
mock_message.app_id = app_id
|
||||
mock_message_file.upload_file_id = file_id
|
||||
mock_message_file.message_id = mock_message.id
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileAccessDeniedError) as exc_info:
|
||||
file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
assert "tenant mismatch" in str(exc_info.value)
|
||||
|
||||
def test_validate_file_ownership_invalid_input(self, file_preview_api):
|
||||
"""Test file ownership validation with invalid input"""
|
||||
|
||||
# Test with empty file_id
|
||||
with pytest.raises(FileAccessDeniedError) as exc_info:
|
||||
file_preview_api._validate_file_ownership("", "app_id")
|
||||
assert "Invalid file or app identifier" in str(exc_info.value)
|
||||
|
||||
# Test with empty app_id
|
||||
with pytest.raises(FileAccessDeniedError) as exc_info:
|
||||
file_preview_api._validate_file_ownership("file_id", "")
|
||||
assert "Invalid file or app identifier" in str(exc_info.value)
|
||||
|
||||
def test_build_file_response_basic(self, file_preview_api, mock_upload_file):
|
||||
"""Test basic file response building"""
|
||||
mock_generator = Mock()
|
||||
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||
|
||||
# Check response properties
|
||||
assert response.mimetype == mock_upload_file.mime_type
|
||||
assert response.direct_passthrough is True
|
||||
assert response.headers["Content-Length"] == str(mock_upload_file.size)
|
||||
assert "Cache-Control" in response.headers
|
||||
|
||||
def test_build_file_response_as_attachment(self, file_preview_api, mock_upload_file):
|
||||
"""Test file response building with attachment flag"""
|
||||
mock_generator = Mock()
|
||||
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, True)
|
||||
|
||||
# Check attachment-specific headers
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
assert mock_upload_file.name in response.headers["Content-Disposition"]
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
|
||||
def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file):
|
||||
"""Test file response building for audio/video files"""
|
||||
mock_generator = Mock()
|
||||
mock_upload_file.mime_type = "video/mp4"
|
||||
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||
|
||||
# Check Range support for media files
|
||||
assert response.headers["Accept-Ranges"] == "bytes"
|
||||
|
||||
def test_build_file_response_no_size(self, file_preview_api, mock_upload_file):
|
||||
"""Test file response building when size is unknown"""
|
||||
mock_generator = Mock()
|
||||
mock_upload_file.size = 0 # Unknown size
|
||||
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||
|
||||
# Content-Length should not be set when size is unknown
|
||||
assert "Content-Length" not in response.headers
|
||||
|
||||
@patch("controllers.service_api.app.file_preview.storage")
|
||||
def test_get_method_integration(
|
||||
self, mock_storage, file_preview_api, mock_app, mock_end_user, mock_upload_file, mock_message_file, mock_message
|
||||
):
|
||||
"""Test the full GET method integration (without decorator)"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = mock_app.id
|
||||
|
||||
# Set up mocks
|
||||
mock_upload_file.tenant_id = mock_app.tenant_id
|
||||
mock_message.app_id = app_id
|
||||
mock_message_file.upload_file_id = file_id
|
||||
mock_message_file.message_id = mock_message.id
|
||||
|
||||
mock_generator = Mock()
|
||||
mock_storage.load.return_value = mock_generator
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse:
|
||||
# Mock request parsing
|
||||
mock_parser = Mock()
|
||||
mock_parser.parse_args.return_value = {"as_attachment": False}
|
||||
mock_reqparse.RequestParser.return_value = mock_parser
|
||||
|
||||
# Test the core logic directly without Flask decorators
|
||||
# Validate file ownership
|
||||
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
assert result_message_file == mock_message_file
|
||||
assert result_upload_file == mock_upload_file
|
||||
|
||||
# Test file response building
|
||||
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||
assert response is not None
|
||||
|
||||
# Verify storage was called correctly
|
||||
mock_storage.load.assert_not_called() # Since we're testing components separately
|
||||
|
||||
@patch("controllers.service_api.app.file_preview.storage")
|
||||
def test_storage_error_handling(
|
||||
self, mock_storage, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message
|
||||
):
|
||||
"""Test storage error handling in the core logic"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = mock_app.id
|
||||
|
||||
# Set up mocks
|
||||
mock_upload_file.tenant_id = mock_app.tenant_id
|
||||
mock_message.app_id = app_id
|
||||
mock_message_file.upload_file_id = file_id
|
||||
mock_message_file.message_id = mock_message.id
|
||||
|
||||
# Mock storage error
|
||||
mock_storage.load.side_effect = Exception("Storage error")
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries for validation
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
|
||||
# First validate file ownership works
|
||||
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
assert result_message_file == mock_message_file
|
||||
assert result_upload_file == mock_upload_file
|
||||
|
||||
# Test storage error handling
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
mock_storage.load(mock_upload_file.key, stream=True)
|
||||
|
||||
assert "Storage error" in str(exc_info.value)
|
||||
|
||||
@patch("controllers.service_api.app.file_preview.logger")
|
||||
def test_validate_file_ownership_unexpected_error_logging(self, mock_logger, file_preview_api):
|
||||
"""Test that unexpected errors are logged properly"""
|
||||
file_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database query to raise unexpected exception
|
||||
mock_db.session.query.side_effect = Exception("Unexpected database error")
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileAccessDeniedError) as exc_info:
|
||||
file_preview_api._validate_file_ownership(file_id, app_id)
|
||||
|
||||
# Verify error message
|
||||
assert "File access validation failed" in str(exc_info.value)
|
||||
|
||||
# Verify logging was called
|
||||
mock_logger.exception.assert_called_once_with(
|
||||
"Unexpected error during file ownership validation",
|
||||
extra={"file_id": file_id, "app_id": app_id, "error": "Unexpected database error"},
|
||||
)
|
||||
|
|
@ -0,0 +1,419 @@
|
|||
"""Test conversation variable handling in AdvancedChatAppRunner."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.variables import SegmentType
|
||||
from factories import variable_factory
|
||||
from models import ConversationVariable, Workflow
|
||||
|
||||
|
||||
class TestAdvancedChatAppRunnerConversationVariables:
|
||||
"""Test that AdvancedChatAppRunner correctly handles conversation variables."""
|
||||
|
||||
def test_missing_conversation_variables_are_added(self):
|
||||
"""Test that new conversation variables added to workflow are created for existing conversations."""
|
||||
# Setup
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
|
||||
# Create workflow with two conversation variables
|
||||
workflow_vars = [
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var1",
|
||||
"name": "existing_var",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default1",
|
||||
}
|
||||
),
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var2",
|
||||
"name": "new_var",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default2",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow with conversation variables
|
||||
mock_workflow = MagicMock(spec=Workflow)
|
||||
mock_workflow.conversation_variables = workflow_vars
|
||||
mock_workflow.tenant_id = str(uuid4())
|
||||
mock_workflow.app_id = app_id
|
||||
mock_workflow.id = workflow_id
|
||||
mock_workflow.type = "chat"
|
||||
mock_workflow.graph_dict = {}
|
||||
mock_workflow.environment_variables = []
|
||||
|
||||
# Create existing conversation variable (only var1 exists in DB)
|
||||
existing_db_var = MagicMock(spec=ConversationVariable)
|
||||
existing_db_var.id = "var1"
|
||||
existing_db_var.app_id = app_id
|
||||
existing_db_var.conversation_id = conversation_id
|
||||
existing_db_var.to_variable = MagicMock(return_value=workflow_vars[0])
|
||||
|
||||
# Mock conversation and message
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.app_id = app_id
|
||||
mock_conversation.id = conversation_id
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
|
||||
# Mock app config
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app_id = app_id
|
||||
mock_app_config.workflow_id = workflow_id
|
||||
mock_app_config.tenant_id = str(uuid4())
|
||||
|
||||
# Mock app generate entity
|
||||
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||
mock_app_generate_entity.app_config = mock_app_config
|
||||
mock_app_generate_entity.inputs = {}
|
||||
mock_app_generate_entity.query = "test query"
|
||||
mock_app_generate_entity.files = []
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
mock_app_generate_entity.trace_manager = None
|
||||
|
||||
# Create runner
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=mock_app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
dialogue_count=1,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_session = MagicMock(spec=Session)
|
||||
|
||||
# First query returns only existing variable
|
||||
mock_scalars_result = MagicMock()
|
||||
mock_scalars_result.all.return_value = [existing_db_var]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Track what gets added to session
|
||||
added_items = []
|
||||
|
||||
def track_add_all(items):
|
||||
added_items.extend(items)
|
||||
|
||||
mock_session.add_all.side_effect = track_add_all
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
# Mock workflow entry
|
||||
mock_workflow_entry = MagicMock()
|
||||
mock_workflow_entry.run.return_value = iter([]) # Empty generator
|
||||
mock_workflow_entry_class.return_value = mock_workflow_entry
|
||||
|
||||
# Run the method
|
||||
runner.run()
|
||||
|
||||
# Verify that the missing variable was added
|
||||
assert len(added_items) == 1, "Should have added exactly one missing variable"
|
||||
|
||||
# Check that the added item is the missing variable (var2)
|
||||
added_var = added_items[0]
|
||||
assert hasattr(added_var, "id"), "Added item should be a ConversationVariable"
|
||||
# Note: Since we're mocking ConversationVariable.from_variable,
|
||||
# we can't directly check the id, but we can verify add_all was called
|
||||
assert mock_session.add_all.called, "Session add_all should have been called"
|
||||
assert mock_session.commit.called, "Session commit should have been called"
|
||||
|
||||
def test_no_variables_creates_all(self):
|
||||
"""Test that all conversation variables are created when none exist in DB."""
|
||||
# Setup
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
|
||||
# Create workflow with conversation variables
|
||||
workflow_vars = [
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var1",
|
||||
"name": "var1",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default1",
|
||||
}
|
||||
),
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var2",
|
||||
"name": "var2",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default2",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow
|
||||
mock_workflow = MagicMock(spec=Workflow)
|
||||
mock_workflow.conversation_variables = workflow_vars
|
||||
mock_workflow.tenant_id = str(uuid4())
|
||||
mock_workflow.app_id = app_id
|
||||
mock_workflow.id = workflow_id
|
||||
mock_workflow.type = "chat"
|
||||
mock_workflow.graph_dict = {}
|
||||
mock_workflow.environment_variables = []
|
||||
|
||||
# Mock conversation and message
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.app_id = app_id
|
||||
mock_conversation.id = conversation_id
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
|
||||
# Mock app config
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app_id = app_id
|
||||
mock_app_config.workflow_id = workflow_id
|
||||
mock_app_config.tenant_id = str(uuid4())
|
||||
|
||||
# Mock app generate entity
|
||||
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||
mock_app_generate_entity.app_config = mock_app_config
|
||||
mock_app_generate_entity.inputs = {}
|
||||
mock_app_generate_entity.query = "test query"
|
||||
mock_app_generate_entity.files = []
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
mock_app_generate_entity.trace_manager = None
|
||||
|
||||
# Create runner
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=mock_app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
dialogue_count=1,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_session = MagicMock(spec=Session)
|
||||
|
||||
# Query returns empty list (no existing variables)
|
||||
mock_scalars_result = MagicMock()
|
||||
mock_scalars_result.all.return_value = []
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Track what gets added to session
|
||||
added_items = []
|
||||
|
||||
def track_add_all(items):
|
||||
added_items.extend(items)
|
||||
|
||||
mock_session.add_all.side_effect = track_add_all
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.ConversationVariable") as mock_conv_var_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock ConversationVariable.from_variable to return mock objects
|
||||
mock_conv_vars = []
|
||||
for var in workflow_vars:
|
||||
mock_cv = MagicMock()
|
||||
mock_cv.id = var.id
|
||||
mock_cv.to_variable.return_value = var
|
||||
mock_conv_vars.append(mock_cv)
|
||||
|
||||
mock_conv_var_class.from_variable.side_effect = mock_conv_vars
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
# Mock workflow entry
|
||||
mock_workflow_entry = MagicMock()
|
||||
mock_workflow_entry.run.return_value = iter([]) # Empty generator
|
||||
mock_workflow_entry_class.return_value = mock_workflow_entry
|
||||
|
||||
# Run the method
|
||||
runner.run()
|
||||
|
||||
# Verify that all variables were created
|
||||
assert len(added_items) == 2, "Should have added both variables"
|
||||
assert mock_session.add_all.called, "Session add_all should have been called"
|
||||
assert mock_session.commit.called, "Session commit should have been called"
|
||||
|
||||
def test_all_variables_exist_no_changes(self):
|
||||
"""Test that no changes are made when all variables already exist in DB."""
|
||||
# Setup
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
|
||||
# Create workflow with conversation variables
|
||||
workflow_vars = [
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var1",
|
||||
"name": "var1",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default1",
|
||||
}
|
||||
),
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var2",
|
||||
"name": "var2",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default2",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow
|
||||
mock_workflow = MagicMock(spec=Workflow)
|
||||
mock_workflow.conversation_variables = workflow_vars
|
||||
mock_workflow.tenant_id = str(uuid4())
|
||||
mock_workflow.app_id = app_id
|
||||
mock_workflow.id = workflow_id
|
||||
mock_workflow.type = "chat"
|
||||
mock_workflow.graph_dict = {}
|
||||
mock_workflow.environment_variables = []
|
||||
|
||||
# Create existing conversation variables (both exist in DB)
|
||||
existing_db_vars = []
|
||||
for var in workflow_vars:
|
||||
db_var = MagicMock(spec=ConversationVariable)
|
||||
db_var.id = var.id
|
||||
db_var.app_id = app_id
|
||||
db_var.conversation_id = conversation_id
|
||||
db_var.to_variable = MagicMock(return_value=var)
|
||||
existing_db_vars.append(db_var)
|
||||
|
||||
# Mock conversation and message
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.app_id = app_id
|
||||
mock_conversation.id = conversation_id
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
|
||||
# Mock app config
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app_id = app_id
|
||||
mock_app_config.workflow_id = workflow_id
|
||||
mock_app_config.tenant_id = str(uuid4())
|
||||
|
||||
# Mock app generate entity
|
||||
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||
mock_app_generate_entity.app_config = mock_app_config
|
||||
mock_app_generate_entity.inputs = {}
|
||||
mock_app_generate_entity.query = "test query"
|
||||
mock_app_generate_entity.files = []
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
mock_app_generate_entity.trace_manager = None
|
||||
|
||||
# Create runner
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=mock_app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
dialogue_count=1,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_session = MagicMock(spec=Session)
|
||||
|
||||
# Query returns all existing variables
|
||||
mock_scalars_result = MagicMock()
|
||||
mock_scalars_result.all.return_value = existing_db_vars
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
# Mock workflow entry
|
||||
mock_workflow_entry = MagicMock()
|
||||
mock_workflow_entry.run.return_value = iter([]) # Empty generator
|
||||
mock_workflow_entry_class.return_value = mock_workflow_entry
|
||||
|
||||
# Run the method
|
||||
runner.run()
|
||||
|
||||
# Verify that no variables were added
|
||||
assert not mock_session.add_all.called, "Session add_all should not have been called"
|
||||
assert mock_session.commit.called, "Session commit should still be called"
|
||||
|
|
@ -49,7 +49,7 @@ def test_executor_with_json_body_and_number_variable():
|
|||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == []
|
||||
assert executor.params is None
|
||||
assert executor.json == {"number": 42}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
|
|
@ -102,7 +102,7 @@ def test_executor_with_json_body_and_object_variable():
|
|||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == []
|
||||
assert executor.params is None
|
||||
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
|
|
@ -157,7 +157,7 @@ def test_executor_with_json_body_and_nested_object_variable():
|
|||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == []
|
||||
assert executor.params is None
|
||||
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
|
|
@ -245,7 +245,7 @@ def test_executor_with_form_data():
|
|||
assert executor.url == "https://api.example.com/upload"
|
||||
assert "Content-Type" in executor.headers
|
||||
assert "multipart/form-data" in executor.headers["Content-Type"]
|
||||
assert executor.params == []
|
||||
assert executor.params is None
|
||||
assert executor.json is None
|
||||
# '__multipart_placeholder__' is expected when no file inputs exist,
|
||||
# to ensure the request is treated as multipart/form-data by the backend.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,127 @@
|
|||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class TestConversationService:
|
||||
def test_pagination_with_empty_include_ids(self):
|
||||
"""Test that empty include_ids returns empty result"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=[], # Empty include_ids should return empty result
|
||||
exclude_ids=None,
|
||||
)
|
||||
|
||||
assert result.data == []
|
||||
assert result.has_more is False
|
||||
assert result.limit == 20
|
||||
|
||||
def test_pagination_with_non_empty_include_ids(self):
|
||||
"""Test that non-empty include_ids filters properly"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
# Mock the query results
|
||||
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_conversations
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
with patch("services.conversation_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_stmt.subquery.return_value = MagicMock()
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=["conv1", "conv2"], # Non-empty include_ids
|
||||
exclude_ids=None,
|
||||
)
|
||||
|
||||
# Verify the where clause was called with id.in_
|
||||
assert mock_stmt.where.called
|
||||
|
||||
def test_pagination_with_empty_exclude_ids(self):
|
||||
"""Test that empty exclude_ids doesn't filter"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
# Mock the query results
|
||||
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(5)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_conversations
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
with patch("services.conversation_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_stmt.subquery.return_value = MagicMock()
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=None,
|
||||
exclude_ids=[], # Empty exclude_ids should not filter
|
||||
)
|
||||
|
||||
# Result should contain the mocked conversations
|
||||
assert len(result.data) == 5
|
||||
|
||||
def test_pagination_with_non_empty_exclude_ids(self):
|
||||
"""Test that non-empty exclude_ids filters properly"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
# Mock the query results
|
||||
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_conversations
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
with patch("services.conversation_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_stmt.subquery.return_value = MagicMock()
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=None,
|
||||
exclude_ids=["conv1", "conv2"], # Non-empty exclude_ids
|
||||
)
|
||||
|
||||
# Verify the where clause was called for exclusion
|
||||
assert mock_stmt.where.called
|
||||
58
api/uv.lock
58
api/uv.lock
|
|
@ -983,6 +983,25 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/42/1f/935d0810b73184a1d306f92458cb0a2e9b0de2377f536da874e063b8e422/clickhouse_connect-0.7.19-cp312-cp312-win_amd64.whl", hash = "sha256:b771ca6a473d65103dcae82810d3a62475c5372fc38d8f211513c72b954fb020", size = 239584, upload-time = "2024-08-21T21:36:22.105Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clickzetta-connector-python"
|
||||
version = "0.8.102"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "future" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pandas" },
|
||||
{ name = "pyarrow" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "requests" },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/e5/23dcc950e873127df0135cf45144062a3207f5d2067259c73854e8ce7228/clickzetta_connector_python-0.8.102-py3-none-any.whl", hash = "sha256:c45486ae77fd82df7113ec67ec50e772372588d79c23757f8ee6291a057994a7", size = 77861, upload-time = "2025-07-17T03:11:59.543Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cloudscraper"
|
||||
version = "1.2.71"
|
||||
|
|
@ -1383,6 +1402,7 @@ vdb = [
|
|||
{ name = "alibabacloud-tea-openapi" },
|
||||
{ name = "chromadb" },
|
||||
{ name = "clickhouse-connect" },
|
||||
{ name = "clickzetta-connector-python" },
|
||||
{ name = "couchbase" },
|
||||
{ name = "elasticsearch" },
|
||||
{ name = "mo-vector" },
|
||||
|
|
@ -1568,6 +1588,7 @@ vdb = [
|
|||
{ name = "alibabacloud-tea-openapi", specifier = "~=0.3.9" },
|
||||
{ name = "chromadb", specifier = "==0.5.20" },
|
||||
{ name = "clickhouse-connect", specifier = "~=0.7.16" },
|
||||
{ name = "clickzetta-connector-python", specifier = ">=0.8.102" },
|
||||
{ name = "couchbase", specifier = "~=4.3.0" },
|
||||
{ name = "elasticsearch", specifier = "==8.14.0" },
|
||||
{ name = "mo-vector", specifier = "~=0.1.13" },
|
||||
|
|
@ -2111,7 +2132,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "google-cloud-bigquery"
|
||||
version = "3.34.0"
|
||||
version = "3.30.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "google-api-core", extra = ["grpc"] },
|
||||
|
|
@ -2122,9 +2143,9 @@ dependencies = [
|
|||
{ name = "python-dateutil" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/24/f9/e9da2d56d7028f05c0e2f5edf6ce43c773220c3172666c3dd925791d763d/google_cloud_bigquery-3.34.0.tar.gz", hash = "sha256:5ee1a78ba5c2ccb9f9a8b2bf3ed76b378ea68f49b6cac0544dc55cc97ff7c1ce", size = 489091, upload-time = "2025-05-29T17:18:06.03Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/2f/3dda76b3ec029578838b1fe6396e6b86eb574200352240e23dea49265bb7/google_cloud_bigquery-3.30.0.tar.gz", hash = "sha256:7e27fbafc8ed33cc200fe05af12ecd74d279fe3da6692585a3cef7aee90575b6", size = 474389, upload-time = "2025-02-27T18:49:45.416Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/7e/7115c4f67ca0bc678f25bff1eab56cc37d06eb9a3978940b2ebd0705aa0a/google_cloud_bigquery-3.34.0-py3-none-any.whl", hash = "sha256:de20ded0680f8136d92ff5256270b5920dfe4fae479f5d0f73e90e5df30b1cf7", size = 253555, upload-time = "2025-05-29T17:18:02.904Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/6d/856a6ca55c1d9d99129786c929a27dd9d31992628ebbff7f5d333352981f/google_cloud_bigquery-3.30.0-py2.py3-none-any.whl", hash = "sha256:f4d28d846a727f20569c9b2d2f4fa703242daadcb2ec4240905aa485ba461877", size = 247885, upload-time = "2025-02-27T18:49:43.454Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3918,11 +3939,11 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "24.2"
|
||||
version = "23.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950, upload-time = "2024-11-08T09:47:47.202Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fb/2b/9b9c33ffed44ee921d0967086d653047286054117d584f1b1a7c22ceaf7b/packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5", size = 146714, upload-time = "2023-10-01T13:50:05.279Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451, upload-time = "2024-11-08T09:47:44.722Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/1a/610693ac4ee14fcdf2d9bf3c493370e4f2ef7ae2e19217d7a237ff42367d/packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7", size = 53011, upload-time = "2023-10-01T13:50:03.745Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -4302,6 +4323,31 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "14.0.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d7/8b/d18b7eb6fb22e5ed6ffcbc073c85dae635778dbd1270a6cf5d750b031e84/pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025", size = 1063645, upload-time = "2023-12-18T15:43:41.625Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/94/8a/411ef0b05483076b7f548c74ccaa0f90c1e60d3875db71a821f6ffa8cf42/pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b", size = 26904455, upload-time = "2023-12-18T15:40:43.477Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/6c/882a57798877e3a49ba54d8e0540bea24aed78fb42e1d860f08c3449c75e/pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23", size = 23997116, upload-time = "2023-12-18T15:40:48.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/3f/ef47fe6192ce4d82803a073db449b5292135406c364a7fc49dfbcd34c987/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200", size = 35944575, upload-time = "2023-12-18T15:40:55.128Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1a/90/2021e529d7f234a3909f419d4341d53382541ef77d957fa274a99c533b18/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696", size = 38079719, upload-time = "2023-12-18T15:41:02.565Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/a9/474caf5fd54a6d5315aaf9284c6e8f5d071ca825325ad64c53137b646e1f/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a", size = 35429706, upload-time = "2023-12-18T15:41:09.955Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/f8/cfba56f5353e51c19b0c240380ce39483f4c76e5c4aee5a000f3d75b72da/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02", size = 38001476, upload-time = "2023-12-18T15:41:16.372Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/3f/7bdf7dc3b3b0cfdcc60760e7880954ba99ccd0bc1e0df806f3dd61bc01cd/pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b", size = 24576230, upload-time = "2023-12-18T15:41:22.561Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/69/5b/d8ab6c20c43b598228710e4e4a6cba03a01f6faa3d08afff9ce76fd0fd47/pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944", size = 26819585, upload-time = "2023-12-18T15:41:27.59Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2d/29/bed2643d0dd5e9570405244a61f6db66c7f4704a6e9ce313f84fa5a3675a/pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5", size = 23965222, upload-time = "2023-12-18T15:41:32.449Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/34/da464632e59a8cdd083370d69e6c14eae30221acb284f671c6bc9273fadd/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422", size = 35942036, upload-time = "2023-12-18T15:41:38.767Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/ff/cbed4836d543b29f00d2355af67575c934999ff1d43e3f438ab0b1b394f1/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07", size = 38089266, upload-time = "2023-12-18T15:41:47.617Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/41/345011cb831d3dbb2dab762fc244c745a5df94b199223a99af52a5f7dff6/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591", size = 35404468, upload-time = "2023-12-18T15:41:54.49Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/af/2fc23ca2068ff02068d8dabf0fb85b6185df40ec825973470e613dbd8790/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379", size = 38003134, upload-time = "2023-12-18T15:42:01.593Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/1f/9d912f66a87e3864f694e000977a6a70a644ea560289eac1d733983f215d/pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d", size = 25043754, upload-time = "2023-12-18T15:42:07.108Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.1"
|
||||
|
|
|
|||
|
|
@ -333,6 +333,25 @@ OPENDAL_SCHEME=fs
|
|||
# Configurations for OpenDAL Local File System.
|
||||
OPENDAL_FS_ROOT=storage
|
||||
|
||||
# ClickZetta Volume Configuration (for storage backend)
|
||||
# To use ClickZetta Volume as storage backend, set STORAGE_TYPE=clickzetta-volume
|
||||
# Note: ClickZetta Volume will reuse the existing CLICKZETTA_* connection parameters
|
||||
|
||||
# Volume type selection (three types available):
|
||||
# - user: Personal/small team use, simple config, user-level permissions
|
||||
# - table: Enterprise multi-tenant, smart routing, table-level + user-level permissions
|
||||
# - external: Data lake integration, external storage connection, volume-level + storage-level permissions
|
||||
CLICKZETTA_VOLUME_TYPE=user
|
||||
|
||||
# External Volume name (required only when TYPE=external)
|
||||
CLICKZETTA_VOLUME_NAME=
|
||||
|
||||
# Table Volume table prefix (used only when TYPE=table)
|
||||
CLICKZETTA_VOLUME_TABLE_PREFIX=dataset_
|
||||
|
||||
# Dify file directory prefix (isolates from other apps, recommended to keep default)
|
||||
CLICKZETTA_VOLUME_DIFY_PREFIX=dify_km
|
||||
|
||||
# S3 Configuration
|
||||
#
|
||||
S3_ENDPOINT=
|
||||
|
|
@ -416,7 +435,7 @@ SUPABASE_URL=your-server-url
|
|||
# ------------------------------
|
||||
|
||||
# The type of vector store to use.
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
|
|
@ -655,6 +674,20 @@ TABLESTORE_ACCESS_KEY_ID=xxx
|
|||
TABLESTORE_ACCESS_KEY_SECRET=xxx
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false
|
||||
|
||||
# Clickzetta configuration, only available when VECTOR_STORE is `clickzetta`
|
||||
CLICKZETTA_USERNAME=
|
||||
CLICKZETTA_PASSWORD=
|
||||
CLICKZETTA_INSTANCE=
|
||||
CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
CLICKZETTA_WORKSPACE=quick_start
|
||||
CLICKZETTA_VCLUSTER=default_ap
|
||||
CLICKZETTA_SCHEMA=dify
|
||||
CLICKZETTA_BATCH_SIZE=100
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX=true
|
||||
CLICKZETTA_ANALYZER_TYPE=chinese
|
||||
CLICKZETTA_ANALYZER_MODE=smart
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance
|
||||
|
||||
# ------------------------------
|
||||
# Knowledge Configuration
|
||||
# ------------------------------
|
||||
|
|
|
|||
|
|
@ -93,6 +93,10 @@ x-shared-env: &shared-api-worker-env
|
|||
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
|
||||
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
|
||||
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
|
||||
CLICKZETTA_VOLUME_TYPE: ${CLICKZETTA_VOLUME_TYPE:-user}
|
||||
CLICKZETTA_VOLUME_NAME: ${CLICKZETTA_VOLUME_NAME:-}
|
||||
CLICKZETTA_VOLUME_TABLE_PREFIX: ${CLICKZETTA_VOLUME_TABLE_PREFIX:-dataset_}
|
||||
CLICKZETTA_VOLUME_DIFY_PREFIX: ${CLICKZETTA_VOLUME_DIFY_PREFIX:-dify_km}
|
||||
S3_ENDPOINT: ${S3_ENDPOINT:-}
|
||||
S3_REGION: ${S3_REGION:-us-east-1}
|
||||
S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai}
|
||||
|
|
@ -313,6 +317,18 @@ x-shared-env: &shared-api-worker-env
|
|||
TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx}
|
||||
TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx}
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false}
|
||||
CLICKZETTA_USERNAME: ${CLICKZETTA_USERNAME:-}
|
||||
CLICKZETTA_PASSWORD: ${CLICKZETTA_PASSWORD:-}
|
||||
CLICKZETTA_INSTANCE: ${CLICKZETTA_INSTANCE:-}
|
||||
CLICKZETTA_SERVICE: ${CLICKZETTA_SERVICE:-api.clickzetta.com}
|
||||
CLICKZETTA_WORKSPACE: ${CLICKZETTA_WORKSPACE:-quick_start}
|
||||
CLICKZETTA_VCLUSTER: ${CLICKZETTA_VCLUSTER:-default_ap}
|
||||
CLICKZETTA_SCHEMA: ${CLICKZETTA_SCHEMA:-dify}
|
||||
CLICKZETTA_BATCH_SIZE: ${CLICKZETTA_BATCH_SIZE:-100}
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX: ${CLICKZETTA_ENABLE_INVERTED_INDEX:-true}
|
||||
CLICKZETTA_ANALYZER_TYPE: ${CLICKZETTA_ANALYZER_TYPE:-chinese}
|
||||
CLICKZETTA_ANALYZER_MODE: ${CLICKZETTA_ANALYZER_MODE:-smart}
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: ${CLICKZETTA_VECTOR_DISTANCE_FUNCTION:-cosine_distance}
|
||||
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
|
||||
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
|
||||
ETL_TYPE: ${ETL_TYPE:-dify}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Description Validation Test
|
||||
*
|
||||
* Tests for the 400-character description validation across App and Dataset
|
||||
* creation and editing workflows to ensure consistent validation behavior.
|
||||
*/
|
||||
|
||||
describe('Description Validation Logic', () => {
|
||||
// Simulate backend validation function
|
||||
const validateDescriptionLength = (description?: string | null) => {
|
||||
if (description && description.length > 400)
|
||||
throw new Error('Description cannot exceed 400 characters.')
|
||||
|
||||
return description
|
||||
}
|
||||
|
||||
describe('Backend Validation Function', () => {
|
||||
test('allows description within 400 characters', () => {
|
||||
const validDescription = 'x'.repeat(400)
|
||||
expect(() => validateDescriptionLength(validDescription)).not.toThrow()
|
||||
expect(validateDescriptionLength(validDescription)).toBe(validDescription)
|
||||
})
|
||||
|
||||
test('allows empty description', () => {
|
||||
expect(() => validateDescriptionLength('')).not.toThrow()
|
||||
expect(() => validateDescriptionLength(null)).not.toThrow()
|
||||
expect(() => validateDescriptionLength(undefined)).not.toThrow()
|
||||
})
|
||||
|
||||
test('rejects description exceeding 400 characters', () => {
|
||||
const invalidDescription = 'x'.repeat(401)
|
||||
expect(() => validateDescriptionLength(invalidDescription)).toThrow(
|
||||
'Description cannot exceed 400 characters.',
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Backend Validation Consistency', () => {
|
||||
test('App and Dataset have consistent validation limits', () => {
|
||||
const maxLength = 400
|
||||
const validDescription = 'x'.repeat(maxLength)
|
||||
const invalidDescription = 'x'.repeat(maxLength + 1)
|
||||
|
||||
// Both should accept exactly 400 characters
|
||||
expect(validDescription.length).toBe(400)
|
||||
expect(() => validateDescriptionLength(validDescription)).not.toThrow()
|
||||
|
||||
// Both should reject 401 characters
|
||||
expect(invalidDescription.length).toBe(401)
|
||||
expect(() => validateDescriptionLength(invalidDescription)).toThrow()
|
||||
})
|
||||
|
||||
test('validation error messages are consistent', () => {
|
||||
const expectedErrorMessage = 'Description cannot exceed 400 characters.'
|
||||
|
||||
// This would be the error message from both App and Dataset backend validation
|
||||
expect(expectedErrorMessage).toBe('Description cannot exceed 400 characters.')
|
||||
|
||||
const invalidDescription = 'x'.repeat(401)
|
||||
try {
|
||||
validateDescriptionLength(invalidDescription)
|
||||
}
|
||||
catch (error) {
|
||||
expect((error as Error).message).toBe(expectedErrorMessage)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Character Length Edge Cases', () => {
|
||||
const testCases = [
|
||||
{ length: 0, shouldPass: true, description: 'empty description' },
|
||||
{ length: 1, shouldPass: true, description: '1 character' },
|
||||
{ length: 399, shouldPass: true, description: '399 characters' },
|
||||
{ length: 400, shouldPass: true, description: '400 characters (boundary)' },
|
||||
{ length: 401, shouldPass: false, description: '401 characters (over limit)' },
|
||||
{ length: 500, shouldPass: false, description: '500 characters' },
|
||||
{ length: 1000, shouldPass: false, description: '1000 characters' },
|
||||
]
|
||||
|
||||
testCases.forEach(({ length, shouldPass, description }) => {
|
||||
test(`handles ${description} correctly`, () => {
|
||||
const testDescription = length > 0 ? 'x'.repeat(length) : ''
|
||||
expect(testDescription.length).toBe(length)
|
||||
|
||||
if (shouldPass) {
|
||||
expect(() => validateDescriptionLength(testDescription)).not.toThrow()
|
||||
expect(validateDescriptionLength(testDescription)).toBe(testDescription)
|
||||
}
|
||||
else {
|
||||
expect(() => validateDescriptionLength(testDescription)).toThrow(
|
||||
'Description cannot exceed 400 characters.',
|
||||
)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -4,7 +4,6 @@ import React, { useEffect, useMemo } from 'react'
|
|||
import { usePathname } from 'next/navigation'
|
||||
import useSWR from 'swr'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import {
|
||||
RiEqualizer2Fill,
|
||||
RiEqualizer2Line,
|
||||
|
|
@ -44,17 +43,12 @@ type IExtraInfoProps = {
|
|||
}
|
||||
|
||||
const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => {
|
||||
const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile)
|
||||
const { t } = useTranslation()
|
||||
const docLink = useDocLink()
|
||||
|
||||
const hasRelatedApps = relatedApps?.data && relatedApps?.data?.length > 0
|
||||
const relatedAppsTotal = relatedApps?.data?.length || 0
|
||||
|
||||
useEffect(() => {
|
||||
setShowTips(!isMobile)
|
||||
}, [isMobile, setShowTips])
|
||||
|
||||
return <div>
|
||||
{/* Related apps for desktop */}
|
||||
<div className={classNames(
|
||||
|
|
|
|||
|
|
@ -1,131 +0,0 @@
|
|||
'use client'
|
||||
|
||||
import { useEffect, useMemo, useState } from 'react'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiListUnordered } from '@remixicon/react'
|
||||
import TemplateEn from './template/template.en.mdx'
|
||||
import TemplateZh from './template/template.zh.mdx'
|
||||
import TemplateJa from './template/template.ja.mdx'
|
||||
import I18n from '@/context/i18n'
|
||||
import { LanguagesSupported } from '@/i18n-config/language'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { Theme } from '@/types/app'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type DocProps = {
|
||||
apiBaseUrl: string
|
||||
}
|
||||
|
||||
const Doc = ({ apiBaseUrl }: DocProps) => {
|
||||
const { locale } = useContext(I18n)
|
||||
const { t } = useTranslation()
|
||||
const [toc, setToc] = useState<Array<{ href: string; text: string }>>([])
|
||||
const [isTocExpanded, setIsTocExpanded] = useState(false)
|
||||
const { theme } = useTheme()
|
||||
|
||||
// Set initial TOC expanded state based on screen width
|
||||
useEffect(() => {
|
||||
const mediaQuery = window.matchMedia('(min-width: 1280px)')
|
||||
setIsTocExpanded(mediaQuery.matches)
|
||||
}, [])
|
||||
|
||||
// Extract TOC from article content
|
||||
useEffect(() => {
|
||||
const extractTOC = () => {
|
||||
const article = document.querySelector('article')
|
||||
if (article) {
|
||||
const headings = article.querySelectorAll('h2')
|
||||
const tocItems = Array.from(headings).map((heading) => {
|
||||
const anchor = heading.querySelector('a')
|
||||
if (anchor) {
|
||||
return {
|
||||
href: anchor.getAttribute('href') || '',
|
||||
text: anchor.textContent || '',
|
||||
}
|
||||
}
|
||||
return null
|
||||
}).filter((item): item is { href: string; text: string } => item !== null)
|
||||
setToc(tocItems)
|
||||
}
|
||||
}
|
||||
|
||||
setTimeout(extractTOC, 0)
|
||||
}, [locale])
|
||||
|
||||
// Handle TOC item click
|
||||
const handleTocClick = (e: React.MouseEvent<HTMLAnchorElement>, item: { href: string; text: string }) => {
|
||||
e.preventDefault()
|
||||
const targetId = item.href.replace('#', '')
|
||||
const element = document.getElementById(targetId)
|
||||
if (element) {
|
||||
const scrollContainer = document.querySelector('.scroll-container')
|
||||
if (scrollContainer) {
|
||||
const headerOffset = -40
|
||||
const elementTop = element.offsetTop - headerOffset
|
||||
scrollContainer.scrollTo({
|
||||
top: elementTop,
|
||||
behavior: 'smooth',
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Template = useMemo(() => {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateZh apiBaseUrl={apiBaseUrl} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateJa apiBaseUrl={apiBaseUrl} />
|
||||
default:
|
||||
return <TemplateEn apiBaseUrl={apiBaseUrl} />
|
||||
}
|
||||
}, [apiBaseUrl, locale])
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<div className={`fixed right-20 top-32 z-10 transition-all ${isTocExpanded ? 'w-64' : 'w-10'}`}>
|
||||
{isTocExpanded
|
||||
? (
|
||||
<nav className="toc max-h-[calc(100vh-150px)] w-full overflow-y-auto rounded-lg bg-components-panel-bg p-4 shadow-md">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<h3 className="text-lg font-semibold text-text-primary">{t('appApi.develop.toc')}</h3>
|
||||
<button
|
||||
onClick={() => setIsTocExpanded(false)}
|
||||
className="text-text-tertiary hover:text-text-secondary"
|
||||
>
|
||||
✕
|
||||
</button>
|
||||
</div>
|
||||
<ul className="space-y-2">
|
||||
{toc.map((item, index) => (
|
||||
<li key={index}>
|
||||
<a
|
||||
href={item.href}
|
||||
className="text-text-secondary transition-colors duration-200 hover:text-text-primary hover:underline"
|
||||
onClick={e => handleTocClick(e, item)}
|
||||
>
|
||||
{item.text}
|
||||
</a>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</nav>
|
||||
)
|
||||
: (
|
||||
<button
|
||||
onClick={() => setIsTocExpanded(true)}
|
||||
className="flex h-10 w-10 items-center justify-center rounded-full bg-components-button-secondary-bg shadow-md transition-colors duration-200 hover:bg-components-button-secondary-bg-hover"
|
||||
>
|
||||
<RiListUnordered className="h-6 w-6 text-components-button-secondary-text" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
<article className={cn('prose-xl prose mx-1 rounded-t-xl bg-background-default px-4 pt-16 sm:mx-12', theme === Theme.dark && 'prose-invert')}>
|
||||
{Template}
|
||||
</article>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Doc
|
||||
|
|
@ -9,10 +9,10 @@ import { useQuery } from '@tanstack/react-query'
|
|||
|
||||
// Components
|
||||
import ExternalAPIPanel from '../../components/datasets/external-api/external-api-panel'
|
||||
import Datasets from './Datasets'
|
||||
import DatasetFooter from './DatasetFooter'
|
||||
import Datasets from './datasets'
|
||||
import DatasetFooter from './dataset-footer'
|
||||
import ApiServer from '../../components/develop/ApiServer'
|
||||
import Doc from './Doc'
|
||||
import Doc from './doc'
|
||||
import TabSliderNew from '@/app/components/base/tab-slider-new'
|
||||
import TagManagementModal from '@/app/components/base/tag-management'
|
||||
import TagFilter from '@/app/components/base/tag-management/filter'
|
||||
|
|
@ -86,8 +86,8 @@ const Container = () => {
|
|||
}, [currentWorkspace, router])
|
||||
|
||||
return (
|
||||
<div ref={containerRef} className='scroll-container relative flex grow flex-col overflow-y-auto bg-background-body'>
|
||||
<div className='sticky top-0 z-10 flex h-[80px] shrink-0 flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]'>
|
||||
<div ref={containerRef} className={`scroll-container relative flex grow flex-col overflow-y-auto rounded-t-xl outline-none ${activeTab === 'dataset' ? 'bg-background-body' : 'bg-components-panel-bg'}`}>
|
||||
<div className={`sticky top-0 z-10 flex shrink-0 flex-wrap items-center justify-between gap-y-2 rounded-t-xl px-6 py-2 ${activeTab === 'api' ? 'border-b border-solid border-b-divider-regular' : ''} ${activeTab === 'dataset' ? 'bg-background-body' : 'bg-components-panel-bg'}`}>
|
||||
<TabSliderNew
|
||||
value={activeTab}
|
||||
onChange={newActiveTab => setActiveTab(newActiveTab)}
|
||||
|
|
@ -5,6 +5,7 @@ import { useRouter } from 'next/navigation'
|
|||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiMoreFill } from '@remixicon/react'
|
||||
import { mutate } from 'swr'
|
||||
import cn from '@/utils/classnames'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
|
|
@ -57,6 +58,19 @@ const DatasetCard = ({
|
|||
const onConfirmDelete = useCallback(async () => {
|
||||
try {
|
||||
await deleteDataset(dataset.id)
|
||||
|
||||
// Clear SWR cache to prevent stale data in knowledge retrieval nodes
|
||||
mutate(
|
||||
(key) => {
|
||||
if (typeof key === 'string') return key.includes('/datasets')
|
||||
if (typeof key === 'object' && key !== null)
|
||||
return key.url === '/datasets' || key.url?.includes('/datasets')
|
||||
return false
|
||||
},
|
||||
undefined,
|
||||
{ revalidate: true },
|
||||
)
|
||||
|
||||
notify({ type: 'success', message: t('dataset.datasetDeleted') })
|
||||
if (onSuccess)
|
||||
onSuccess()
|
||||
|
|
@ -162,24 +176,19 @@ const DatasetCard = ({
|
|||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
'mb-2 max-h-[72px] grow px-[14px] text-xs leading-normal text-text-tertiary group-hover:line-clamp-2 group-hover:max-h-[36px]',
|
||||
tags.length ? 'line-clamp-2' : 'line-clamp-4',
|
||||
'mb-2 line-clamp-2 max-h-[36px] grow px-[14px] text-xs leading-normal text-text-tertiary',
|
||||
!dataset.embedding_available && 'opacity-50 hover:opacity-100',
|
||||
)}
|
||||
title={dataset.description}>
|
||||
{dataset.description}
|
||||
</div>
|
||||
<div className={cn(
|
||||
'mt-4 h-[42px] shrink-0 items-center pb-[6px] pl-[14px] pr-[6px] pt-1',
|
||||
tags.length ? 'flex' : '!hidden group-hover:!flex',
|
||||
)}>
|
||||
<div className='mt-4 flex h-[42px] shrink-0 items-center pb-[6px] pl-[14px] pr-[6px] pt-1'>
|
||||
<div className={cn('flex w-0 grow items-center gap-1', !dataset.embedding_available && 'opacity-50 hover:opacity-100')} onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
e.preventDefault()
|
||||
}}>
|
||||
<div className={cn(
|
||||
'mr-[41px] w-full grow group-hover:!mr-0 group-hover:!block',
|
||||
tags.length ? '!block' : '!hidden',
|
||||
'mr-[41px] w-full grow group-hover:!mr-0',
|
||||
)}>
|
||||
<TagSelector
|
||||
position='bl'
|
||||
|
|
@ -3,8 +3,8 @@
|
|||
import { useCallback, useEffect, useRef } from 'react'
|
||||
import useSWRInfinite from 'swr/infinite'
|
||||
import { debounce } from 'lodash-es'
|
||||
import NewDatasetCard from './NewDatasetCard'
|
||||
import DatasetCard from './DatasetCard'
|
||||
import NewDatasetCard from './new-dataset-card'
|
||||
import DatasetCard from './dataset-card'
|
||||
import type { DataSetListResponse, FetchDatasetsParams } from '@/models/datasets'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
|
|
@ -36,7 +36,7 @@ const getKey = (
|
|||
}
|
||||
|
||||
type Props = {
|
||||
containerRef: React.RefObject<HTMLDivElement>
|
||||
containerRef: React.RefObject<HTMLDivElement | null>
|
||||
tags: string[]
|
||||
keywords: string
|
||||
includeAll: boolean
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
'use client'
|
||||
|
||||
import { useEffect, useMemo, useState } from 'react'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiCloseLine, RiListUnordered } from '@remixicon/react'
|
||||
import TemplateEn from './template/template.en.mdx'
|
||||
import TemplateZh from './template/template.zh.mdx'
|
||||
import TemplateJa from './template/template.ja.mdx'
|
||||
import I18n from '@/context/i18n'
|
||||
import { LanguagesSupported } from '@/i18n-config/language'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { Theme } from '@/types/app'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type DocProps = {
|
||||
apiBaseUrl: string
|
||||
}
|
||||
|
||||
const Doc = ({ apiBaseUrl }: DocProps) => {
|
||||
const { locale } = useContext(I18n)
|
||||
const { t } = useTranslation()
|
||||
const [toc, setToc] = useState<Array<{ href: string; text: string }>>([])
|
||||
const [isTocExpanded, setIsTocExpanded] = useState(false)
|
||||
const [activeSection, setActiveSection] = useState<string>('')
|
||||
const { theme } = useTheme()
|
||||
|
||||
// Set initial TOC expanded state based on screen width
|
||||
useEffect(() => {
|
||||
const mediaQuery = window.matchMedia('(min-width: 1280px)')
|
||||
setIsTocExpanded(mediaQuery.matches)
|
||||
}, [])
|
||||
|
||||
// Extract TOC from article content
|
||||
useEffect(() => {
|
||||
const extractTOC = () => {
|
||||
const article = document.querySelector('article')
|
||||
if (article) {
|
||||
const headings = article.querySelectorAll('h2')
|
||||
const tocItems = Array.from(headings).map((heading) => {
|
||||
const anchor = heading.querySelector('a')
|
||||
if (anchor) {
|
||||
return {
|
||||
href: anchor.getAttribute('href') || '',
|
||||
text: anchor.textContent || '',
|
||||
}
|
||||
}
|
||||
return null
|
||||
}).filter((item): item is { href: string; text: string } => item !== null)
|
||||
setToc(tocItems)
|
||||
// Set initial active section
|
||||
if (tocItems.length > 0)
|
||||
setActiveSection(tocItems[0].href.replace('#', ''))
|
||||
}
|
||||
}
|
||||
|
||||
setTimeout(extractTOC, 0)
|
||||
}, [locale])
|
||||
|
||||
// Track scroll position for active section highlighting
|
||||
useEffect(() => {
|
||||
const handleScroll = () => {
|
||||
const scrollContainer = document.querySelector('.scroll-container')
|
||||
if (!scrollContainer || toc.length === 0)
|
||||
return
|
||||
|
||||
// Find active section based on scroll position
|
||||
let currentSection = ''
|
||||
toc.forEach((item) => {
|
||||
const targetId = item.href.replace('#', '')
|
||||
const element = document.getElementById(targetId)
|
||||
if (element) {
|
||||
const rect = element.getBoundingClientRect()
|
||||
// Consider section active if its top is above the middle of viewport
|
||||
if (rect.top <= window.innerHeight / 2)
|
||||
currentSection = targetId
|
||||
}
|
||||
})
|
||||
|
||||
if (currentSection && currentSection !== activeSection)
|
||||
setActiveSection(currentSection)
|
||||
}
|
||||
|
||||
const scrollContainer = document.querySelector('.scroll-container')
|
||||
if (scrollContainer) {
|
||||
scrollContainer.addEventListener('scroll', handleScroll)
|
||||
handleScroll() // Initial check
|
||||
return () => scrollContainer.removeEventListener('scroll', handleScroll)
|
||||
}
|
||||
}, [toc, activeSection])
|
||||
|
||||
// Handle TOC item click
|
||||
const handleTocClick = (e: React.MouseEvent<HTMLAnchorElement>, item: { href: string; text: string }) => {
|
||||
e.preventDefault()
|
||||
const targetId = item.href.replace('#', '')
|
||||
const element = document.getElementById(targetId)
|
||||
if (element) {
|
||||
const scrollContainer = document.querySelector('.scroll-container')
|
||||
if (scrollContainer) {
|
||||
const headerOffset = -40
|
||||
const elementTop = element.offsetTop - headerOffset
|
||||
scrollContainer.scrollTo({
|
||||
top: elementTop,
|
||||
behavior: 'smooth',
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Template = useMemo(() => {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateZh apiBaseUrl={apiBaseUrl} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateJa apiBaseUrl={apiBaseUrl} />
|
||||
default:
|
||||
return <TemplateEn apiBaseUrl={apiBaseUrl} />
|
||||
}
|
||||
}, [apiBaseUrl, locale])
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<div className={`fixed right-20 top-32 z-10 transition-all duration-150 ease-out ${isTocExpanded ? 'w-[280px]' : 'w-11'}`}>
|
||||
{isTocExpanded
|
||||
? (
|
||||
<nav className="toc flex max-h-[calc(100vh-150px)] w-full flex-col overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-background-default-hover shadow-xl">
|
||||
<div className="relative z-10 flex items-center justify-between border-b border-components-panel-border-subtle bg-background-default-hover px-4 py-2.5">
|
||||
<span className="text-xs font-medium uppercase tracking-wide text-text-tertiary">
|
||||
{t('appApi.develop.toc')}
|
||||
</span>
|
||||
<button
|
||||
onClick={() => setIsTocExpanded(false)}
|
||||
className="group flex h-6 w-6 items-center justify-center rounded-md transition-colors hover:bg-state-base-hover"
|
||||
aria-label="Close"
|
||||
>
|
||||
<RiCloseLine className="h-3 w-3 text-text-quaternary transition-colors group-hover:text-text-secondary" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="from-components-panel-border-subtle/20 pointer-events-none absolute left-0 right-0 top-[41px] z-10 h-2 bg-gradient-to-b to-transparent"></div>
|
||||
<div className="pointer-events-none absolute left-0 right-0 top-[43px] z-10 h-3 bg-gradient-to-b from-background-default-hover to-transparent"></div>
|
||||
|
||||
<div className="relative flex-1 overflow-y-auto px-3 py-3 pt-1">
|
||||
{toc.length === 0 ? (
|
||||
<div className="px-2 py-8 text-center text-xs text-text-quaternary">
|
||||
{t('appApi.develop.noContent')}
|
||||
</div>
|
||||
) : (
|
||||
<ul className="space-y-0.5">
|
||||
{toc.map((item, index) => {
|
||||
const isActive = activeSection === item.href.replace('#', '')
|
||||
return (
|
||||
<li key={index}>
|
||||
<a
|
||||
href={item.href}
|
||||
onClick={e => handleTocClick(e, item)}
|
||||
className={cn(
|
||||
'group relative flex items-center rounded-md px-3 py-2 text-[13px] transition-all duration-200',
|
||||
isActive
|
||||
? 'bg-state-base-hover font-medium text-text-primary'
|
||||
: 'text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary',
|
||||
)}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
'mr-2 h-1.5 w-1.5 rounded-full transition-all duration-200',
|
||||
isActive
|
||||
? 'scale-100 bg-text-accent'
|
||||
: 'scale-75 bg-components-panel-border',
|
||||
)}
|
||||
/>
|
||||
<span className="flex-1 truncate">
|
||||
{item.text}
|
||||
</span>
|
||||
</a>
|
||||
</li>
|
||||
)
|
||||
})}
|
||||
</ul>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="pointer-events-none absolute bottom-0 left-0 right-0 z-10 h-4 rounded-b-xl bg-gradient-to-t from-background-default-hover to-transparent"></div>
|
||||
</nav>
|
||||
)
|
||||
: (
|
||||
<button
|
||||
onClick={() => setIsTocExpanded(true)}
|
||||
className="group flex h-11 w-11 items-center justify-center rounded-full border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg transition-all duration-150 hover:bg-background-default-hover hover:shadow-xl"
|
||||
aria-label="Open table of contents"
|
||||
>
|
||||
<RiListUnordered className="h-5 w-5 text-text-tertiary transition-colors group-hover:text-text-secondary" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
<article className={cn('prose-xl prose', theme === Theme.dark && 'prose-invert')}>
|
||||
{Template}
|
||||
</article>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Doc
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
'use client'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Container from './Container'
|
||||
import Container from './container'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
|
||||
const AppList = () => {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</CodeGroup>
|
||||
</div>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/document/create-by-text'
|
||||
|
|
@ -163,7 +163,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/document/create-by-file'
|
||||
|
|
@ -294,7 +294,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
|
|
@ -400,7 +400,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
|
|
@ -472,7 +472,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}'
|
||||
|
|
@ -553,7 +553,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}'
|
||||
|
|
@ -714,7 +714,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}'
|
||||
|
|
@ -751,7 +751,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
|
||||
|
|
@ -853,7 +853,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
|
||||
|
|
@ -952,7 +952,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{batch}/indexing-status'
|
||||
|
|
@ -1007,7 +1007,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}'
|
||||
|
|
@ -1047,7 +1047,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents'
|
||||
|
|
@ -1122,7 +1122,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}'
|
||||
|
|
@ -1245,7 +1245,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
___
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/status/{action}'
|
||||
|
|
@ -1302,7 +1302,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments'
|
||||
|
|
@ -1388,7 +1388,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments'
|
||||
|
|
@ -1476,7 +1476,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
|
||||
|
|
@ -1546,7 +1546,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
|
||||
|
|
@ -1590,7 +1590,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
|
||||
|
|
@ -1679,7 +1679,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
|
||||
|
|
@ -1750,7 +1750,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
|
||||
|
|
@ -1827,7 +1827,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
|
||||
|
|
@ -1873,7 +1873,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
|
||||
|
|
@ -1947,7 +1947,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/upload-file'
|
||||
|
|
@ -1998,7 +1998,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/retrieve'
|
||||
|
|
@ -2177,7 +2177,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata'
|
||||
|
|
@ -2224,7 +2224,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata/{metadata_id}'
|
||||
|
|
@ -2273,7 +2273,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata/{metadata_id}'
|
||||
|
|
@ -2306,7 +2306,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata/built-in/{action}'
|
||||
|
|
@ -2339,7 +2339,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/metadata'
|
||||
|
|
@ -2378,7 +2378,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata'
|
||||
|
|
@ -2424,7 +2424,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/workspaces/current/models/model-types/text-embedding'
|
||||
|
|
@ -2528,7 +2528,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
Okay, I will translate the Chinese text in your document while keeping all formatting and code content unchanged.
|
||||
|
||||
<Heading
|
||||
|
|
@ -2574,7 +2574,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
|
|||
</Row>
|
||||
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags'
|
||||
|
|
@ -2615,7 +2615,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags'
|
||||
|
|
@ -2662,7 +2662,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
|
||||
<Heading
|
||||
|
|
@ -2704,7 +2704,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags/binding'
|
||||
|
|
@ -2746,7 +2746,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags/unbinding'
|
||||
|
|
@ -2789,7 +2789,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
|
|||
</Row>
|
||||
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/<uuid:dataset_id>/tags'
|
||||
|
|
@ -2837,7 +2837,7 @@ Okay, I will translate the Chinese text in your document while keeping all forma
|
|||
</Row>
|
||||
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
|
||||
<Row>
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</CodeGroup>
|
||||
</div>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/document/create-by-text'
|
||||
|
|
@ -163,7 +163,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/document/create-by-file'
|
||||
|
|
@ -294,7 +294,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
|
|
@ -399,7 +399,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
|
|
@ -471,7 +471,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}'
|
||||
|
|
@ -508,7 +508,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
|
||||
|
|
@ -610,7 +610,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
|
||||
|
|
@ -709,7 +709,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{batch}/indexing-status'
|
||||
|
|
@ -764,7 +764,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}'
|
||||
|
|
@ -804,7 +804,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents'
|
||||
|
|
@ -879,7 +879,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}'
|
||||
|
|
@ -1002,7 +1002,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
___
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
|
||||
<Heading
|
||||
|
|
@ -1060,7 +1060,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments'
|
||||
|
|
@ -1146,7 +1146,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments'
|
||||
|
|
@ -1234,7 +1234,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
|
||||
|
|
@ -1304,7 +1304,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
method='DELETE'
|
||||
|
|
@ -1347,7 +1347,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
method='POST'
|
||||
|
|
@ -1435,7 +1435,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
|
||||
|
|
@ -1506,7 +1506,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
|
||||
|
|
@ -1583,7 +1583,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
|
||||
|
|
@ -1629,7 +1629,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
|
||||
|
|
@ -1703,7 +1703,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/upload-file'
|
||||
|
|
@ -1754,7 +1754,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/retrieve'
|
||||
|
|
@ -1933,7 +1933,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata'
|
||||
|
|
@ -1980,7 +1980,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata/{metadata_id}'
|
||||
|
|
@ -2029,7 +2029,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata/{metadata_id}'
|
||||
|
|
@ -2062,7 +2062,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata/built-in/{action}'
|
||||
|
|
@ -2095,7 +2095,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/metadata'
|
||||
|
|
@ -2136,7 +2136,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata'
|
||||
|
|
@ -2182,7 +2182,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
<Heading
|
||||
url='/datasets/tags'
|
||||
method='POST'
|
||||
|
|
@ -2226,7 +2226,7 @@ ___
|
|||
</Row>
|
||||
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags'
|
||||
|
|
@ -2267,7 +2267,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags'
|
||||
|
|
@ -2314,7 +2314,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
|
||||
<Heading
|
||||
|
|
@ -2356,7 +2356,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags/binding'
|
||||
|
|
@ -2398,7 +2398,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags/unbinding'
|
||||
|
|
@ -2441,7 +2441,7 @@ ___
|
|||
</Row>
|
||||
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/<uuid:dataset_id>/tags'
|
||||
|
|
@ -2489,7 +2489,7 @@ ___
|
|||
</Row>
|
||||
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</CodeGroup>
|
||||
</div>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/document/create-by-text'
|
||||
|
|
@ -167,7 +167,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/document/create-by-file'
|
||||
|
|
@ -298,7 +298,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
|
|
@ -403,7 +403,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
|
|
@ -475,7 +475,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}'
|
||||
|
|
@ -556,7 +556,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}'
|
||||
|
|
@ -721,7 +721,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}'
|
||||
|
|
@ -758,7 +758,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
|
||||
|
|
@ -860,7 +860,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
|
||||
|
|
@ -959,7 +959,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{batch}/indexing-status'
|
||||
|
|
@ -1014,7 +1014,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}'
|
||||
|
|
@ -1054,7 +1054,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents'
|
||||
|
|
@ -1129,7 +1129,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}'
|
||||
|
|
@ -1252,7 +1252,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
</Col>
|
||||
</Row>
|
||||
___
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
|
||||
<Heading
|
||||
|
|
@ -1310,7 +1310,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments'
|
||||
|
|
@ -1396,7 +1396,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments'
|
||||
|
|
@ -1484,7 +1484,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
|
||||
|
|
@ -1528,7 +1528,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
|
||||
|
|
@ -1598,7 +1598,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
method='POST'
|
||||
|
|
@ -1687,7 +1687,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
|
||||
|
|
@ -1758,7 +1758,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks'
|
||||
|
|
@ -1835,7 +1835,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
|
||||
|
|
@ -1881,7 +1881,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
@ -1915,7 +1915,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}'
|
||||
|
|
@ -1989,7 +1989,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/upload-file'
|
||||
|
|
@ -2040,7 +2040,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/retrieve'
|
||||
|
|
@ -2219,7 +2219,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata'
|
||||
|
|
@ -2266,7 +2266,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata/{metadata_id}'
|
||||
|
|
@ -2315,7 +2315,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata/{metadata_id}'
|
||||
|
|
@ -2348,7 +2348,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata/built-in/{action}'
|
||||
|
|
@ -2381,7 +2381,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/metadata'
|
||||
|
|
@ -2422,7 +2422,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/metadata'
|
||||
|
|
@ -2468,7 +2468,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/workspaces/current/models/model-types/text-embedding'
|
||||
|
|
@ -2572,7 +2572,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags'
|
||||
|
|
@ -2617,7 +2617,7 @@ ___
|
|||
</Row>
|
||||
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags'
|
||||
|
|
@ -2658,7 +2658,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags'
|
||||
|
|
@ -2705,7 +2705,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
|
||||
<Heading
|
||||
|
|
@ -2747,7 +2747,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags/binding'
|
||||
|
|
@ -2789,7 +2789,7 @@ ___
|
|||
</Col>
|
||||
</Row>
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/tags/unbinding'
|
||||
|
|
@ -2832,7 +2832,7 @@ ___
|
|||
</Row>
|
||||
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Heading
|
||||
url='/datasets/<uuid:dataset_id>/tags'
|
||||
|
|
@ -2880,7 +2880,7 @@ ___
|
|||
</Row>
|
||||
|
||||
|
||||
<hr className='ml-0 mr-0' />
|
||||
<hr style={{ marginLeft: 0, marginRight: 0, width: '100%', maxWidth: '100%' }} />
|
||||
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
|
|||
<Avatar {...props} />
|
||||
<div
|
||||
onClick={() => { setIsShowAvatarPicker(true) }}
|
||||
className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black bg-opacity-50 opacity-0 transition-opacity group-hover:opacity-100"
|
||||
className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black/50 opacity-0 transition-opacity group-hover:opacity-100"
|
||||
>
|
||||
<span className="text-xs text-white">
|
||||
<RiPencilLine />
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import {
|
|||
RiFileUploadLine,
|
||||
} from '@remixicon/react'
|
||||
import AppIcon from '../base/app-icon'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
|
|
@ -31,6 +30,7 @@ import Divider from '../base/divider'
|
|||
import type { Operation } from './app-operations'
|
||||
import AppOperations from './app-operations'
|
||||
import dynamic from 'next/dynamic'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), {
|
||||
ssr: false,
|
||||
|
|
@ -256,32 +256,40 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
|||
}}
|
||||
className='block w-full'
|
||||
>
|
||||
<div className={cn('flex rounded-lg', expand ? 'flex-col gap-2 p-2 pb-2.5' : 'items-start justify-center gap-1 p-1', open && 'bg-state-base-hover', isCurrentWorkspaceEditor && 'cursor-pointer hover:bg-state-base-hover')}>
|
||||
<div className={`flex items-center self-stretch ${expand ? 'justify-between' : 'flex-col gap-1'}`}>
|
||||
<AppIcon
|
||||
size={expand ? 'large' : 'small'}
|
||||
iconType={appDetail.icon_type}
|
||||
icon={appDetail.icon}
|
||||
background={appDetail.icon_background}
|
||||
imageUrl={appDetail.icon_url}
|
||||
/>
|
||||
<div className='flex items-center justify-center rounded-md p-0.5'>
|
||||
<div className='flex h-5 w-5 items-center justify-center'>
|
||||
<div className='flex flex-col gap-2 rounded-lg p-1 hover:bg-state-base-hover'>
|
||||
<div className='flex items-center gap-1'>
|
||||
<div className={cn(!expand && 'ml-1')}>
|
||||
<AppIcon
|
||||
size={expand ? 'large' : 'small'}
|
||||
iconType={appDetail.icon_type}
|
||||
icon={appDetail.icon}
|
||||
background={appDetail.icon_background}
|
||||
imageUrl={appDetail.icon_url}
|
||||
/>
|
||||
</div>
|
||||
{expand && (
|
||||
<div className='ml-auto flex items-center justify-center rounded-md p-0.5'>
|
||||
<div className='flex h-5 w-5 items-center justify-center'>
|
||||
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{!expand && (
|
||||
<div className='flex items-center justify-center'>
|
||||
<div className='flex h-5 w-5 items-center justify-center rounded-md p-0.5'>
|
||||
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className={cn(
|
||||
'flex flex-col items-start gap-1 transition-all duration-200 ease-in-out',
|
||||
expand
|
||||
? 'w-auto opacity-100'
|
||||
: 'pointer-events-none w-0 overflow-hidden opacity-0',
|
||||
)}>
|
||||
<div className='flex w-full'>
|
||||
<div className='system-md-semibold truncate whitespace-nowrap text-text-secondary'>{appDetail.name}</div>
|
||||
)}
|
||||
{expand && (
|
||||
<div className='flex flex-col items-start gap-1'>
|
||||
<div className='flex w-full'>
|
||||
<div className='system-md-semibold truncate whitespace-nowrap text-text-secondary'>{appDetail.name}</div>
|
||||
</div>
|
||||
<div className='system-2xs-medium-uppercase whitespace-nowrap text-text-tertiary'>{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}</div>
|
||||
</div>
|
||||
<div className='system-2xs-medium-uppercase whitespace-nowrap text-text-tertiary'>{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ const AccessControlDialog = ({
|
|||
leaveFrom="opacity-100"
|
||||
leaveTo="opacity-0"
|
||||
>
|
||||
<div className="fixed inset-0 bg-background-overlay bg-opacity-25" />
|
||||
<div className="bg-background-overlay/25 fixed inset-0" />
|
||||
</Transition.Child>
|
||||
|
||||
<div className="fixed inset-0 flex items-center justify-center">
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ function SelectedGroupsBreadCrumb() {
|
|||
setSelectedGroupsForBreadcrumb([])
|
||||
}, [setSelectedGroupsForBreadcrumb])
|
||||
return <div className='flex h-7 items-center gap-x-0.5 px-2 py-0.5'>
|
||||
<span className={classNames('system-xs-regular text-text-tertiary', selectedGroupsForBreadcrumb.length > 0 && 'text-text-accent cursor-pointer')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')}</span>
|
||||
<span className={classNames('system-xs-regular text-text-tertiary', selectedGroupsForBreadcrumb.length > 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')}</span>
|
||||
{selectedGroupsForBreadcrumb.map((group, index) => {
|
||||
return <div key={index} className='system-xs-regular flex items-center gap-x-0.5 text-text-tertiary'>
|
||||
<span>/</span>
|
||||
|
|
@ -198,7 +198,7 @@ type BaseItemProps = {
|
|||
children: React.ReactNode
|
||||
}
|
||||
function BaseItem({ children, className }: BaseItemProps) {
|
||||
return <div className={classNames('p-1 pl-2 flex items-center space-x-2 hover:rounded-lg hover:bg-state-base-hover cursor-pointer', className)}>
|
||||
return <div className={classNames('flex cursor-pointer items-center space-x-2 p-1 pl-2 hover:rounded-lg hover:bg-state-base-hover', className)}>
|
||||
{children}
|
||||
</div>
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import React, { useRef, useState } from 'react'
|
|||
import { useGetState, useInfiniteScroll } from 'ahooks'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Link from 'next/link'
|
||||
import produce from 'immer'
|
||||
import TypeIcon from '../type-icon'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
|
|
@ -29,9 +28,10 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
|
|||
onSelect,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [selected, setSelected] = React.useState<DataSet[]>(selectedIds.map(id => ({ id }) as any))
|
||||
const [selected, setSelected] = React.useState<DataSet[]>([])
|
||||
const [loaded, setLoaded] = React.useState(false)
|
||||
const [datasets, setDataSets] = React.useState<DataSet[] | null>(null)
|
||||
const [hasInitialized, setHasInitialized] = React.useState(false)
|
||||
const hasNoData = !datasets || datasets?.length === 0
|
||||
const canSelectMulti = true
|
||||
|
||||
|
|
@ -49,19 +49,17 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
|
|||
const newList = [...(datasets || []), ...data.filter(item => item.indexing_technique || item.provider === 'external')]
|
||||
setDataSets(newList)
|
||||
setLoaded(true)
|
||||
if (!selected.find(item => !item.name))
|
||||
return { list: [] }
|
||||
|
||||
const newSelected = produce(selected, (draft) => {
|
||||
selected.forEach((item, index) => {
|
||||
if (!item.name) { // not fetched database
|
||||
const newItem = newList.find(i => i.id === item.id)
|
||||
if (newItem)
|
||||
draft[index] = newItem
|
||||
}
|
||||
})
|
||||
})
|
||||
setSelected(newSelected)
|
||||
// Initialize selected datasets based on selectedIds and available datasets
|
||||
if (!hasInitialized) {
|
||||
if (selectedIds.length > 0) {
|
||||
const validSelectedDatasets = selectedIds
|
||||
.map(id => newList.find(item => item.id === id))
|
||||
.filter(Boolean) as DataSet[]
|
||||
setSelected(validSelectedDatasets)
|
||||
}
|
||||
setHasInitialized(true)
|
||||
}
|
||||
}
|
||||
return { list: [] }
|
||||
},
|
||||
|
|
|
|||
|
|
@ -55,8 +55,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
|||
const { data: embeddingsModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
const {
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
currentModel: isRerankDefaultModelValid,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
const { t } = useTranslation()
|
||||
const docLink = useDocLink()
|
||||
|
|
|
|||
|
|
@ -40,13 +40,13 @@ type CategoryItemProps = {
|
|||
}
|
||||
function CategoryItem({ category, active, onClick }: CategoryItemProps) {
|
||||
return <li
|
||||
className={classNames('p-1 pl-3 h-8 rounded-lg flex items-center gap-2 group cursor-pointer hover:bg-state-base-hover [&.active]:bg-state-base-active', active && 'active')}
|
||||
className={classNames('group flex h-8 cursor-pointer items-center gap-2 rounded-lg p-1 pl-3 hover:bg-state-base-hover [&.active]:bg-state-base-active', active && 'active')}
|
||||
onClick={() => { onClick?.(category) }}>
|
||||
{category === AppCategories.RECOMMENDED && <div className='inline-flex h-5 w-5 items-center justify-center rounded-md'>
|
||||
<RiThumbUpLine className='h-4 w-4 text-components-menu-item-text group-[.active]:text-components-menu-item-text-active' />
|
||||
</div>}
|
||||
<AppCategoryLabel category={category}
|
||||
className={classNames('system-sm-medium text-components-menu-item-text group-[.active]:text-components-menu-item-text-active group-hover:text-components-menu-item-text-hover', active && 'system-sm-semibold')} />
|
||||
className={classNames('system-sm-medium text-components-menu-item-text group-hover:text-components-menu-item-text-hover group-[.active]:text-components-menu-item-text-active', active && 'system-sm-semibold')} />
|
||||
</li >
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -82,8 +82,11 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps)
|
|||
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
|
||||
getRedirection(isCurrentWorkspaceEditor, app, push)
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('app.newApp.appCreateFailed') })
|
||||
catch (e: any) {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: e.message || t('app.newApp.appCreateFailed'),
|
||||
})
|
||||
}
|
||||
isCreatingRef.current = false
|
||||
}, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, push, isCurrentWorkspaceEditor])
|
||||
|
|
|
|||
|
|
@ -106,8 +106,8 @@ const Uploader: FC<Props> = ({
|
|||
<div className='flex w-full items-center justify-center space-x-2'>
|
||||
<RiUploadCloud2Line className='h-6 w-6 text-text-tertiary' />
|
||||
<div className='text-text-tertiary'>
|
||||
{t('datasetCreation.stepOne.uploader.button')}
|
||||
<span className='cursor-pointer pl-1 text-text-accent' onClick={selectHandle}>{t('datasetDocuments.list.batchModal.browse')}</span>
|
||||
{t('app.dslUploader.button')}
|
||||
<span className='cursor-pointer pl-1 text-text-accent' onClick={selectHandle}>{t('app.dslUploader.browse')}</span>
|
||||
</div>
|
||||
</div>
|
||||
{dragging && <div ref={dragRef} className='absolute left-0 top-0 h-full w-full' />}
|
||||
|
|
|
|||
|
|
@ -117,8 +117,11 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
if (onRefresh)
|
||||
onRefresh()
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('app.editFailed') })
|
||||
catch (e: any) {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: e.message || t('app.editFailed'),
|
||||
})
|
||||
}
|
||||
}, [app.id, notify, onRefresh, t])
|
||||
|
||||
|
|
@ -364,26 +367,20 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
</div>
|
||||
<div className='title-wrapper h-[90px] px-[14px] text-xs leading-normal text-text-tertiary'>
|
||||
<div
|
||||
className={cn(tags.length ? 'line-clamp-2' : 'line-clamp-4', 'group-hover:line-clamp-2')}
|
||||
className='line-clamp-2'
|
||||
title={app.description}
|
||||
>
|
||||
{app.description}
|
||||
</div>
|
||||
</div>
|
||||
<div className={cn(
|
||||
'absolute bottom-1 left-0 right-0 h-[42px] shrink-0 items-center pb-[6px] pl-[14px] pr-[6px] pt-1',
|
||||
tags.length ? 'flex' : '!hidden group-hover:!flex',
|
||||
)}>
|
||||
<div className='absolute bottom-1 left-0 right-0 flex h-[42px] shrink-0 items-center pb-[6px] pl-[14px] pr-[6px] pt-1'>
|
||||
{isCurrentWorkspaceEditor && (
|
||||
<>
|
||||
<div className={cn('flex w-0 grow items-center gap-1')} onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
e.preventDefault()
|
||||
}}>
|
||||
<div className={cn(
|
||||
'mr-[41px] w-full grow group-hover:!mr-0 group-hover:!block',
|
||||
tags.length ? '!block' : '!hidden',
|
||||
)}>
|
||||
<div className='mr-[41px] w-full grow group-hover:!mr-0'>
|
||||
<TagSelector
|
||||
position='bl'
|
||||
type='app'
|
||||
|
|
@ -395,7 +392,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className='mx-1 !hidden h-[14px] w-[1px] shrink-0 group-hover:!flex' />
|
||||
<div className='mx-1 !hidden h-[14px] w-[1px] shrink-0 bg-divider-regular group-hover:!flex' />
|
||||
<div className='!hidden shrink-0 group-hover:!flex'>
|
||||
<CustomPopover
|
||||
htmlContent={<Operations />}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import React, { useState } from 'react'
|
||||
import React from 'react'
|
||||
import Link from 'next/link'
|
||||
import { RiCloseLine, RiDiscordFill, RiGithubFill } from '@remixicon/react'
|
||||
import { RiDiscordFill, RiGithubFill } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
type CustomLinkProps = {
|
||||
|
|
@ -26,24 +26,9 @@ const CustomLink = React.memo(({
|
|||
|
||||
const Footer = () => {
|
||||
const { t } = useTranslation()
|
||||
const [isVisible, setIsVisible] = useState(true)
|
||||
|
||||
const handleClose = () => {
|
||||
setIsVisible(false)
|
||||
}
|
||||
|
||||
if (!isVisible)
|
||||
return null
|
||||
|
||||
return (
|
||||
<footer className='relative shrink-0 grow-0 px-12 py-2'>
|
||||
<button
|
||||
onClick={handleClose}
|
||||
className='absolute right-2 top-2 flex h-6 w-6 cursor-pointer items-center justify-center rounded-full transition-colors duration-200 ease-in-out hover:bg-components-main-nav-nav-button-bg-active'
|
||||
aria-label="Close footer"
|
||||
>
|
||||
<RiCloseLine className='h-4 w-4 text-text-tertiary hover:text-text-secondary' />
|
||||
</button>
|
||||
<h3 className='text-gradient text-xl font-semibold leading-tight'>{t('app.join')}</h3>
|
||||
<p className='system-sm-regular mt-1 text-text-tertiary'>{t('app.communityIntro')}</p>
|
||||
<div className='mt-3 flex items-center gap-2'>
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
'use client'
|
||||
import { useEducationInit } from '@/app/education-apply/hooks'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import List from './list'
|
||||
import Footer from './footer'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const Apps = () => {
|
||||
const { t } = useTranslation()
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
|
||||
useDocumentTitle(t('common.menus.apps'))
|
||||
useEducationInit()
|
||||
|
|
@ -16,9 +13,6 @@ const Apps = () => {
|
|||
return (
|
||||
<div className='relative flex h-0 shrink-0 grow flex-col overflow-y-auto bg-background-body'>
|
||||
<List />
|
||||
{!systemFeatures.branding.enabled && (
|
||||
<Footer />
|
||||
)}
|
||||
</div >
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ import TagFilter from '@/app/components/base/tag-management/filter'
|
|||
import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label'
|
||||
import dynamic from 'next/dynamic'
|
||||
import Empty from './empty'
|
||||
import Footer from './footer'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), {
|
||||
ssr: false,
|
||||
|
|
@ -66,6 +68,7 @@ const getKey = (
|
|||
|
||||
const List = () => {
|
||||
const { t } = useTranslation()
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
const router = useRouter()
|
||||
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator } = useAppContext()
|
||||
const showTagManagementModal = useTagStore(s => s.showTagManagementModal)
|
||||
|
|
@ -229,6 +232,9 @@ const List = () => {
|
|||
<span className="system-xs-regular">{t('app.newApp.dropDSLToCreateApp')}</span>
|
||||
</div>
|
||||
)}
|
||||
{!systemFeatures.branding.enabled && (
|
||||
<Footer />
|
||||
)}
|
||||
<CheckModal />
|
||||
<div ref={anchorRef} className='h-0'> </div>
|
||||
{showTagManagementModal && (
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue