diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000..3e95f2e982 --- /dev/null +++ b/.env.example @@ -0,0 +1,1197 @@ +# ------------------------------ +# Environment Variables for API service & worker +# ------------------------------ + +# ------------------------------ +# Common Variables +# ------------------------------ + +# The backend URL of the console API, +# used to concatenate the authorization callback. +# If empty, it is the same domain. +# Example: https://api.console.dify.ai +CONSOLE_API_URL= + +# The front-end URL of the console web, +# used to concatenate some front-end addresses and for CORS configuration use. +# If empty, it is the same domain. +# Example: https://console.dify.ai +CONSOLE_WEB_URL= + +# Service API Url, +# used to display Service API Base Url to the front-end. +# If empty, it is the same domain. +# Example: https://api.dify.ai +SERVICE_API_URL= + +# WebApp API backend Url, +# used to declare the back-end URL for the front-end API. +# If empty, it is the same domain. +# Example: https://api.app.dify.ai +APP_API_URL= + +# WebApp Url, +# used to display WebAPP API Base Url to the front-end. +# If empty, it is the same domain. +# Example: https://app.dify.ai +APP_WEB_URL= + +# File preview or download Url prefix. +# used to display File preview or download Url to the front-end or as Multi-model inputs; +# Url is signed and has expiration time. +# Setting FILES_URL is required for file processing plugins. +# - For https://example.com, use FILES_URL=https://example.com +# - For http://example.com, use FILES_URL=http://example.com +# Recommendation: use a dedicated domain (e.g., https://upload.example.com). +# Alternatively, use http://:5001 or http://api:5001, +# ensuring port 5001 is externally accessible (see docker-compose.yaml). +FILES_URL= + +# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network. +# Set this to the internal Docker service URL for proper plugin file access. +# Example: INTERNAL_FILES_URL=http://api:5001 +INTERNAL_FILES_URL= + +# ------------------------------ +# Server Configuration +# ------------------------------ + +# The log level for the application. +# Supported values are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` +LOG_LEVEL=INFO +# Log file path +LOG_FILE=/app/logs/server.log +# Log file max size, the unit is MB +LOG_FILE_MAX_SIZE=20 +# Log file max backup count +LOG_FILE_BACKUP_COUNT=5 +# Log dateformat +LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S +# Log Timezone +LOG_TZ=UTC + +# Debug mode, default is false. +# It is recommended to turn on this configuration for local development +# to prevent some problems caused by monkey patch. +DEBUG=false + +# Flask debug mode, it can output trace information at the interface when turned on, +# which is convenient for debugging. +FLASK_DEBUG=false + +# Enable request logging, which will log the request and response information. +# And the log level is DEBUG +ENABLE_REQUEST_LOGGING=False + +# A secret key that is used for securely signing the session cookie +# and encrypting sensitive information on the database. +# You can generate a strong key using `openssl rand -base64 42`. +SECRET_KEY=sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U + +# Password for admin user initialization. +# If left unset, admin user will not be prompted for a password +# when creating the initial admin account. +# The length of the password cannot exceed 30 characters. +INIT_PASSWORD= + +# Deployment environment. +# Supported values are `PRODUCTION`, `TESTING`. Default is `PRODUCTION`. +# Testing environment. There will be a distinct color label on the front-end page, +# indicating that this environment is a testing environment. +DEPLOY_ENV=PRODUCTION + +# Whether to enable the version check policy. +# If set to empty, https://updates.dify.ai will be called for version check. +CHECK_UPDATE_URL=https://updates.dify.ai + +# Used to change the OpenAI base address, default is https://api.openai.com/v1. +# When OpenAI cannot be accessed in China, replace it with a domestic mirror address, +# or when a local model provides OpenAI compatible API, it can be replaced. +OPENAI_API_BASE=https://api.openai.com/v1 + +# When enabled, migrations will be executed prior to application startup +# and the application will start after the migrations have completed. +MIGRATION_ENABLED=true + +# File Access Time specifies a time interval in seconds for the file to be accessed. +# The default value is 300 seconds. +FILES_ACCESS_TIMEOUT=300 + +# Access token expiration time in minutes +ACCESS_TOKEN_EXPIRE_MINUTES=60 + +# Refresh token expiration time in days +REFRESH_TOKEN_EXPIRE_DAYS=30 + +# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. +APP_MAX_ACTIVE_REQUESTS=0 +APP_MAX_EXECUTION_TIME=1200 + +# ------------------------------ +# Container Startup Related Configuration +# Only effective when starting with docker image or docker-compose. +# ------------------------------ + +# API service binding address, default: 0.0.0.0, i.e., all addresses can be accessed. +DIFY_BIND_ADDRESS=0.0.0.0 + +# API service binding port number, default 5001. +DIFY_PORT=5001 + +# The number of API server workers, i.e., the number of workers. +# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent +# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers +SERVER_WORKER_AMOUNT=1 + +# Defaults to gevent. If using windows, it can be switched to sync or solo. +SERVER_WORKER_CLASS=gevent + +# Default number of worker connections, the default is 10. +SERVER_WORKER_CONNECTIONS=10 + +# Similar to SERVER_WORKER_CLASS. +# If using windows, it can be switched to sync or solo. +CELERY_WORKER_CLASS= + +# Request handling timeout. The default is 200, +# it is recommended to set it to 360 to support a longer sse connection time. +GUNICORN_TIMEOUT=360 + +# The number of Celery workers. The default is 1, and can be set as needed. +CELERY_WORKER_AMOUNT= + +# Flag indicating whether to enable autoscaling of Celery workers. +# +# Autoscaling is useful when tasks are CPU intensive and can be dynamically +# allocated and deallocated based on the workload. +# +# When autoscaling is enabled, the maximum and minimum number of workers can +# be specified. The autoscaling algorithm will dynamically adjust the number +# of workers within the specified range. +# +# Default is false (i.e., autoscaling is disabled). +# +# Example: +# CELERY_AUTO_SCALE=true +CELERY_AUTO_SCALE=false + +# The maximum number of Celery workers that can be autoscaled. +# This is optional and only used when autoscaling is enabled. +# Default is not set. +CELERY_MAX_WORKERS= + +# The minimum number of Celery workers that can be autoscaled. +# This is optional and only used when autoscaling is enabled. +# Default is not set. +CELERY_MIN_WORKERS= + +# API Tool configuration +API_TOOL_DEFAULT_CONNECT_TIMEOUT=10 +API_TOOL_DEFAULT_READ_TIMEOUT=60 + +# ------------------------------- +# Datasource Configuration +# -------------------------------- +ENABLE_WEBSITE_JINAREADER=true +ENABLE_WEBSITE_FIRECRAWL=true +ENABLE_WEBSITE_WATERCRAWL=true + +# ------------------------------ +# Database Configuration +# The database uses PostgreSQL. Please use the public schema. +# It is consistent with the configuration in the 'db' service below. +# ------------------------------ + +DB_USERNAME=postgres +DB_PASSWORD=difyai123456 +DB_HOST=db +DB_PORT=5432 +DB_DATABASE=dify +# The size of the database connection pool. +# The default is 30 connections, which can be appropriately increased. +SQLALCHEMY_POOL_SIZE=30 +# Database connection pool recycling time, the default is 3600 seconds. +SQLALCHEMY_POOL_RECYCLE=3600 +# Whether to print SQL, default is false. +SQLALCHEMY_ECHO=false +# If True, will test connections for liveness upon each checkout +SQLALCHEMY_POOL_PRE_PING=false +# Whether to enable the Last in first out option or use default FIFO queue if is false +SQLALCHEMY_POOL_USE_LIFO=false + +# Maximum number of connections to the database +# Default is 100 +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-connection.html#GUC-MAX-CONNECTIONS +POSTGRES_MAX_CONNECTIONS=100 + +# Sets the amount of shared memory used for postgres's shared buffers. +# Default is 128MB +# Recommended value: 25% of available memory +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-SHARED-BUFFERS +POSTGRES_SHARED_BUFFERS=128MB + +# Sets the amount of memory used by each database worker for working space. +# Default is 4MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-WORK-MEM +POSTGRES_WORK_MEM=4MB + +# Sets the amount of memory reserved for maintenance activities. +# Default is 64MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-MAINTENANCE-WORK-MEM +POSTGRES_MAINTENANCE_WORK_MEM=64MB + +# Sets the planner's assumption about the effective cache size. +# Default is 4096MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-query.html#GUC-EFFECTIVE-CACHE-SIZE +POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB + +# ------------------------------ +# Redis Configuration +# This Redis configuration is used for caching and for pub/sub during conversation. +# ------------------------------ + +REDIS_HOST=redis +REDIS_PORT=6379 +REDIS_USERNAME= +REDIS_PASSWORD=difyai123456 +REDIS_USE_SSL=false +REDIS_DB=0 + +# Whether to use Redis Sentinel mode. +# If set to true, the application will automatically discover and connect to the master node through Sentinel. +REDIS_USE_SENTINEL=false + +# List of Redis Sentinel nodes. If Sentinel mode is enabled, provide at least one Sentinel IP and port. +# Format: `:,:,:` +REDIS_SENTINELS= +REDIS_SENTINEL_SERVICE_NAME= +REDIS_SENTINEL_USERNAME= +REDIS_SENTINEL_PASSWORD= +REDIS_SENTINEL_SOCKET_TIMEOUT=0.1 + +# List of Redis Cluster nodes. If Cluster mode is enabled, provide at least one Cluster IP and port. +# Format: `:,:,:` +REDIS_USE_CLUSTERS=false +REDIS_CLUSTERS= +REDIS_CLUSTERS_PASSWORD= + +# ------------------------------ +# Celery Configuration +# ------------------------------ + +# Use redis as the broker, and redis db 1 for celery broker. +# Format as follows: `redis://:@:/` +# Example: redis://:difyai123456@redis:6379/1 +# If use Redis Sentinel, format as follows: `sentinel://:@:/` +# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1 +CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1 +BROKER_USE_SSL=false + +# If you are using Redis Sentinel for high availability, configure the following settings. +CELERY_USE_SENTINEL=false +CELERY_SENTINEL_MASTER_NAME= +CELERY_SENTINEL_PASSWORD= +CELERY_SENTINEL_SOCKET_TIMEOUT=0.1 + +# ------------------------------ +# CORS Configuration +# Used to set the front-end cross-domain access policy. +# ------------------------------ + +# Specifies the allowed origins for cross-origin requests to the Web API, +# e.g. https://dify.app or * for all origins. +WEB_API_CORS_ALLOW_ORIGINS=* + +# Specifies the allowed origins for cross-origin requests to the console API, +# e.g. https://cloud.dify.ai or * for all origins. +CONSOLE_CORS_ALLOW_ORIGINS=* + +# ------------------------------ +# File Storage Configuration +# ------------------------------ + +# The type of storage to use for storing user files. +STORAGE_TYPE=opendal + +# Apache OpenDAL Configuration +# The configuration for OpenDAL consists of the following format: OPENDAL__. +# You can find all the service configurations (CONFIG_NAME) in the repository at: https://github.com/apache/opendal/tree/main/core/src/services. +# Dify will scan configurations starting with OPENDAL_ and automatically apply them. +# The scheme name for the OpenDAL storage. +OPENDAL_SCHEME=fs +# Configurations for OpenDAL Local File System. +OPENDAL_FS_ROOT=storage + +# ClickZetta Volume Configuration (for storage backend) +# To use ClickZetta Volume as storage backend, set STORAGE_TYPE=clickzetta-volume +# Note: ClickZetta Volume will reuse the existing CLICKZETTA_* connection parameters + +# Volume type selection (three types available): +# - user: Personal/small team use, simple config, user-level permissions +# - table: Enterprise multi-tenant, smart routing, table-level + user-level permissions +# - external: Data lake integration, external storage connection, volume-level + storage-level permissions +CLICKZETTA_VOLUME_TYPE=user + +# External Volume name (required only when TYPE=external) +CLICKZETTA_VOLUME_NAME= + +# Table Volume table prefix (used only when TYPE=table) +CLICKZETTA_VOLUME_TABLE_PREFIX=dataset_ + +# Dify file directory prefix (isolates from other apps, recommended to keep default) +CLICKZETTA_VOLUME_DIFY_PREFIX=dify_km + +# S3 Configuration +# +S3_ENDPOINT= +S3_REGION=us-east-1 +S3_BUCKET_NAME=difyai +S3_ACCESS_KEY= +S3_SECRET_KEY= +# Whether to use AWS managed IAM roles for authenticating with the S3 service. +# If set to false, the access key and secret key must be provided. +S3_USE_AWS_MANAGED_IAM=false + +# Azure Blob Configuration +# +AZURE_BLOB_ACCOUNT_NAME=difyai +AZURE_BLOB_ACCOUNT_KEY=difyai +AZURE_BLOB_CONTAINER_NAME=difyai-container +AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net + +# Google Storage Configuration +# +GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name +GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64= + +# The Alibaba Cloud OSS configurations, +# +ALIYUN_OSS_BUCKET_NAME=your-bucket-name +ALIYUN_OSS_ACCESS_KEY=your-access-key +ALIYUN_OSS_SECRET_KEY=your-secret-key +ALIYUN_OSS_ENDPOINT=https://oss-ap-southeast-1-internal.aliyuncs.com +ALIYUN_OSS_REGION=ap-southeast-1 +ALIYUN_OSS_AUTH_VERSION=v4 +# Don't start with '/'. OSS doesn't support leading slash in object names. +ALIYUN_OSS_PATH=your-path + +# Tencent COS Configuration +# +TENCENT_COS_BUCKET_NAME=your-bucket-name +TENCENT_COS_SECRET_KEY=your-secret-key +TENCENT_COS_SECRET_ID=your-secret-id +TENCENT_COS_REGION=your-region +TENCENT_COS_SCHEME=your-scheme + +# Oracle Storage Configuration +# +OCI_ENDPOINT=https://your-object-storage-namespace.compat.objectstorage.us-ashburn-1.oraclecloud.com +OCI_BUCKET_NAME=your-bucket-name +OCI_ACCESS_KEY=your-access-key +OCI_SECRET_KEY=your-secret-key +OCI_REGION=us-ashburn-1 + +# Huawei OBS Configuration +# +HUAWEI_OBS_BUCKET_NAME=your-bucket-name +HUAWEI_OBS_SECRET_KEY=your-secret-key +HUAWEI_OBS_ACCESS_KEY=your-access-key +HUAWEI_OBS_SERVER=your-server-url + +# Volcengine TOS Configuration +# +VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name +VOLCENGINE_TOS_SECRET_KEY=your-secret-key +VOLCENGINE_TOS_ACCESS_KEY=your-access-key +VOLCENGINE_TOS_ENDPOINT=your-server-url +VOLCENGINE_TOS_REGION=your-region + +# Baidu OBS Storage Configuration +# +BAIDU_OBS_BUCKET_NAME=your-bucket-name +BAIDU_OBS_SECRET_KEY=your-secret-key +BAIDU_OBS_ACCESS_KEY=your-access-key +BAIDU_OBS_ENDPOINT=your-server-url + +# Supabase Storage Configuration +# +SUPABASE_BUCKET_NAME=your-bucket-name +SUPABASE_API_KEY=your-access-key +SUPABASE_URL=your-server-url + +# ------------------------------ +# Vector Database Configuration +# ------------------------------ + +# The type of vector store to use. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. +VECTOR_STORE=weaviate + +# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. +WEAVIATE_ENDPOINT=http://weaviate:8080 +WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih + +# The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`. +QDRANT_URL=http://qdrant:6333 +QDRANT_API_KEY=difyai123456 +QDRANT_CLIENT_TIMEOUT=20 +QDRANT_GRPC_ENABLED=false +QDRANT_GRPC_PORT=6334 +QDRANT_REPLICATION_FACTOR=1 + +# Milvus configuration. Only available when VECTOR_STORE is `milvus`. +# The milvus uri. +MILVUS_URI=http://host.docker.internal:19530 +MILVUS_DATABASE= +MILVUS_TOKEN= +MILVUS_USER= +MILVUS_PASSWORD= +MILVUS_ENABLE_HYBRID_SEARCH=False +MILVUS_ANALYZER_PARAMS= + +# MyScale configuration, only available when VECTOR_STORE is `myscale` +# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to: +# https://myscale.com/docs/en/text-search/#understanding-fts-index-parameters +MYSCALE_HOST=myscale +MYSCALE_PORT=8123 +MYSCALE_USER=default +MYSCALE_PASSWORD= +MYSCALE_DATABASE=dify +MYSCALE_FTS_PARAMS= + +# Couchbase configurations, only available when VECTOR_STORE is `couchbase` +# The connection string must include hostname defined in the docker-compose file (couchbase-server in this case) +COUCHBASE_CONNECTION_STRING=couchbase://couchbase-server +COUCHBASE_USER=Administrator +COUCHBASE_PASSWORD=password +COUCHBASE_BUCKET_NAME=Embeddings +COUCHBASE_SCOPE_NAME=_default + +# pgvector configurations, only available when VECTOR_STORE is `pgvector` +PGVECTOR_HOST=pgvector +PGVECTOR_PORT=5432 +PGVECTOR_USER=postgres +PGVECTOR_PASSWORD=difyai123456 +PGVECTOR_DATABASE=dify +PGVECTOR_MIN_CONNECTION=1 +PGVECTOR_MAX_CONNECTION=5 +PGVECTOR_PG_BIGM=false +PGVECTOR_PG_BIGM_VERSION=1.2-20240606 + +# vastbase configurations, only available when VECTOR_STORE is `vastbase` +VASTBASE_HOST=vastbase +VASTBASE_PORT=5432 +VASTBASE_USER=dify +VASTBASE_PASSWORD=Difyai123456 +VASTBASE_DATABASE=dify +VASTBASE_MIN_CONNECTION=1 +VASTBASE_MAX_CONNECTION=5 + +# pgvecto-rs configurations, only available when VECTOR_STORE is `pgvecto-rs` +PGVECTO_RS_HOST=pgvecto-rs +PGVECTO_RS_PORT=5432 +PGVECTO_RS_USER=postgres +PGVECTO_RS_PASSWORD=difyai123456 +PGVECTO_RS_DATABASE=dify + +# analyticdb configurations, only available when VECTOR_STORE is `analyticdb` +ANALYTICDB_KEY_ID=your-ak +ANALYTICDB_KEY_SECRET=your-sk +ANALYTICDB_REGION_ID=cn-hangzhou +ANALYTICDB_INSTANCE_ID=gp-ab123456 +ANALYTICDB_ACCOUNT=testaccount +ANALYTICDB_PASSWORD=testpassword +ANALYTICDB_NAMESPACE=dify +ANALYTICDB_NAMESPACE_PASSWORD=difypassword +ANALYTICDB_HOST=gp-test.aliyuncs.com +ANALYTICDB_PORT=5432 +ANALYTICDB_MIN_CONNECTION=1 +ANALYTICDB_MAX_CONNECTION=5 + +# TiDB vector configurations, only available when VECTOR_STORE is `tidb_vector` +TIDB_VECTOR_HOST=tidb +TIDB_VECTOR_PORT=4000 +TIDB_VECTOR_USER= +TIDB_VECTOR_PASSWORD= +TIDB_VECTOR_DATABASE=dify + +# Matrixone vector configurations. +MATRIXONE_HOST=matrixone +MATRIXONE_PORT=6001 +MATRIXONE_USER=dump +MATRIXONE_PASSWORD=111 +MATRIXONE_DATABASE=dify + +# Tidb on qdrant configuration, only available when VECTOR_STORE is `tidb_on_qdrant` +TIDB_ON_QDRANT_URL=http://127.0.0.1 +TIDB_ON_QDRANT_API_KEY=dify +TIDB_ON_QDRANT_CLIENT_TIMEOUT=20 +TIDB_ON_QDRANT_GRPC_ENABLED=false +TIDB_ON_QDRANT_GRPC_PORT=6334 +TIDB_PUBLIC_KEY=dify +TIDB_PRIVATE_KEY=dify +TIDB_API_URL=http://127.0.0.1 +TIDB_IAM_API_URL=http://127.0.0.1 +TIDB_REGION=regions/aws-us-east-1 +TIDB_PROJECT_ID=dify +TIDB_SPEND_LIMIT=100 + +# Chroma configuration, only available when VECTOR_STORE is `chroma` +CHROMA_HOST=127.0.0.1 +CHROMA_PORT=8000 +CHROMA_TENANT=default_tenant +CHROMA_DATABASE=default_database +CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthClientProvider +CHROMA_AUTH_CREDENTIALS= + +# Oracle configuration, only available when VECTOR_STORE is `oracle` +ORACLE_USER=dify +ORACLE_PASSWORD=dify +ORACLE_DSN=oracle:1521/FREEPDB1 +ORACLE_CONFIG_DIR=/app/api/storage/wallet +ORACLE_WALLET_LOCATION=/app/api/storage/wallet +ORACLE_WALLET_PASSWORD=dify +ORACLE_IS_AUTONOMOUS=false + +# relyt configurations, only available when VECTOR_STORE is `relyt` +RELYT_HOST=db +RELYT_PORT=5432 +RELYT_USER=postgres +RELYT_PASSWORD=difyai123456 +RELYT_DATABASE=postgres + +# open search configuration, only available when VECTOR_STORE is `opensearch` +OPENSEARCH_HOST=opensearch +OPENSEARCH_PORT=9200 +OPENSEARCH_SECURE=true +OPENSEARCH_VERIFY_CERTS=true +OPENSEARCH_AUTH_METHOD=basic +OPENSEARCH_USER=admin +OPENSEARCH_PASSWORD=admin +# If using AWS managed IAM, e.g. Managed Cluster or OpenSearch Serverless +OPENSEARCH_AWS_REGION=ap-southeast-1 +OPENSEARCH_AWS_SERVICE=aoss + +# tencent vector configurations, only available when VECTOR_STORE is `tencent` +TENCENT_VECTOR_DB_URL=http://127.0.0.1 +TENCENT_VECTOR_DB_API_KEY=dify +TENCENT_VECTOR_DB_TIMEOUT=30 +TENCENT_VECTOR_DB_USERNAME=dify +TENCENT_VECTOR_DB_DATABASE=dify +TENCENT_VECTOR_DB_SHARD=1 +TENCENT_VECTOR_DB_REPLICAS=2 +TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false + +# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch` +ELASTICSEARCH_HOST=0.0.0.0 +ELASTICSEARCH_PORT=9200 +ELASTICSEARCH_USERNAME=elastic +ELASTICSEARCH_PASSWORD=elastic +KIBANA_PORT=5601 + +# baidu vector configurations, only available when VECTOR_STORE is `baidu` +BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 +BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 +BAIDU_VECTOR_DB_ACCOUNT=root +BAIDU_VECTOR_DB_API_KEY=dify +BAIDU_VECTOR_DB_DATABASE=dify +BAIDU_VECTOR_DB_SHARD=1 +BAIDU_VECTOR_DB_REPLICAS=3 + +# VikingDB configurations, only available when VECTOR_STORE is `vikingdb` +VIKINGDB_ACCESS_KEY=your-ak +VIKINGDB_SECRET_KEY=your-sk +VIKINGDB_REGION=cn-shanghai +VIKINGDB_HOST=api-vikingdb.xxx.volces.com +VIKINGDB_SCHEMA=http +VIKINGDB_CONNECTION_TIMEOUT=30 +VIKINGDB_SOCKET_TIMEOUT=30 + +# Lindorm configuration, only available when VECTOR_STORE is `lindorm` +LINDORM_URL=http://lindorm:30070 +LINDORM_USERNAME=lindorm +LINDORM_PASSWORD=lindorm +LINDORM_QUERY_TIMEOUT=1 + +# OceanBase Vector configuration, only available when VECTOR_STORE is `oceanbase` +OCEANBASE_VECTOR_HOST=oceanbase +OCEANBASE_VECTOR_PORT=2881 +OCEANBASE_VECTOR_USER=root@test +OCEANBASE_VECTOR_PASSWORD=difyai123456 +OCEANBASE_VECTOR_DATABASE=test +OCEANBASE_CLUSTER_NAME=difyai +OCEANBASE_MEMORY_LIMIT=6G +OCEANBASE_ENABLE_HYBRID_SEARCH=false + +# opengauss configurations, only available when VECTOR_STORE is `opengauss` +OPENGAUSS_HOST=opengauss +OPENGAUSS_PORT=6600 +OPENGAUSS_USER=postgres +OPENGAUSS_PASSWORD=Dify@123 +OPENGAUSS_DATABASE=dify +OPENGAUSS_MIN_CONNECTION=1 +OPENGAUSS_MAX_CONNECTION=5 +OPENGAUSS_ENABLE_PQ=false + +# huawei cloud search service vector configurations, only available when VECTOR_STORE is `huawei_cloud` +HUAWEI_CLOUD_HOSTS=https://127.0.0.1:9200 +HUAWEI_CLOUD_USER=admin +HUAWEI_CLOUD_PASSWORD=admin + +# Upstash Vector configuration, only available when VECTOR_STORE is `upstash` +UPSTASH_VECTOR_URL=https://xxx-vector.upstash.io +UPSTASH_VECTOR_TOKEN=dify + +# TableStore Vector configuration +# (only used when VECTOR_STORE is tablestore) +TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com +TABLESTORE_INSTANCE_NAME=instance-name +TABLESTORE_ACCESS_KEY_ID=xxx +TABLESTORE_ACCESS_KEY_SECRET=xxx + +# Clickzetta configuration, only available when VECTOR_STORE is `clickzetta` +CLICKZETTA_USERNAME= +CLICKZETTA_PASSWORD= +CLICKZETTA_INSTANCE= +CLICKZETTA_SERVICE=api.clickzetta.com +CLICKZETTA_WORKSPACE=quick_start +CLICKZETTA_VCLUSTER=default_ap +CLICKZETTA_SCHEMA=dify +CLICKZETTA_BATCH_SIZE=100 +CLICKZETTA_ENABLE_INVERTED_INDEX=true +CLICKZETTA_ANALYZER_TYPE=chinese +CLICKZETTA_ANALYZER_MODE=smart +CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance + +# ------------------------------ +# Knowledge Configuration +# ------------------------------ + +# Upload file size limit, default 15M. +UPLOAD_FILE_SIZE_LIMIT=15 + +# The maximum number of files that can be uploaded at a time, default 5. +UPLOAD_FILE_BATCH_LIMIT=5 + +# ETL type, support: `dify`, `Unstructured` +# `dify` Dify's proprietary file extraction scheme +# `Unstructured` Unstructured.io file extraction scheme +ETL_TYPE=dify + +# Unstructured API path and API key, needs to be configured when ETL_TYPE is Unstructured +# Or using Unstructured for document extractor node for pptx. +# For example: http://unstructured:8000/general/v0/general +UNSTRUCTURED_API_URL= +UNSTRUCTURED_API_KEY= +SCARF_NO_ANALYTICS=true + +# ------------------------------ +# Model Configuration +# ------------------------------ + +# The maximum number of tokens allowed for prompt generation. +# This setting controls the upper limit of tokens that can be used by the LLM +# when generating a prompt in the prompt generation tool. +# Default: 512 tokens. +PROMPT_GENERATION_MAX_TOKENS=512 + +# The maximum number of tokens allowed for code generation. +# This setting controls the upper limit of tokens that can be used by the LLM +# when generating code in the code generation tool. +# Default: 1024 tokens. +CODE_GENERATION_MAX_TOKENS=1024 + +# Enable or disable plugin based token counting. If disabled, token counting will return 0. +# This can improve performance by skipping token counting operations. +# Default: false (disabled). +PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false + +# ------------------------------ +# Multi-modal Configuration +# ------------------------------ + +# The format of the image/video/audio/document sent when the multi-modal model is input, +# the default is base64, optional url. +# The delay of the call in url mode will be lower than that in base64 mode. +# It is generally recommended to use the more compatible base64 mode. +# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video/audio/document. +MULTIMODAL_SEND_FORMAT=base64 +# Upload image file size limit, default 10M. +UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 +# Upload video file size limit, default 100M. +UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 +# Upload audio file size limit, default 50M. +UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 + +# ------------------------------ +# Sentry Configuration +# Used for application monitoring and error log tracking. +# ------------------------------ +SENTRY_DSN= + +# API Service Sentry DSN address, default is empty, when empty, +# all monitoring information is not reported to Sentry. +# If not set, Sentry error reporting will be disabled. +API_SENTRY_DSN= +# API Service The reporting ratio of Sentry events, if it is 0.01, it is 1%. +API_SENTRY_TRACES_SAMPLE_RATE=1.0 +# API Service The reporting ratio of Sentry profiles, if it is 0.01, it is 1%. +API_SENTRY_PROFILES_SAMPLE_RATE=1.0 + +# Web Service Sentry DSN address, default is empty, when empty, +# all monitoring information is not reported to Sentry. +# If not set, Sentry error reporting will be disabled. +WEB_SENTRY_DSN= + +# ------------------------------ +# Notion Integration Configuration +# Variables can be obtained by applying for Notion integration: https://www.notion.so/my-integrations +# ------------------------------ + +# Configure as "public" or "internal". +# Since Notion's OAuth redirect URL only supports HTTPS, +# if deploying locally, please use Notion's internal integration. +NOTION_INTEGRATION_TYPE=public +# Notion OAuth client secret (used for public integration type) +NOTION_CLIENT_SECRET= +# Notion OAuth client id (used for public integration type) +NOTION_CLIENT_ID= +# Notion internal integration secret. +# If the value of NOTION_INTEGRATION_TYPE is "internal", +# you need to configure this variable. +NOTION_INTERNAL_SECRET= + +# ------------------------------ +# Mail related configuration +# ------------------------------ + +# Mail type, support: resend, smtp, sendgrid +MAIL_TYPE=resend + +# Default send from email address, if not specified +# If using SendGrid, use the 'from' field for authentication if necessary. +MAIL_DEFAULT_SEND_FROM= + +# API-Key for the Resend email provider, used when MAIL_TYPE is `resend`. +RESEND_API_URL=https://api.resend.com +RESEND_API_KEY=your-resend-api-key + + +# SMTP server configuration, used when MAIL_TYPE is `smtp` +SMTP_SERVER= +SMTP_PORT=465 +SMTP_USERNAME= +SMTP_PASSWORD= +SMTP_USE_TLS=true +SMTP_OPPORTUNISTIC_TLS=false + +# Sendgid configuration +SENDGRID_API_KEY= + +# ------------------------------ +# Others Configuration +# ------------------------------ + +# Maximum length of segmentation tokens for indexing +INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 + +# Member invitation link valid time (hours), +# Default: 72. +INVITE_EXPIRY_HOURS=72 + +# Reset password token valid time (minutes), +RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 + +# The sandbox service endpoint. +CODE_EXECUTION_ENDPOINT=http://sandbox:8194 +CODE_EXECUTION_API_KEY=dify-sandbox +CODE_MAX_NUMBER=9223372036854775807 +CODE_MIN_NUMBER=-9223372036854775808 +CODE_MAX_DEPTH=5 +CODE_MAX_PRECISION=20 +CODE_MAX_STRING_LENGTH=80000 +CODE_MAX_STRING_ARRAY_LENGTH=30 +CODE_MAX_OBJECT_ARRAY_LENGTH=30 +CODE_MAX_NUMBER_ARRAY_LENGTH=1000 +CODE_EXECUTION_CONNECT_TIMEOUT=10 +CODE_EXECUTION_READ_TIMEOUT=60 +CODE_EXECUTION_WRITE_TIMEOUT=10 +TEMPLATE_TRANSFORM_MAX_LENGTH=80000 + +# Workflow runtime configuration +WORKFLOW_MAX_EXECUTION_STEPS=500 +WORKFLOW_MAX_EXECUTION_TIME=1200 +WORKFLOW_CALL_MAX_DEPTH=5 +MAX_VARIABLE_SIZE=204800 +WORKFLOW_PARALLEL_DEPTH_LIMIT=3 +WORKFLOW_FILE_UPLOAD_LIMIT=10 + +# Workflow storage configuration +# Options: rdbms, hybrid +# rdbms: Use only the relational database (default) +# hybrid: Save new data to object storage, read from both object storage and RDBMS +WORKFLOW_NODE_EXECUTION_STORAGE=rdbms + +# Repository configuration +# Core workflow execution repository implementation +CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository + +# Core workflow node execution repository implementation +CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository + +# API workflow node execution repository implementation +API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository + +# API workflow run repository implementation +API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository + +# HTTP request node in workflow configuration +HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 +HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 +HTTP_REQUEST_NODE_SSL_VERIFY=True + +# Respect X-* headers to redirect clients +RESPECT_XFORWARD_HEADERS_ENABLED=false + +# SSRF Proxy server HTTP URL +SSRF_PROXY_HTTP_URL=http://ssrf_proxy:3128 +# SSRF Proxy server HTTPS URL +SSRF_PROXY_HTTPS_URL=http://ssrf_proxy:3128 + +# Maximum loop count in the workflow +LOOP_NODE_MAX_COUNT=100 + +# The maximum number of tools that can be used in the agent. +MAX_TOOLS_NUM=10 + +# Maximum number of Parallelism branches in the workflow +MAX_PARALLEL_LIMIT=10 + +# The maximum number of iterations for agent setting +MAX_ITERATIONS_NUM=99 + +# ------------------------------ +# Environment Variables for web Service +# ------------------------------ + +# The timeout for the text generation in millisecond +TEXT_GENERATION_TIMEOUT_MS=60000 + +# Allow rendering unsafe URLs which have "data:" scheme. +ALLOW_UNSAFE_DATA_SCHEME=false + +# ------------------------------ +# Environment Variables for db Service +# ------------------------------ + +# The name of the default postgres user. +POSTGRES_USER=${DB_USERNAME} +# The password for the default postgres user. +POSTGRES_PASSWORD=${DB_PASSWORD} +# The name of the default postgres database. +POSTGRES_DB=${DB_DATABASE} +# postgres data directory +PGDATA=/var/lib/postgresql/data/pgdata + +# ------------------------------ +# Environment Variables for sandbox Service +# ------------------------------ + +# The API key for the sandbox service +SANDBOX_API_KEY=dify-sandbox +# The mode in which the Gin framework runs +SANDBOX_GIN_MODE=release +# The timeout for the worker in seconds +SANDBOX_WORKER_TIMEOUT=15 +# Enable network for the sandbox service +SANDBOX_ENABLE_NETWORK=true +# HTTP proxy URL for SSRF protection +SANDBOX_HTTP_PROXY=http://ssrf_proxy:3128 +# HTTPS proxy URL for SSRF protection +SANDBOX_HTTPS_PROXY=http://ssrf_proxy:3128 +# The port on which the sandbox service runs +SANDBOX_PORT=8194 + +# ------------------------------ +# Environment Variables for weaviate Service +# (only used when VECTOR_STORE is weaviate) +# ------------------------------ +WEAVIATE_PERSISTENCE_DATA_PATH=/var/lib/weaviate +WEAVIATE_QUERY_DEFAULTS_LIMIT=25 +WEAVIATE_AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true +WEAVIATE_DEFAULT_VECTORIZER_MODULE=none +WEAVIATE_CLUSTER_HOSTNAME=node1 +WEAVIATE_AUTHENTICATION_APIKEY_ENABLED=true +WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih +WEAVIATE_AUTHENTICATION_APIKEY_USERS=hello@dify.ai +WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED=true +WEAVIATE_AUTHORIZATION_ADMINLIST_USERS=hello@dify.ai + +# ------------------------------ +# Environment Variables for Chroma +# (only used when VECTOR_STORE is chroma) +# ------------------------------ + +# Authentication credentials for Chroma server +CHROMA_SERVER_AUTHN_CREDENTIALS=difyai123456 +# Authentication provider for Chroma server +CHROMA_SERVER_AUTHN_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider +# Persistence setting for Chroma server +CHROMA_IS_PERSISTENT=TRUE + +# ------------------------------ +# Environment Variables for Oracle Service +# (only used when VECTOR_STORE is oracle) +# ------------------------------ +ORACLE_PWD=Dify123456 +ORACLE_CHARACTERSET=AL32UTF8 + +# ------------------------------ +# Environment Variables for milvus Service +# (only used when VECTOR_STORE is milvus) +# ------------------------------ +# ETCD configuration for auto compaction mode +ETCD_AUTO_COMPACTION_MODE=revision +# ETCD configuration for auto compaction retention in terms of number of revisions +ETCD_AUTO_COMPACTION_RETENTION=1000 +# ETCD configuration for backend quota in bytes +ETCD_QUOTA_BACKEND_BYTES=4294967296 +# ETCD configuration for the number of changes before triggering a snapshot +ETCD_SNAPSHOT_COUNT=50000 +# MinIO access key for authentication +MINIO_ACCESS_KEY=minioadmin +# MinIO secret key for authentication +MINIO_SECRET_KEY=minioadmin +# ETCD service endpoints +ETCD_ENDPOINTS=etcd:2379 +# MinIO service address +MINIO_ADDRESS=minio:9000 +# Enable or disable security authorization +MILVUS_AUTHORIZATION_ENABLED=true + +# ------------------------------ +# Environment Variables for pgvector / pgvector-rs Service +# (only used when VECTOR_STORE is pgvector / pgvector-rs) +# ------------------------------ +PGVECTOR_PGUSER=postgres +# The password for the default postgres user. +PGVECTOR_POSTGRES_PASSWORD=difyai123456 +# The name of the default postgres database. +PGVECTOR_POSTGRES_DB=dify +# postgres data directory +PGVECTOR_PGDATA=/var/lib/postgresql/data/pgdata + +# ------------------------------ +# Environment Variables for opensearch +# (only used when VECTOR_STORE is opensearch) +# ------------------------------ +OPENSEARCH_DISCOVERY_TYPE=single-node +OPENSEARCH_BOOTSTRAP_MEMORY_LOCK=true +OPENSEARCH_JAVA_OPTS_MIN=512m +OPENSEARCH_JAVA_OPTS_MAX=1024m +OPENSEARCH_INITIAL_ADMIN_PASSWORD=Qazwsxedc!@#123 +OPENSEARCH_MEMLOCK_SOFT=-1 +OPENSEARCH_MEMLOCK_HARD=-1 +OPENSEARCH_NOFILE_SOFT=65536 +OPENSEARCH_NOFILE_HARD=65536 + +# ------------------------------ +# Environment Variables for Nginx reverse proxy +# ------------------------------ +NGINX_SERVER_NAME=_ +NGINX_HTTPS_ENABLED=false +# HTTP port +NGINX_PORT=80 +# SSL settings are only applied when HTTPS_ENABLED is true +NGINX_SSL_PORT=443 +# if HTTPS_ENABLED is true, you're required to add your own SSL certificates/keys to the `./nginx/ssl` directory +# and modify the env vars below accordingly. +NGINX_SSL_CERT_FILENAME=dify.crt +NGINX_SSL_CERT_KEY_FILENAME=dify.key +NGINX_SSL_PROTOCOLS=TLSv1.1 TLSv1.2 TLSv1.3 + +# Nginx performance tuning +NGINX_WORKER_PROCESSES=auto +NGINX_CLIENT_MAX_BODY_SIZE=100M +NGINX_KEEPALIVE_TIMEOUT=65 + +# Proxy settings +NGINX_PROXY_READ_TIMEOUT=3600s +NGINX_PROXY_SEND_TIMEOUT=3600s + +# Set true to accept requests for /.well-known/acme-challenge/ +NGINX_ENABLE_CERTBOT_CHALLENGE=false + +# ------------------------------ +# Certbot Configuration +# ------------------------------ + +# Email address (required to get certificates from Let's Encrypt) +CERTBOT_EMAIL=your_email@example.com + +# Domain name +CERTBOT_DOMAIN=your_domain.com + +# certbot command options +# i.e: --force-renewal --dry-run --test-cert --debug +CERTBOT_OPTIONS= + +# ------------------------------ +# Environment Variables for SSRF Proxy +# ------------------------------ +SSRF_HTTP_PORT=3128 +SSRF_COREDUMP_DIR=/var/spool/squid +SSRF_REVERSE_PROXY_PORT=8194 +SSRF_SANDBOX_HOST=sandbox +SSRF_DEFAULT_TIME_OUT=5 +SSRF_DEFAULT_CONNECT_TIME_OUT=5 +SSRF_DEFAULT_READ_TIME_OUT=5 +SSRF_DEFAULT_WRITE_TIME_OUT=5 + +# ------------------------------ +# docker env var for specifying vector db type at startup +# (based on the vector db type, the corresponding docker +# compose profile will be used) +# if you want to use unstructured, add ',unstructured' to the end +# ------------------------------ +COMPOSE_PROFILES=${VECTOR_STORE:-weaviate} + +# ------------------------------ +# Docker Compose Service Expose Host Port Configurations +# ------------------------------ +EXPOSE_NGINX_PORT=80 +EXPOSE_NGINX_SSL_PORT=443 + +# ---------------------------------------------------------------------------- +# ModelProvider & Tool Position Configuration +# Used to specify the model providers and tools that can be used in the app. +# ---------------------------------------------------------------------------- + +# Pin, include, and exclude tools +# Use comma-separated values with no spaces between items. +# Example: POSITION_TOOL_PINS=bing,google +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +# Pin, include, and exclude model providers +# Use comma-separated values with no spaces between items. +# Example: POSITION_PROVIDER_PINS=openai,openllm +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= + +# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP +CSP_WHITELIST= + +# Enable or disable create tidb service job +CREATE_TIDB_SERVICE_JOB_ENABLED=false + +# Maximum number of submitted thread count in a ThreadPool for parallel node execution +MAX_SUBMIT_COUNT=100 + +# The maximum number of top-k value for RAG. +TOP_K_MAX_VALUE=10 + +# ------------------------------ +# Plugin Daemon Configuration +# ------------------------------ + +DB_PLUGIN_DATABASE=dify_plugin +EXPOSE_PLUGIN_DAEMON_PORT=5002 +PLUGIN_DAEMON_PORT=5002 +PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi +PLUGIN_DAEMON_URL=http://plugin_daemon:5002 +PLUGIN_MAX_PACKAGE_SIZE=52428800 +PLUGIN_PPROF_ENABLED=false + +PLUGIN_DEBUGGING_HOST=0.0.0.0 +PLUGIN_DEBUGGING_PORT=5003 +EXPOSE_PLUGIN_DEBUGGING_HOST=localhost +EXPOSE_PLUGIN_DEBUGGING_PORT=5003 + +# If this key is changed, DIFY_INNER_API_KEY in plugin_daemon service must also be updated or agent node will fail. +PLUGIN_DIFY_INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 +PLUGIN_DIFY_INNER_API_URL=http://api:5001 + +ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id} + +MARKETPLACE_ENABLED=true +MARKETPLACE_API_URL=https://marketplace.dify.ai + +FORCE_VERIFYING_SIGNATURE=true + +PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 +PLUGIN_MAX_EXECUTION_TIMEOUT=600 +# PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple +PIP_MIRROR_URL= + +# https://github.com/langgenius/dify-plugin-daemon/blob/main/.env.example +# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss volcengine_tos +PLUGIN_STORAGE_TYPE=local +PLUGIN_STORAGE_LOCAL_ROOT=/app/storage +PLUGIN_WORKING_PATH=/app/storage/cwd +PLUGIN_INSTALLED_PATH=plugin +PLUGIN_PACKAGE_CACHE_PATH=plugin_packages +PLUGIN_MEDIA_CACHE_PATH=assets +# Plugin oss bucket +PLUGIN_STORAGE_OSS_BUCKET= +# Plugin oss s3 credentials +PLUGIN_S3_USE_AWS=false +PLUGIN_S3_USE_AWS_MANAGED_IAM=false +PLUGIN_S3_ENDPOINT= +PLUGIN_S3_USE_PATH_STYLE=false +PLUGIN_AWS_ACCESS_KEY= +PLUGIN_AWS_SECRET_KEY= +PLUGIN_AWS_REGION= +# Plugin oss azure blob +PLUGIN_AZURE_BLOB_STORAGE_CONTAINER_NAME= +PLUGIN_AZURE_BLOB_STORAGE_CONNECTION_STRING= +# Plugin oss tencent cos +PLUGIN_TENCENT_COS_SECRET_KEY= +PLUGIN_TENCENT_COS_SECRET_ID= +PLUGIN_TENCENT_COS_REGION= +# Plugin oss aliyun oss +PLUGIN_ALIYUN_OSS_REGION= +PLUGIN_ALIYUN_OSS_ENDPOINT= +PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID= +PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET= +PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4 +PLUGIN_ALIYUN_OSS_PATH= +# Plugin oss volcengine tos +PLUGIN_VOLCENGINE_TOS_ENDPOINT= +PLUGIN_VOLCENGINE_TOS_ACCESS_KEY= +PLUGIN_VOLCENGINE_TOS_SECRET_KEY= +PLUGIN_VOLCENGINE_TOS_REGION= + +# ------------------------------ +# OTLP Collector Configuration +# ------------------------------ +ENABLE_OTEL=false +OTLP_TRACE_ENDPOINT= +OTLP_METRIC_ENDPOINT= +OTLP_BASE_ENDPOINT=http://localhost:4318 +OTLP_API_KEY= +OTEL_EXPORTER_OTLP_PROTOCOL= +OTEL_EXPORTER_TYPE=otlp +OTEL_SAMPLING_RATE=0.1 +OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000 +OTEL_MAX_QUEUE_SIZE=2048 +OTEL_MAX_EXPORT_BATCH_SIZE=512 +OTEL_METRIC_EXPORT_INTERVAL=60000 +OTEL_BATCH_EXPORT_TIMEOUT=10000 +OTEL_METRIC_EXPORT_TIMEOUT=30000 + +# Prevent Clickjacking +ALLOW_EMBED=false + +# Dataset queue monitor configuration +QUEUE_MONITOR_THRESHOLD=200 +# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai +QUEUE_MONITOR_ALERT_EMAILS= +# Monitor interval in minutes, default is 30 minutes +QUEUE_MONITOR_INTERVAL=30 diff --git a/.github/ISSUE_TEMPLATE/chore.yaml b/.github/ISSUE_TEMPLATE/chore.yaml index 43449ef942..cf74dcc546 100644 --- a/.github/ISSUE_TEMPLATE/chore.yaml +++ b/.github/ISSUE_TEMPLATE/chore.yaml @@ -4,6 +4,23 @@ title: "[Chore/Refactor] " labels: - refactor body: + - type: checkboxes + attributes: + label: Self Checks + description: "To make sure we get to you in time, please check the following :)" + options: + - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). + required: true + - label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). + required: true + - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. + required: true + - label: I confirm that I am using English to submit this report, otherwise it will be closed. + required: true + - label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :) + required: true + - label: "Please do not modify this template :) and fill in all the required fields." + required: true - type: textarea id: description attributes: diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index c79d58563f..4b06174ee1 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -1,13 +1,18 @@ name: Check i18n Files and Create PR on: - pull_request: - types: [closed] + push: branches: [main] + paths: + - 'web/i18n/en-US/*.ts' + +permissions: + contents: write + pull-requests: write jobs: check-and-update: - if: github.event.pull_request.merged == true + if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest defaults: run: @@ -15,8 +20,8 @@ jobs: steps: - uses: actions/checkout@v4 with: - fetch-depth: 2 # last 2 commits - persist-credentials: false + fetch-depth: 2 + token: ${{ secrets.GITHUB_TOKEN }} - name: Check for file changes in i18n/en-US id: check_files @@ -27,6 +32,13 @@ jobs: echo "Changed files: $changed_files" if [ -n "$changed_files" ]; then echo "FILES_CHANGED=true" >> $GITHUB_ENV + file_args="" + for file in $changed_files; do + filename=$(basename "$file" .ts) + file_args="$file_args --file=$filename" + done + echo "FILE_ARGS=$file_args" >> $GITHUB_ENV + echo "File arguments: $file_args" else echo "FILES_CHANGED=false" >> $GITHUB_ENV fi @@ -49,14 +61,15 @@ jobs: if: env.FILES_CHANGED == 'true' run: pnpm install --frozen-lockfile - - name: Run npm script + - name: Generate i18n translations if: env.FILES_CHANGED == 'true' - run: pnpm run auto-gen-i18n + run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }} - name: Create Pull Request if: env.FILES_CHANGED == 'true' uses: peter-evans/create-pull-request@v6 with: + token: ${{ secrets.GITHUB_TOKEN }} commit-message: Update i18n files based on en-US changes title: 'chore: translate i18n files' body: This PR was automatically created to update i18n files based on changes in en-US locale. diff --git a/.gitignore b/.gitignore index dd4673a3d2..5c68d89a4d 100644 --- a/.gitignore +++ b/.gitignore @@ -215,3 +215,4 @@ mise.toml # AI Assistant .roo/ api/.env.backup +/clickzetta diff --git a/api/Dockerfile b/api/Dockerfile index e097b5811e..d69291f7ea 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -19,7 +19,7 @@ RUN apt-get update \ # Install Python dependencies COPY pyproject.toml uv.lock ./ -RUN uv sync --locked +RUN uv sync --locked --no-dev # production stage FROM base AS production diff --git a/api/commands.py b/api/commands.py index 8177f1a48c..8ee52ba716 100644 --- a/api/commands.py +++ b/api/commands.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from flask import current_app from pydantic import TypeAdapter from sqlalchemy import select -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError from configs import dify_config from constants.languages import languages @@ -181,8 +181,8 @@ def migrate_annotation_vector_database(): ) if not apps: break - except NotFound: - break + except SQLAlchemyError: + raise page += 1 for app in apps: @@ -308,8 +308,8 @@ def migrate_knowledge_vector_database(): ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - except NotFound: - break + except SQLAlchemyError: + raise page += 1 for dataset in datasets: @@ -561,8 +561,8 @@ def old_metadata_migration(): .order_by(DatasetDocument.created_at.desc()) ) documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - except NotFound: - break + except SQLAlchemyError: + raise if not documents: break for document in documents: diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 9f1646ea7d..4dbc8207f1 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -330,17 +330,17 @@ class HttpConfig(BaseSettings): def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") - HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[ - PositiveInt, Field(ge=10, description="Maximum connection timeout in seconds for HTTP requests") - ] = 10 + HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field( + ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10 + ) - HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[ - PositiveInt, Field(ge=60, description="Maximum read timeout in seconds for HTTP requests") - ] = 60 + HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field( + ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60 + ) - HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[ - PositiveInt, Field(ge=10, description="Maximum write timeout in seconds for HTTP requests") - ] = 20 + HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field( + ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20 + ) HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( description="Maximum allowed size in bytes for binary data in HTTP requests", diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index ff290ff99d..4e228ab932 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -10,6 +10,7 @@ from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig from .storage.amazon_s3_storage_config import S3StorageConfig from .storage.azure_blob_storage_config import AzureBlobStorageConfig from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig +from .storage.clickzetta_volume_storage_config import ClickZettaVolumeStorageConfig from .storage.google_cloud_storage_config import GoogleCloudStorageConfig from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig from .storage.oci_storage_config import OCIStorageConfig @@ -20,6 +21,7 @@ from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig from .vdb.analyticdb_config import AnalyticdbConfig from .vdb.baidu_vector_config import BaiduVectorDBConfig from .vdb.chroma_config import ChromaConfig +from .vdb.clickzetta_config import ClickzettaConfig from .vdb.couchbase_config import CouchbaseConfig from .vdb.elasticsearch_config import ElasticsearchConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig @@ -52,6 +54,7 @@ class StorageConfig(BaseSettings): "aliyun-oss", "azure-blob", "baidu-obs", + "clickzetta-volume", "google-storage", "huawei-obs", "oci-storage", @@ -61,8 +64,9 @@ class StorageConfig(BaseSettings): "local", ] = Field( description="Type of storage to use." - " Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', " - "'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.", + " Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', " + "'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', " + "'volcengine-tos', 'supabase'. Default is 'opendal'.", default="opendal", ) @@ -303,6 +307,7 @@ class MiddlewareConfig( AliyunOSSStorageConfig, AzureBlobStorageConfig, BaiduOBSStorageConfig, + ClickZettaVolumeStorageConfig, GoogleCloudStorageConfig, HuaweiCloudOBSStorageConfig, OCIStorageConfig, @@ -315,6 +320,7 @@ class MiddlewareConfig( VectorStoreConfig, AnalyticdbConfig, ChromaConfig, + ClickzettaConfig, HuaweiCloudConfig, MilvusConfig, MyScaleConfig, diff --git a/api/configs/middleware/storage/clickzetta_volume_storage_config.py b/api/configs/middleware/storage/clickzetta_volume_storage_config.py new file mode 100644 index 0000000000..56e1b6a957 --- /dev/null +++ b/api/configs/middleware/storage/clickzetta_volume_storage_config.py @@ -0,0 +1,65 @@ +"""ClickZetta Volume Storage Configuration""" + +from typing import Optional + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class ClickZettaVolumeStorageConfig(BaseSettings): + """Configuration for ClickZetta Volume storage.""" + + CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field( + description="Username for ClickZetta Volume authentication", + default=None, + ) + + CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field( + description="Password for ClickZetta Volume authentication", + default=None, + ) + + CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field( + description="ClickZetta instance identifier", + default=None, + ) + + CLICKZETTA_VOLUME_SERVICE: str = Field( + description="ClickZetta service endpoint", + default="api.clickzetta.com", + ) + + CLICKZETTA_VOLUME_WORKSPACE: str = Field( + description="ClickZetta workspace name", + default="quick_start", + ) + + CLICKZETTA_VOLUME_VCLUSTER: str = Field( + description="ClickZetta virtual cluster name", + default="default_ap", + ) + + CLICKZETTA_VOLUME_SCHEMA: str = Field( + description="ClickZetta schema name", + default="dify", + ) + + CLICKZETTA_VOLUME_TYPE: str = Field( + description="ClickZetta volume type (table|user|external)", + default="user", + ) + + CLICKZETTA_VOLUME_NAME: Optional[str] = Field( + description="ClickZetta volume name for external volumes", + default=None, + ) + + CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field( + description="Prefix for ClickZetta volume table names", + default="dataset_", + ) + + CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field( + description="Directory prefix for User Volume to organize Dify files", + default="dify_km", + ) diff --git a/api/configs/middleware/vdb/clickzetta_config.py b/api/configs/middleware/vdb/clickzetta_config.py new file mode 100644 index 0000000000..04f81e25fc --- /dev/null +++ b/api/configs/middleware/vdb/clickzetta_config.py @@ -0,0 +1,69 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class ClickzettaConfig(BaseModel): + """ + Clickzetta Lakehouse vector database configuration + """ + + CLICKZETTA_USERNAME: Optional[str] = Field( + description="Username for authenticating with Clickzetta Lakehouse", + default=None, + ) + + CLICKZETTA_PASSWORD: Optional[str] = Field( + description="Password for authenticating with Clickzetta Lakehouse", + default=None, + ) + + CLICKZETTA_INSTANCE: Optional[str] = Field( + description="Clickzetta Lakehouse instance ID", + default=None, + ) + + CLICKZETTA_SERVICE: Optional[str] = Field( + description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')", + default="api.clickzetta.com", + ) + + CLICKZETTA_WORKSPACE: Optional[str] = Field( + description="Clickzetta workspace name", + default="default", + ) + + CLICKZETTA_VCLUSTER: Optional[str] = Field( + description="Clickzetta virtual cluster name", + default="default_ap", + ) + + CLICKZETTA_SCHEMA: Optional[str] = Field( + description="Database schema name in Clickzetta", + default="public", + ) + + CLICKZETTA_BATCH_SIZE: Optional[int] = Field( + description="Batch size for bulk insert operations", + default=100, + ) + + CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field( + description="Enable inverted index for full-text search capabilities", + default=True, + ) + + CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field( + description="Analyzer type for full-text search: keyword, english, chinese, unicode", + default="chinese", + ) + + CLICKZETTA_ANALYZER_MODE: Optional[str] = Field( + description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)", + default="smart", + ) + + CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field( + description="Distance function for vector similarity: l2_distance or cosine_distance", + default="cosine_distance", + ) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 007b1f6d3d..ee6011cd65 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -225,14 +225,15 @@ class AnnotationBatchImportApi(Resource): raise Forbidden() app_id = str(app_id) - # get file from request - file = request.files["file"] # check file if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() + + # get file from request + file = request.files["file"] # check file type if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 9fe32dde6d..1cc13d669c 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -28,6 +28,12 @@ from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] +def _validate_description_length(description): + if description and len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + class AppListApi(Resource): @setup_required @login_required @@ -94,7 +100,7 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") @@ -146,7 +152,7 @@ class AppApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") @@ -189,7 +195,7 @@ class AppCopyApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f551bc2432..2befd2a651 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -41,7 +41,7 @@ def _validate_name(name): def _validate_description_length(description): - if len(description) > 400: + if description and len(description) > 400: raise ValueError("Description cannot exceed 400 characters.") return description @@ -113,7 +113,7 @@ class DatasetListApi(Resource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", @@ -683,6 +683,7 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.HUAWEI_CLOUD | VectorType.TENCENT | VectorType.MATRIXONE + | VectorType.CLICKZETTA ): return { "retrieval_method": [ @@ -731,6 +732,7 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.TENCENT | VectorType.HUAWEI_CLOUD | VectorType.MATRIXONE + | VectorType.CLICKZETTA ): return { "retrieval_method": [ diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 66b6214f82..256ff24b3b 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -49,7 +49,6 @@ class FileApi(Resource): @marshal_with(file_fields) @cloud_edition_billing_resource_check("documents") def post(self): - file = request.files["file"] source_str = request.form.get("source") source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None @@ -58,6 +57,7 @@ class FileApi(Resource): if len(request.files) > 1: raise TooManyFilesError() + file = request.files["file"] if not file.filename: raise FilenameNotExistsError diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 19999e7361..6012c9ecc8 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -191,9 +191,6 @@ class WebappLogoWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): - # get file from request - file = request.files["file"] - # check file if "file" not in request.files: raise NoFileUploadedError() @@ -201,6 +198,8 @@ class WebappLogoWorkspaceApi(Resource): if len(request.files) > 1: raise TooManyFilesError() + # get file from request + file = request.files["file"] if not file.filename: raise FilenameNotExistsError diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index d964e27819..b26f29d98d 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) from . import index -from .app import annotation, app, audio, completion, conversation, file, message, site, workflow +from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow from .dataset import dataset, document, hit_testing, metadata, segment, upload_file from .workspace import models diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index ca91da80c1..ba705f71e2 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -107,3 +107,15 @@ class UnsupportedFileTypeError(BaseHTTPException): error_code = "unsupported_file_type" description = "File type not allowed." code = 415 + + +class FileNotFoundError(BaseHTTPException): + error_code = "file_not_found" + description = "The requested file was not found." + code = 404 + + +class FileAccessDeniedError(BaseHTTPException): + error_code = "file_access_denied" + description = "Access to the requested file is denied." + code = 403 diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index b0fd8e65ef..f09d07bcb6 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -20,18 +20,17 @@ class FileApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) @marshal_with(file_fields) def post(self, app_model: App, end_user: EndUser): - file = request.files["file"] - # check file if "file" not in request.files: raise NoFileUploadedError() - if not file.mimetype: - raise UnsupportedFileTypeError() - if len(request.files) > 1: raise TooManyFilesError() + file = request.files["file"] + if not file.mimetype: + raise UnsupportedFileTypeError() + if not file.filename: raise FilenameNotExistsError diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py new file mode 100644 index 0000000000..57141033d1 --- /dev/null +++ b/api/controllers/service_api/app/file_preview.py @@ -0,0 +1,186 @@ +import logging +from urllib.parse import quote + +from flask import Response +from flask_restful import Resource, reqparse + +from controllers.service_api import api +from controllers.service_api.app.error import ( + FileAccessDeniedError, + FileNotFoundError, +) +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import App, EndUser, Message, MessageFile, UploadFile + +logger = logging.getLogger(__name__) + + +class FilePreviewApi(Resource): + """ + Service API File Preview endpoint + + Provides secure file preview/download functionality for external API users. + Files can only be accessed if they belong to messages within the requesting app's context. + """ + + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) + def get(self, app_model: App, end_user: EndUser, file_id: str): + """ + Preview/Download a file that was uploaded via Service API + + Args: + app_model: The authenticated app model + end_user: The authenticated end user (optional) + file_id: UUID of the file to preview + + Query Parameters: + user: Optional user identifier + as_attachment: Boolean, whether to download as attachment (default: false) + + Returns: + Stream response with file content + + Raises: + FileNotFoundError: File does not exist + FileAccessDeniedError: File access denied (not owned by app) + """ + file_id = str(file_id) + + # Parse query parameters + parser = reqparse.RequestParser() + parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") + args = parser.parse_args() + + # Validate file ownership and get file objects + message_file, upload_file = self._validate_file_ownership(file_id, app_model.id) + + # Get file content generator + try: + generator = storage.load(upload_file.key, stream=True) + except Exception as e: + raise FileNotFoundError(f"Failed to load file content: {str(e)}") + + # Build response with appropriate headers + response = self._build_file_response(generator, upload_file, args["as_attachment"]) + + return response + + def _validate_file_ownership(self, file_id: str, app_id: str) -> tuple[MessageFile, UploadFile]: + """ + Validate that the file belongs to a message within the requesting app's context + + Security validations performed: + 1. File exists in MessageFile table (was used in a conversation) + 2. Message belongs to the requesting app + 3. UploadFile record exists and is accessible + 4. File tenant matches app tenant (additional security layer) + + Args: + file_id: UUID of the file to validate + app_id: UUID of the requesting app + + Returns: + Tuple of (MessageFile, UploadFile) if validation passes + + Raises: + FileNotFoundError: File or related records not found + FileAccessDeniedError: File does not belong to the app's context + """ + try: + # Input validation + if not file_id or not app_id: + raise FileAccessDeniedError("Invalid file or app identifier") + + # First, find the MessageFile that references this upload file + message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first() + + if not message_file: + raise FileNotFoundError("File not found in message context") + + # Get the message and verify it belongs to the requesting app + message = ( + db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first() + ) + + if not message: + raise FileAccessDeniedError("File access denied: not owned by requesting app") + + # Get the actual upload file record + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + + if not upload_file: + raise FileNotFoundError("Upload file record not found") + + # Additional security: verify tenant isolation + app = db.session.query(App).where(App.id == app_id).first() + if app and upload_file.tenant_id != app.tenant_id: + raise FileAccessDeniedError("File access denied: tenant mismatch") + + return message_file, upload_file + + except (FileNotFoundError, FileAccessDeniedError): + # Re-raise our custom exceptions + raise + except Exception as e: + # Log unexpected errors for debugging + logger.exception( + "Unexpected error during file ownership validation", + extra={"file_id": file_id, "app_id": app_id, "error": str(e)}, + ) + raise FileAccessDeniedError("File access validation failed") + + def _build_file_response(self, generator, upload_file: UploadFile, as_attachment: bool = False) -> Response: + """ + Build Flask Response object with appropriate headers for file streaming + + Args: + generator: File content generator from storage + upload_file: UploadFile database record + as_attachment: Whether to set Content-Disposition as attachment + + Returns: + Flask Response object with streaming file content + """ + response = Response( + generator, + mimetype=upload_file.mime_type, + direct_passthrough=True, + headers={}, + ) + + # Add Content-Length if known + if upload_file.size and upload_file.size > 0: + response.headers["Content-Length"] = str(upload_file.size) + + # Add Accept-Ranges header for audio/video files to support seeking + if upload_file.mime_type in [ + "audio/mpeg", + "audio/wav", + "audio/mp4", + "audio/ogg", + "audio/flac", + "audio/aac", + "video/mp4", + "video/webm", + "video/quicktime", + "audio/x-m4a", + ]: + response.headers["Accept-Ranges"] = "bytes" + + # Set Content-Disposition for downloads + if as_attachment and upload_file.name: + encoded_filename = quote(upload_file.name) + response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" + # Override content-type for downloads to force download + response.headers["Content-Type"] = "application/octet-stream" + + # Add caching headers for performance + response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour + + return response + + +# Register the API endpoint +api.add_resource(FilePreviewApi, "/files//preview") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index a499719fc3..29eef41253 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -29,7 +29,7 @@ def _validate_name(name): def _validate_description_length(description): - if len(description) > 400: + if description and len(description) > 400: raise ValueError("Description cannot exceed 400 characters.") return description @@ -87,7 +87,7 @@ class DatasetListApi(DatasetApiResource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 77600aa18c..2955d5d20d 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -234,8 +234,6 @@ class DocumentAddByFileApi(DatasetApiResource): args["retrieval_model"].get("reranking_model").get("reranking_model_name"), ) - # save file info - file = request.files["file"] # check file if "file" not in request.files: raise NoFileUploadedError() @@ -243,6 +241,8 @@ class DocumentAddByFileApi(DatasetApiResource): if len(request.files) > 1: raise TooManyFilesError() + # save file info + file = request.files["file"] if not file.filename: raise FilenameNotExistsError diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 94a525a75d..197859e8f3 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,5 +1,6 @@ from flask import request from flask_restful import Resource, marshal_with, reqparse +from werkzeug.exceptions import Unauthorized from controllers.common import fields from controllers.web import api @@ -75,14 +76,14 @@ class AppWebAuthPermission(Resource): try: auth_header = request.headers.get("Authorization") if auth_header is None: - raise + raise Unauthorized("Authorization header is missing.") if " " not in auth_header: - raise + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, tk = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() if auth_scheme != "bearer": - raise + raise Unauthorized("Authorization scheme must be 'Bearer'") decoded = PassportService().verify(tk) user_id = decoded.get("user_id", "visitor") diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index df06a73a85..8e9317606e 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -12,18 +12,17 @@ from services.file_service import FileService class FileApi(WebApiResource): @marshal_with(file_fields) def post(self, app_model, end_user): - file = request.files["file"] - source = request.form.get("source") - if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() + file = request.files["file"] if not file.filename: raise FilenameNotExistsError + source = request.form.get("source") if source not in ("datasets", None): source = None diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index a75e17af64..3de2f5ca9e 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -118,26 +118,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ): return - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == self.conversation.app_id, - ConversationVariable.conversation_id == self.conversation.id, - ) - with Session(db.engine) as session: - db_conversation_variables = session.scalars(stmt).all() - if not db_conversation_variables: - # Create conversation variables if they don't exist. - db_conversation_variables = [ - ConversationVariable.from_variable( - app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable - ) - for variable in self._workflow.conversation_variables - ] - session.add_all(db_conversation_variables) - # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in db_conversation_variables] - - session.commit() + # Initialize conversation variables + conversation_variables = self._initialize_conversation_variables() # Create a variable pool. system_inputs = SystemVariable( @@ -292,3 +274,100 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): message_id=message_id, trace_manager=app_generate_entity.trace_manager, ) + + def _initialize_conversation_variables(self) -> list[VariableUnion]: + """ + Initialize conversation variables for the current conversation. + + This method: + 1. Loads existing variables from the database + 2. Creates new variables if none exist + 3. Syncs missing variables from the workflow definition + + :return: List of conversation variables ready for use + """ + with Session(db.engine) as session: + existing_variables = self._load_existing_conversation_variables(session) + + if not existing_variables: + # First time initialization - create all variables + existing_variables = self._create_all_conversation_variables(session) + else: + # Check and add any missing variables from the workflow + existing_variables = self._sync_missing_conversation_variables(session, existing_variables) + + # Convert to Variable objects for use in the workflow + conversation_variables = [var.to_variable() for var in existing_variables] + + session.commit() + return cast(list[VariableUnion], conversation_variables) + + def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: + """ + Load existing conversation variables from the database. + + :param session: Database session + :return: List of existing conversation variables + """ + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == self.conversation.app_id, + ConversationVariable.conversation_id == self.conversation.id, + ) + return list(session.scalars(stmt).all()) + + def _create_all_conversation_variables(self, session: Session) -> list[ConversationVariable]: + """ + Create all conversation variables for a new conversation. + + :param session: Database session + :return: List of created conversation variables + """ + new_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable + ) + for variable in self._workflow.conversation_variables + ] + + if new_variables: + session.add_all(new_variables) + + return new_variables + + def _sync_missing_conversation_variables( + self, session: Session, existing_variables: list[ConversationVariable] + ) -> list[ConversationVariable]: + """ + Sync missing conversation variables from the workflow definition. + + This handles the case where new variables are added to a workflow + after conversations have already been created. + + :param session: Database session + :param existing_variables: List of existing conversation variables + :return: Updated list including any newly created variables + """ + # Get IDs of existing and workflow variables + existing_ids = {var.id for var in existing_variables} + workflow_variables = {var.id: var for var in self._workflow.conversation_variables} + + # Find missing variable IDs + missing_ids = set(workflow_variables.keys()) - existing_ids + + if not missing_ids: + return existing_variables + + # Create missing variables with their default values + new_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, + conversation_id=self.conversation.id, + variable=workflow_variables[var_id], + ) + for var_id in missing_ids + ] + + session.add_all(new_variables) + + # Return combined list + return existing_variables + new_variables diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index f0e9425e3f..f3b9dbf758 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -23,6 +23,7 @@ from core.app.entities.task_entities import ( MessageFileStreamResponse, MessageReplaceStreamResponse, MessageStreamResponse, + StreamEvent, WorkflowTaskState, ) from core.llm_generator.llm_generator import LLMGenerator @@ -180,11 +181,15 @@ class MessageCycleManager: :param message_id: message id :return: """ + message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first() + event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE + return MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, answer=answer, from_variable_selector=from_variable_selector, + event=event_type, ) def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 9aaa1f0b10..8bfbd82e1f 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -843,7 +843,7 @@ class ProviderConfiguration(BaseModel): continue status = ModelStatus.ACTIVE - if m.model in model_setting_map: + if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: status = ModelStatus.DISABLED diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 7ce124594a..91f17568b6 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -121,9 +121,8 @@ class TokenBufferMemory: curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) if curr_message_tokens > max_token_limit: - pruned_memory = [] while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: - pruned_memory.append(prompt_messages.pop(0)) + prompt_messages.pop(0) curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) return prompt_messages diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py similarity index 100% rename from api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py rename to api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py diff --git a/api/core/rag/datasource/vdb/clickzetta/README.md b/api/core/rag/datasource/vdb/clickzetta/README.md new file mode 100644 index 0000000000..2ee3e657d3 --- /dev/null +++ b/api/core/rag/datasource/vdb/clickzetta/README.md @@ -0,0 +1,190 @@ +# Clickzetta Vector Database Integration + +This module provides integration with Clickzetta Lakehouse as a vector database for Dify. + +## Features + +- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type +- **Vector Search**: Efficient similarity search using HNSW algorithm +- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities +- **Hybrid Search**: Combine vector similarity and full-text search for better results +- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing +- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments + +## Configuration + +### Required Environment Variables + +All seven configuration parameters are required: + +```bash +# Authentication +CLICKZETTA_USERNAME=your_username +CLICKZETTA_PASSWORD=your_password + +# Instance configuration +CLICKZETTA_INSTANCE=your_instance_id +CLICKZETTA_SERVICE=api.clickzetta.com +CLICKZETTA_WORKSPACE=your_workspace +CLICKZETTA_VCLUSTER=your_vcluster +CLICKZETTA_SCHEMA=your_schema +``` + +### Optional Configuration + +```bash +# Batch processing +CLICKZETTA_BATCH_SIZE=100 + +# Full-text search configuration +CLICKZETTA_ENABLE_INVERTED_INDEX=true +CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode +CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart + +# Vector search configuration +CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance +``` + +## Usage + +### 1. Set Clickzetta as the Vector Store + +In your Dify configuration, set: + +```bash +VECTOR_STORE=clickzetta +``` + +### 2. Table Structure + +Clickzetta will automatically create tables with the following structure: + +```sql +CREATE TABLE ( + id STRING NOT NULL, + content STRING NOT NULL, + metadata JSON, + vector VECTOR(FLOAT, ) NOT NULL, + PRIMARY KEY (id) +); + +-- Vector index for similarity search +CREATE VECTOR INDEX idx__vec +ON TABLE .(vector) +PROPERTIES ( + "distance.function" = "cosine_distance", + "scalar.type" = "f32" +); + +-- Inverted index for full-text search (if enabled) +CREATE INVERTED INDEX idx__text +ON .(content) +PROPERTIES ( + "analyzer" = "chinese", + "mode" = "smart" +); +``` + +## Full-Text Search Capabilities + +Clickzetta supports advanced full-text search with multiple analyzers: + +### Analyzer Types + +1. **keyword**: No tokenization, treats the entire string as a single token + - Best for: Exact matching, IDs, codes + +2. **english**: Designed for English text + - Features: Recognizes ASCII letters and numbers, converts to lowercase + - Best for: English content + +3. **chinese**: Chinese text tokenizer + - Features: Recognizes Chinese and English characters, removes punctuation + - Best for: Chinese or mixed Chinese-English content + +4. **unicode**: Multi-language tokenizer based on Unicode + - Features: Recognizes text boundaries in multiple languages + - Best for: Multi-language content + +### Analyzer Modes + +- **max_word**: Fine-grained tokenization (more tokens) +- **smart**: Intelligent tokenization (balanced) + +### Full-Text Search Functions + +- `MATCH_ALL(column, query)`: All terms must be present +- `MATCH_ANY(column, query)`: At least one term must be present +- `MATCH_PHRASE(column, query)`: Exact phrase matching +- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching +- `MATCH_REGEXP(column, pattern)`: Regular expression matching + +## Performance Optimization + +### Vector Search + +1. **Adjust exploration factor** for accuracy vs speed trade-off: + ```sql + SET cz.vector.index.search.ef=64; + ``` + +2. **Use appropriate distance functions**: + - `cosine_distance`: Best for normalized embeddings (e.g., from language models) + - `l2_distance`: Best for raw feature vectors + +### Full-Text Search + +1. **Choose the right analyzer**: + - Use `keyword` for exact matching + - Use language-specific analyzers for better tokenization + +2. **Combine with vector search**: + - Pre-filter with full-text search for better performance + - Use hybrid search for improved relevance + +## Troubleshooting + +### Connection Issues + +1. Verify all 7 required configuration parameters are set +2. Check network connectivity to Clickzetta service +3. Ensure the user has proper permissions on the schema + +### Search Performance + +1. Verify vector index exists: + ```sql + SHOW INDEX FROM .; + ``` + +2. Check if vector index is being used: + ```sql + EXPLAIN SELECT ... WHERE l2_distance(...) < threshold; + ``` + Look for `vector_index_search_type` in the execution plan. + +### Full-Text Search Not Working + +1. Verify inverted index is created +2. Check analyzer configuration matches your content language +3. Use `TOKENIZE()` function to test tokenization: + ```sql + SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart')); + ``` + +## Limitations + +1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns +2. Full-text search relevance scores are not provided by Clickzetta +3. Inverted index creation may fail for very large existing tables (continue without error) +4. Index naming constraints: + - Index names must be unique within a schema + - Only one vector index can be created per column + - The implementation uses timestamps to ensure unique index names +5. A column can only have one vector index at a time + +## References + +- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search) +- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index) +- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference) diff --git a/api/core/rag/datasource/vdb/clickzetta/__init__.py b/api/core/rag/datasource/vdb/clickzetta/__init__.py new file mode 100644 index 0000000000..9d41c5a57d --- /dev/null +++ b/api/core/rag/datasource/vdb/clickzetta/__init__.py @@ -0,0 +1 @@ +# Clickzetta Vector Database Integration for Dify diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py new file mode 100644 index 0000000000..1059b855a2 --- /dev/null +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -0,0 +1,1086 @@ +import json +import logging +import queue +import re +import threading +import time +import uuid +from typing import TYPE_CHECKING, Any, Optional + +import clickzetta # type: ignore +from pydantic import BaseModel, model_validator + +if TYPE_CHECKING: + from clickzetta import Connection + +from configs import dify_config +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +# ClickZetta Lakehouse Vector Database Configuration + + +class ClickzettaConfig(BaseModel): + """ + Configuration class for Clickzetta connection. + """ + + username: str + password: str + instance: str + service: str = "api.clickzetta.com" + workspace: str = "quick_start" + vcluster: str = "default_ap" + schema_name: str = "dify" # Renamed to avoid shadowing BaseModel.schema + # Advanced settings + batch_size: int = 20 # Reduced batch size to avoid large SQL statements + enable_inverted_index: bool = True # Enable inverted index for full-text search + analyzer_type: str = "chinese" # Analyzer type for full-text search: keyword, english, chinese, unicode + analyzer_mode: str = "smart" # Analyzer mode: max_word, smart + vector_distance_function: str = "cosine_distance" # l2_distance or cosine_distance + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + """ + Validate the configuration values. + """ + if not values.get("username"): + raise ValueError("config CLICKZETTA_USERNAME is required") + if not values.get("password"): + raise ValueError("config CLICKZETTA_PASSWORD is required") + if not values.get("instance"): + raise ValueError("config CLICKZETTA_INSTANCE is required") + if not values.get("service"): + raise ValueError("config CLICKZETTA_SERVICE is required") + if not values.get("workspace"): + raise ValueError("config CLICKZETTA_WORKSPACE is required") + if not values.get("vcluster"): + raise ValueError("config CLICKZETTA_VCLUSTER is required") + if not values.get("schema_name"): + raise ValueError("config CLICKZETTA_SCHEMA is required") + return values + + +class ClickzettaConnectionPool: + """ + Global connection pool for ClickZetta connections. + Manages connection reuse across ClickzettaVector instances. + """ + + _instance: Optional["ClickzettaConnectionPool"] = None + _lock = threading.Lock() + + def __init__(self): + self._pools: dict[str, list[tuple[Connection, float]]] = {} # config_key -> [(connection, last_used_time)] + self._pool_locks: dict[str, threading.Lock] = {} + self._max_pool_size = 5 # Maximum connections per configuration + self._connection_timeout = 300 # 5 minutes timeout + self._cleanup_thread: Optional[threading.Thread] = None + self._shutdown = False + self._start_cleanup_thread() + + @classmethod + def get_instance(cls) -> "ClickzettaConnectionPool": + """Get singleton instance of connection pool.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def _get_config_key(self, config: ClickzettaConfig) -> str: + """Generate unique key for connection configuration.""" + return ( + f"{config.username}:{config.instance}:{config.service}:" + f"{config.workspace}:{config.vcluster}:{config.schema_name}" + ) + + def _create_connection(self, config: ClickzettaConfig) -> "Connection": + """Create a new ClickZetta connection.""" + max_retries = 3 + retry_delay = 1.0 + + for attempt in range(max_retries): + try: + connection = clickzetta.connect( + username=config.username, + password=config.password, + instance=config.instance, + service=config.service, + workspace=config.workspace, + vcluster=config.vcluster, + schema=config.schema_name, + ) + + # Configure connection session settings + self._configure_connection(connection) + logger.debug("Created new ClickZetta connection (attempt %d/%d)", attempt + 1, max_retries) + return connection + except Exception: + logger.exception("ClickZetta connection attempt %d/%d failed", attempt + 1, max_retries) + if attempt < max_retries - 1: + time.sleep(retry_delay * (2**attempt)) + else: + raise + + raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts") + + def _configure_connection(self, connection: "Connection") -> None: + """Configure connection session settings.""" + try: + with connection.cursor() as cursor: + # Temporarily suppress ClickZetta client logging to reduce noise + clickzetta_logger = logging.getLogger("clickzetta") + original_level = clickzetta_logger.level + clickzetta_logger.setLevel(logging.WARNING) + + try: + # Use quote mode for string literal escaping + cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") + + # Apply performance optimization hints + performance_hints = [ + # Vector index optimization + "SET cz.storage.parquet.vector.index.read.memory.cache = true", + "SET cz.storage.parquet.vector.index.read.local.cache = false", + # Query optimization + "SET cz.sql.table.scan.push.down.filter = true", + "SET cz.sql.table.scan.enable.ensure.filter = true", + "SET cz.storage.always.prefetch.internal = true", + "SET cz.optimizer.generate.columns.always.valid = true", + "SET cz.sql.index.prewhere.enabled = true", + # Storage optimization + "SET cz.storage.parquet.enable.io.prefetch = false", + "SET cz.optimizer.enable.mv.rewrite = false", + "SET cz.sql.dump.as.lz4 = true", + "SET cz.optimizer.limited.optimization.naive.query = true", + "SET cz.sql.table.scan.enable.push.down.log = false", + "SET cz.storage.use.file.format.local.stats = false", + "SET cz.storage.local.file.object.cache.level = all", + # Job execution optimization + "SET cz.sql.job.fast.mode = true", + "SET cz.storage.parquet.non.contiguous.read = true", + "SET cz.sql.compaction.after.commit = true", + ] + + for hint in performance_hints: + cursor.execute(hint) + finally: + # Restore original logging level + clickzetta_logger.setLevel(original_level) + + except Exception: + logger.exception("Failed to configure connection, continuing with defaults") + + def _is_connection_valid(self, connection: "Connection") -> bool: + """Check if connection is still valid.""" + try: + with connection.cursor() as cursor: + cursor.execute("SELECT 1") + return True + except Exception: + return False + + def get_connection(self, config: ClickzettaConfig) -> "Connection": + """Get a connection from the pool or create a new one.""" + config_key = self._get_config_key(config) + + # Ensure pool lock exists + if config_key not in self._pool_locks: + with self._lock: + if config_key not in self._pool_locks: + self._pool_locks[config_key] = threading.Lock() + self._pools[config_key] = [] + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + current_time = time.time() + + # Try to reuse existing connection + while pool: + connection, last_used = pool.pop(0) + + # Check if connection is not expired and still valid + if current_time - last_used < self._connection_timeout and self._is_connection_valid(connection): + logger.debug("Reusing ClickZetta connection from pool") + return connection + else: + # Connection expired or invalid, close it + try: + connection.close() + except Exception: + pass + + # No valid connection found, create new one + return self._create_connection(config) + + def return_connection(self, config: ClickzettaConfig, connection: "Connection") -> None: + """Return a connection to the pool.""" + config_key = self._get_config_key(config) + + if config_key not in self._pool_locks: + # Pool was cleaned up, just close the connection + try: + connection.close() + except Exception: + pass + return + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + + # Only return to pool if not at capacity and connection is valid + if len(pool) < self._max_pool_size and self._is_connection_valid(connection): + pool.append((connection, time.time())) + logger.debug("Returned ClickZetta connection to pool") + else: + # Pool full or connection invalid, close it + try: + connection.close() + except Exception: + pass + + def _cleanup_expired_connections(self) -> None: + """Clean up expired connections from all pools.""" + current_time = time.time() + + with self._lock: + for config_key in list(self._pools.keys()): + if config_key not in self._pool_locks: + continue + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + valid_connections = [] + + for connection, last_used in pool: + if current_time - last_used < self._connection_timeout: + valid_connections.append((connection, last_used)) + else: + try: + connection.close() + except Exception: + pass + + self._pools[config_key] = valid_connections + + def _start_cleanup_thread(self) -> None: + """Start background thread for connection cleanup.""" + + def cleanup_worker(): + while not self._shutdown: + try: + time.sleep(60) # Cleanup every minute + if not self._shutdown: + self._cleanup_expired_connections() + except Exception: + logger.exception("Error in connection pool cleanup") + + self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) + self._cleanup_thread.start() + + def shutdown(self) -> None: + """Shutdown connection pool and close all connections.""" + self._shutdown = True + + with self._lock: + for config_key in list(self._pools.keys()): + if config_key not in self._pool_locks: + continue + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + for connection, _ in pool: + try: + connection.close() + except Exception: + pass + pool.clear() + + +class ClickzettaVector(BaseVector): + """ + Clickzetta vector storage implementation. + """ + + # Class-level write queue and lock for serializing writes + _write_queue: Optional[queue.Queue] = None + _write_thread: Optional[threading.Thread] = None + _write_lock = threading.Lock() + _shutdown = False + + def __init__(self, collection_name: str, config: ClickzettaConfig): + super().__init__(collection_name) + self._config = config + self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name + self._connection_pool = ClickzettaConnectionPool.get_instance() + self._init_write_queue() + + def _get_connection(self) -> "Connection": + """Get a connection from the pool.""" + return self._connection_pool.get_connection(self._config) + + def _return_connection(self, connection: "Connection") -> None: + """Return a connection to the pool.""" + self._connection_pool.return_connection(self._config, connection) + + class ConnectionContext: + """Context manager for borrowing and returning connections.""" + + def __init__(self, vector_instance: "ClickzettaVector"): + self.vector = vector_instance + self.connection: Optional[Connection] = None + + def __enter__(self) -> "Connection": + self.connection = self.vector._get_connection() + return self.connection + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.connection: + self.vector._return_connection(self.connection) + + def get_connection_context(self) -> "ClickzettaVector.ConnectionContext": + """Get a connection context manager.""" + return self.ConnectionContext(self) + + def _parse_metadata(self, raw_metadata: str, row_id: str) -> dict: + """ + Parse metadata from JSON string with proper error handling and fallback. + + Args: + raw_metadata: Raw JSON string from database + row_id: Row ID for fallback document_id + + Returns: + Parsed metadata dict with guaranteed required fields + """ + try: + if raw_metadata: + metadata = json.loads(raw_metadata) + + # Handle double-encoded JSON + if isinstance(metadata, str): + metadata = json.loads(metadata) + + # Ensure we have a dict + if not isinstance(metadata, dict): + metadata = {} + else: + metadata = {} + except (json.JSONDecodeError, TypeError): + logger.exception("JSON parsing failed for metadata") + # Fallback: extract document_id with regex + doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "") + metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} + + # Ensure required fields are set + metadata["doc_id"] = row_id # segment id + + # Ensure document_id exists (critical for Dify's format_retrieval_documents) + if "document_id" not in metadata: + metadata["document_id"] = row_id # fallback to segment id + + return metadata + + @classmethod + def _init_write_queue(cls): + """Initialize the write queue and worker thread.""" + with cls._write_lock: + if cls._write_queue is None: + cls._write_queue = queue.Queue() + cls._write_thread = threading.Thread(target=cls._write_worker, daemon=True) + cls._write_thread.start() + logger.info("Started Clickzetta write worker thread") + + @classmethod + def _write_worker(cls): + """Worker thread that processes write tasks sequentially.""" + while not cls._shutdown: + try: + # Get task from queue with timeout + if cls._write_queue is not None: + task = cls._write_queue.get(timeout=1) + if task is None: # Shutdown signal + break + + # Execute the write task + func, args, kwargs, result_queue = task + try: + result = func(*args, **kwargs) + result_queue.put((True, result)) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Write task failed") + result_queue.put((False, e)) + finally: + cls._write_queue.task_done() + else: + break + except queue.Empty: + continue + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Write worker error") + + def _execute_write(self, func, *args, **kwargs): + """Execute a write operation through the queue.""" + if ClickzettaVector._write_queue is None: + raise RuntimeError("Write queue not initialized") + + result_queue: queue.Queue[tuple[bool, Any]] = queue.Queue() + ClickzettaVector._write_queue.put((func, args, kwargs, result_queue)) + + # Wait for result + success, result = result_queue.get() + if not success: + raise result + return result + + def get_type(self) -> str: + """Return the vector database type.""" + return "clickzetta" + + def _ensure_connection(self) -> "Connection": + """Get a connection from the pool.""" + return self._get_connection() + + def _table_exists(self) -> bool: + """Check if the table exists.""" + try: + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}") + return True + except Exception as e: + error_message = str(e).lower() + # Handle ClickZetta specific "table or view not found" errors + if any( + phrase in error_message + for phrase in ["table or view not found", "czlh-42000", "semantic analysis exception"] + ): + logger.debug("Table %s.%s does not exist", self._config.schema_name, self._table_name) + return False + else: + # For other connection/permission errors, log warning but return False to avoid blocking cleanup + logger.exception( + "Table existence check failed for %s.%s, assuming it doesn't exist", + self._config.schema_name, + self._table_name, + ) + return False + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + """Create the collection and add initial documents.""" + # Execute table creation through write queue to avoid concurrent conflicts + self._execute_write(self._create_table_and_indexes, embeddings) + + # Add initial texts + if texts: + self.add_texts(texts, embeddings, **kwargs) + + def _create_table_and_indexes(self, embeddings: list[list[float]]): + """Create table and indexes (executed in write worker thread).""" + # Check if table already exists to avoid unnecessary index creation + if self._table_exists(): + logger.info("Table %s.%s already exists, skipping creation", self._config.schema_name, self._table_name) + return + + # Create table with vector and metadata columns + dimension = len(embeddings[0]) if embeddings else 768 + + create_table_sql = f""" + CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} ( + id STRING NOT NULL COMMENT 'Unique document identifier', + {Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval', + {Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes', + {Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT + 'High-dimensional embedding vector for semantic similarity search', + PRIMARY KEY (id) + ) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' + """ + + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(create_table_sql) + logger.info("Created table %s.%s", self._config.schema_name, self._table_name) + + # Create vector index + self._create_vector_index(cursor) + + # Create inverted index for full-text search if enabled + if self._config.enable_inverted_index: + self._create_inverted_index(cursor) + + def _create_vector_index(self, cursor): + """Create HNSW vector index for similarity search.""" + # Use a fixed index name based on table and column name + index_name = f"idx_{self._table_name}_vector" + + # First check if an index already exists on this column + try: + cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") + existing_indexes = cursor.fetchall() + for idx in existing_indexes: + # Check if vector index already exists on the embedding column + if Field.VECTOR.value in str(idx).lower(): + logger.info("Vector index already exists on column %s", Field.VECTOR.value) + return + except (RuntimeError, ValueError) as e: + logger.warning("Failed to check existing indexes: %s", e) + + index_sql = f""" + CREATE VECTOR INDEX IF NOT EXISTS {index_name} + ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value}) + PROPERTIES ( + "distance.function" = "{self._config.vector_distance_function}", + "scalar.type" = "f32", + "m" = "16", + "ef.construction" = "128" + ) + """ + try: + cursor.execute(index_sql) + logger.info("Created vector index: %s", index_name) + except (RuntimeError, ValueError) as e: + error_msg = str(e).lower() + if "already exists" in error_msg or "already has index" in error_msg or "with the same type" in error_msg: + logger.info("Vector index already exists: %s", e) + else: + logger.exception("Failed to create vector index") + raise + + def _create_inverted_index(self, cursor): + """Create inverted index for full-text search.""" + # Use a fixed index name based on table name to avoid duplicates + index_name = f"idx_{self._table_name}_text" + + # Check if an inverted index already exists on this column + try: + cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") + existing_indexes = cursor.fetchall() + for idx in existing_indexes: + idx_str = str(idx).lower() + # More precise check: look for inverted index specifically on the content column + if ( + "inverted" in idx_str + and Field.CONTENT_KEY.value.lower() in idx_str + and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str) + ): + logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx) + return + except (RuntimeError, ValueError) as e: + logger.warning("Failed to check existing indexes: %s", e) + + index_sql = f""" + CREATE INVERTED INDEX IF NOT EXISTS {index_name} + ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value}) + PROPERTIES ( + "analyzer" = "{self._config.analyzer_type}", + "mode" = "{self._config.analyzer_mode}" + ) + """ + try: + cursor.execute(index_sql) + logger.info("Created inverted index: %s", index_name) + except (RuntimeError, ValueError) as e: + error_msg = str(e).lower() + # Handle ClickZetta specific error messages + if ( + "already exists" in error_msg + or "already has index" in error_msg + or "with the same type" in error_msg + or "cannot create inverted index" in error_msg + ) and "already has index" in error_msg: + logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value) + # Try to get the existing index name for logging + try: + cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") + existing_indexes = cursor.fetchall() + for idx in existing_indexes: + if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower(): + logger.info("Found existing inverted index: %s", idx) + break + except (RuntimeError, ValueError): + pass + else: + logger.warning("Failed to create inverted index: %s", e) + # Continue without inverted index - full-text search will fall back to LIKE + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + """Add documents with embeddings to the collection.""" + if not documents: + return + + batch_size = self._config.batch_size + total_batches = (len(documents) + batch_size - 1) // batch_size + + for i in range(0, len(documents), batch_size): + batch_docs = documents[i : i + batch_size] + batch_embeddings = embeddings[i : i + batch_size] + + # Execute batch insert through write queue + self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) + + def _insert_batch( + self, + batch_docs: list[Document], + batch_embeddings: list[list[float]], + batch_index: int, + batch_size: int, + total_batches: int, + ): + """Insert a batch of documents using parameterized queries (executed in write worker thread).""" + if not batch_docs or not batch_embeddings: + logger.warning("Empty batch provided, skipping insertion") + return + + if len(batch_docs) != len(batch_embeddings): + logger.error("Mismatch between docs (%d) and embeddings (%d)", len(batch_docs), len(batch_embeddings)) + return + + # Prepare data for parameterized insertion + data_rows = [] + vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768 + + for doc, embedding in zip(batch_docs, batch_embeddings): + # Optimized: minimal checks for common case, fallback for edge cases + metadata = doc.metadata if doc.metadata else {} + + if not isinstance(metadata, dict): + metadata = {} + + doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))) + + # Fast path for JSON serialization + try: + metadata_json = json.dumps(metadata, ensure_ascii=True) + except (TypeError, ValueError): + logger.warning("JSON serialization failed, using empty dict") + metadata_json = "{}" + + content = doc.page_content or "" + + # According to ClickZetta docs, vector should be formatted as array string + # for external systems: '[1.0, 2.0, 3.0]' + vector_str = "[" + ",".join(map(str, embedding)) + "]" + data_rows.append([doc_id, content, metadata_json, vector_str]) + + # Check if we have any valid data to insert + if not data_rows: + logger.warning("No valid documents to insert in batch %d/%d", batch_index // batch_size + 1, total_batches) + return + + # Use parameterized INSERT with executemany for better performance and security + # Cast JSON and VECTOR in SQL, pass raw data as parameters + columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}" + insert_sql = ( + f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) " + f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" + ) + + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + try: + # Set session-level hints for batch insert operations + # Note: executemany doesn't support hints parameter, so we set them as session variables + # Temporarily suppress ClickZetta client logging to reduce noise + clickzetta_logger = logging.getLogger("clickzetta") + original_level = clickzetta_logger.level + clickzetta_logger.setLevel(logging.WARNING) + + try: + cursor.execute("SET cz.sql.job.fast.mode = true") + cursor.execute("SET cz.sql.compaction.after.commit = true") + cursor.execute("SET cz.storage.always.prefetch.internal = true") + finally: + # Restore original logging level + clickzetta_logger.setLevel(original_level) + + cursor.executemany(insert_sql, data_rows) + logger.info( + "Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)", + batch_index // batch_size + 1, + total_batches, + len(data_rows), + vector_dimension, + ) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows)) + logger.exception("SQL template: %s", insert_sql) + logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None") + raise + + def text_exists(self, id: str) -> bool: + """Check if a document exists by ID.""" + # Check if table exists first + if not self._table_exists(): + return False + + safe_id = self._safe_doc_id(id) + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute( + f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", + binding_params=[safe_id], + ) + result = cursor.fetchone() + return result[0] > 0 if result else False + + def delete_by_ids(self, ids: list[str]) -> None: + """Delete documents by IDs.""" + if not ids: + return + + # Check if table exists before attempting delete + if not self._table_exists(): + logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name) + return + + # Execute delete through write queue + self._execute_write(self._delete_by_ids_impl, ids) + + def _delete_by_ids_impl(self, ids: list[str]) -> None: + """Implementation of delete by IDs (executed in write worker thread).""" + safe_ids = [self._safe_doc_id(id) for id in ids] + + # Use parameterized query to prevent SQL injection + placeholders = ",".join("?" for _ in safe_ids) + sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({placeholders})" + + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(sql, binding_params=safe_ids) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + """Delete documents by metadata field.""" + # Check if table exists before attempting delete + if not self._table_exists(): + logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name) + return + + # Execute delete through write queue + self._execute_write(self._delete_by_metadata_field_impl, key, value) + + def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: + """Implementation of delete by metadata field (executed in write worker thread).""" + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Using JSON path to filter with parameterized query + # Note: JSON path requires literal key name, cannot be parameterized + # Use json_extract_string function for ClickZetta compatibility + sql = ( + f"DELETE FROM {self._config.schema_name}.{self._table_name} " + f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?" + ) + cursor.execute(sql, binding_params=[value]) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """Search for documents by vector similarity.""" + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + + top_k = kwargs.get("top_k", 10) + score_threshold = kwargs.get("score_threshold", 0.0) + document_ids_filter = kwargs.get("document_ids_filter") + + # Handle filter parameter from canvas (workflow) + filter_param = kwargs.get("filter", {}) + + # Build filter clause + filter_clauses = [] + if document_ids_filter: + safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) + # Use json_extract_string function for ClickZetta compatibility + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + + # No need for dataset_id filter since each dataset has its own table + + # Add distance threshold based on distance function + vector_dimension = len(query_vector) + if self._config.vector_distance_function == "cosine_distance": + # For cosine distance, smaller is better (0 = identical, 2 = opposite) + distance_func = "COSINE_DISTANCE" + if score_threshold > 0: + query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" + filter_clauses.append( + f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}" + ) + else: + # For L2 distance, smaller is better + distance_func = "L2_DISTANCE" + if score_threshold > 0: + query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" + filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}") + + where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1" + + # Execute vector search query + query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" + search_sql = f""" + SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, + {distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance + FROM {self._config.schema_name}.{self._table_name} + WHERE {where_clause} + ORDER BY distance + LIMIT {top_k} + """ + + documents = [] + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Use hints parameter for vector search optimization + search_hints = { + "hints": { + "sdk.job.timeout": 60, # Increase timeout for vector search + "cz.sql.job.fast.mode": True, + "cz.storage.parquet.vector.index.read.memory.cache": True, + } + } + cursor.execute(search_sql, search_hints) + results = cursor.fetchall() + + for row in results: + # Parse metadata using centralized method + metadata = self._parse_metadata(row[2], row[0]) + + # Add score based on distance + if self._config.vector_distance_function == "cosine_distance": + metadata["score"] = 1 - (row[3] / 2) + else: + metadata["score"] = 1 / (1 + row[3]) + + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) + + return documents + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Search for documents using full-text search with inverted index.""" + if not self._config.enable_inverted_index: + logger.warning("Full-text search is not enabled. Enable inverted index in config.") + return [] + + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + + top_k = kwargs.get("top_k", 10) + document_ids_filter = kwargs.get("document_ids_filter") + + # Handle filter parameter from canvas (workflow) + filter_param = kwargs.get("filter", {}) + + # Build filter clause + filter_clauses = [] + if document_ids_filter: + safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) + # Use json_extract_string function for ClickZetta compatibility + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + + # No need for dataset_id filter since each dataset has its own table + + # Use match_all function for full-text search + # match_all requires all terms to be present + # Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause + escaped_query = query.replace("'", "''") + filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')") + + where_clause = " AND ".join(filter_clauses) + + # Execute full-text search query + search_sql = f""" + SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value} + FROM {self._config.schema_name}.{self._table_name} + WHERE {where_clause} + LIMIT {top_k} + """ + + documents = [] + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + try: + # Use hints parameter for full-text search optimization + fulltext_hints = { + "hints": { + "sdk.job.timeout": 30, # Timeout for full-text search + "cz.sql.job.fast.mode": True, + "cz.sql.index.prewhere.enabled": True, + } + } + cursor.execute(search_sql, fulltext_hints) + results = cursor.fetchall() + + for row in results: + # Parse metadata from JSON string (may be double-encoded) + try: + if row[2]: + metadata = json.loads(row[2]) + + # If result is a string, it's double-encoded JSON - parse again + if isinstance(metadata, str): + metadata = json.loads(metadata) + + if not isinstance(metadata, dict): + metadata = {} + else: + metadata = {} + except (json.JSONDecodeError, TypeError) as e: + logger.exception("JSON parsing failed") + # Fallback: extract document_id with regex + + doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) + metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} + + # Ensure required fields are set + metadata["doc_id"] = row[0] # segment id + + # Ensure document_id exists (critical for Dify's format_retrieval_documents) + if "document_id" not in metadata: + metadata["document_id"] = row[0] # fallback to segment id + + # Add a relevance score for full-text search + metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Full-text search failed") + # Fallback to LIKE search if full-text search fails + return self._search_by_like(query, **kwargs) + + return documents + + def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]: + """Fallback search using LIKE operator.""" + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + + top_k = kwargs.get("top_k", 10) + document_ids_filter = kwargs.get("document_ids_filter") + + # Handle filter parameter from canvas (workflow) + filter_param = kwargs.get("filter", {}) + + # Build filter clause + filter_clauses = [] + if document_ids_filter: + safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) + # Use json_extract_string function for ClickZetta compatibility + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + + # No need for dataset_id filter since each dataset has its own table + + # Use simple quote escaping for LIKE clause + escaped_query = query.replace("'", "''") + filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'") + where_clause = " AND ".join(filter_clauses) + + search_sql = f""" + SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value} + FROM {self._config.schema_name}.{self._table_name} + WHERE {where_clause} + LIMIT {top_k} + """ + + documents = [] + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Use hints parameter for LIKE search optimization + like_hints = { + "hints": { + "sdk.job.timeout": 20, # Timeout for LIKE search + "cz.sql.job.fast.mode": True, + } + } + cursor.execute(search_sql, like_hints) + results = cursor.fetchall() + + for row in results: + # Parse metadata using centralized method + metadata = self._parse_metadata(row[2], row[0]) + + metadata["score"] = 0.5 # Lower score for LIKE search + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) + + return documents + + def delete(self) -> None: + """Delete the entire collection.""" + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") + + def _format_vector_simple(self, vector: list[float]) -> str: + """Simple vector formatting for SQL queries.""" + return ",".join(map(str, vector)) + + def _safe_doc_id(self, doc_id: str) -> str: + """Ensure doc_id is safe for SQL and doesn't contain special characters.""" + if not doc_id: + return str(uuid.uuid4()) + # Remove or replace potentially problematic characters + safe_id = str(doc_id) + # Only allow alphanumeric, hyphens, underscores + safe_id = "".join(c for c in safe_id if c.isalnum() or c in "-_") + if not safe_id: # If all characters were removed + return str(uuid.uuid4()) + return safe_id[:255] # Limit length + + +class ClickzettaVectorFactory(AbstractVectorFactory): + """Factory for creating Clickzetta vector instances.""" + + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: + """Initialize a Clickzetta vector instance.""" + # Get configuration from environment variables or dataset config + config = ClickzettaConfig( + username=dify_config.CLICKZETTA_USERNAME or "", + password=dify_config.CLICKZETTA_PASSWORD or "", + instance=dify_config.CLICKZETTA_INSTANCE or "", + service=dify_config.CLICKZETTA_SERVICE or "api.clickzetta.com", + workspace=dify_config.CLICKZETTA_WORKSPACE or "quick_start", + vcluster=dify_config.CLICKZETTA_VCLUSTER or "default_ap", + schema_name=dify_config.CLICKZETTA_SCHEMA or "dify", + batch_size=dify_config.CLICKZETTA_BATCH_SIZE or 100, + enable_inverted_index=dify_config.CLICKZETTA_ENABLE_INVERTED_INDEX or True, + analyzer_type=dify_config.CLICKZETTA_ANALYZER_TYPE or "chinese", + analyzer_mode=dify_config.CLICKZETTA_ANALYZER_MODE or "smart", + vector_distance_function=dify_config.CLICKZETTA_VECTOR_DISTANCE_FUNCTION or "cosine_distance", + ) + + # Use dataset collection name as table name + collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower() + + return ClickzettaVector(collection_name=collection_name, config=config) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 3aa4b67a78..0517d5a6d1 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -246,6 +246,10 @@ class TencentVector(BaseVector): return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + document_ids_filter = kwargs.get("document_ids_filter") + filter = None + if document_ids_filter: + filter = Filter(Filter.In("metadata.document_id", document_ids_filter)) if not self._enable_hybrid_search: return [] res = self._client.hybrid_search( @@ -269,6 +273,7 @@ class TencentVector(BaseVector): ), retrieve_vector=False, limit=kwargs.get("top_k", 4), + filter=filter, ) score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(res, score_threshold) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 43c49ed4b3..eef03ce412 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -172,6 +172,10 @@ class Vector: from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory return MatrixoneVectorFactory + case VectorType.CLICKZETTA: + from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory + + return ClickzettaVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 0d70947b72..a415142196 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -30,3 +30,4 @@ class VectorType(StrEnum): TABLESTORE = "tablestore" HUAWEI_CLOUD = "huawei_cloud" MATRIXONE = "matrixone" + CLICKZETTA = "clickzetta" diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 14363de7d4..0eff7c186a 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -62,7 +62,7 @@ class WordExtractor(BaseExtractor): def extract(self) -> list[Document]: """Load given path as single page.""" - content = self.parse_docx(self.file_path, "storage") + content = self.parse_docx(self.file_path) return [ Document( page_content=content, @@ -189,23 +189,8 @@ class WordExtractor(BaseExtractor): paragraph_content.append(run.text) return "".join(paragraph_content).strip() - def _parse_paragraph(self, paragraph, image_map): - paragraph_content = [] - for run in paragraph.runs: - if run.element.xpath(".//a:blip"): - for blip in run.element.xpath(".//a:blip"): - embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") - if embed_id: - rel_target = run.part.rels[embed_id].target_ref - if rel_target in image_map: - paragraph_content.append(image_map[rel_target]) - if run.text.strip(): - paragraph_content.append(run.text.strip()) - return " ".join(paragraph_content) if paragraph_content else "" - - def parse_docx(self, docx_path, image_folder): + def parse_docx(self, docx_path): doc = DocxDocument(docx_path) - os.makedirs(image_folder, exist_ok=True) content = [] diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index bcaf299892..d654463be9 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -5,14 +5,13 @@ from __future__ import annotations from typing import Any, Optional from core.model_manager import ModelInstance -from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer from core.rag.splitter.text_splitter import ( TS, Collection, Literal, RecursiveCharacterTextSplitter, Set, - TokenTextSplitter, Union, ) @@ -45,14 +44,6 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return [len(text) for text in texts] - if issubclass(cls, TokenTextSplitter): - extra_kwargs = { - "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", - "allowed_special": allowed_special, - "disallowed_special": disallowed_special, - } - kwargs = {**kwargs, **extra_kwargs} - return cls(length_function=_character_encoder, **kwargs) diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index 1639dd687f..a8fd6ec2cd 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -37,12 +37,12 @@ class LocaltimeToTimestampTool(BuiltinTool): @staticmethod def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None: try: - if local_tz is None: - local_tz = datetime.now().astimezone().tzinfo - if isinstance(local_tz, str): - local_tz = pytz.timezone(local_tz) local_time = datetime.strptime(localtime, time_format) - localtime = local_tz.localize(local_time) # type: ignore + if local_tz is None: + localtime = local_time.astimezone() # type: ignore + elif isinstance(local_tz, str): + local_tz = pytz.timezone(local_tz) + localtime = local_tz.localize(local_time) # type: ignore timestamp = int(localtime.timestamp()) # type: ignore return timestamp except Exception as e: diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 333ef2834c..e112de9578 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -1,7 +1,8 @@ import json from collections.abc import Generator +from dataclasses import dataclass from os import getenv -from typing import Any, Optional +from typing import Any, Optional, Union from urllib.parse import urlencode import httpx @@ -20,6 +21,20 @@ API_TOOL_DEFAULT_TIMEOUT = ( ) +@dataclass +class ParsedResponse: + """Represents a parsed HTTP response with type information""" + + content: Union[str, dict] + is_json: bool + + def to_string(self) -> str: + """Convert response to string format for credential validation""" + if isinstance(self.content, dict): + return json.dumps(self.content, ensure_ascii=False) + return str(self.content) + + class ApiTool(Tool): """ Api tool @@ -58,7 +73,9 @@ class ApiTool(Tool): response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) # validate response - return self.validate_and_parse_response(response) + parsed_response = self.validate_and_parse_response(response) + # For credential validation, always return as string + return parsed_response.to_string() def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.API @@ -112,23 +129,36 @@ class ApiTool(Tool): return headers - def validate_and_parse_response(self, response: httpx.Response) -> str: + def validate_and_parse_response(self, response: httpx.Response) -> ParsedResponse: """ - validate the response + validate the response and return parsed content with type information + + :return: ParsedResponse with content and is_json flag """ if isinstance(response, httpx.Response): if response.status_code >= 400: raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") if not response.content: - return "Empty response from the tool, please check your parameters and try again." + return ParsedResponse( + "Empty response from the tool, please check your parameters and try again.", False + ) + + # Check content type + content_type = response.headers.get("content-type", "").lower() + is_json_content_type = "application/json" in content_type + + # Try to parse as JSON try: - response = response.json() - try: - return json.dumps(response, ensure_ascii=False) - except Exception: - return json.dumps(response) + json_response = response.json() + # If content-type indicates JSON, return as JSON object + if is_json_content_type: + return ParsedResponse(json_response, True) + else: + # If content-type doesn't indicate JSON, treat as text regardless of content + return ParsedResponse(response.text, False) except Exception: - return response.text + # Not valid JSON, return as text + return ParsedResponse(response.text, False) else: raise ValueError(f"Invalid response type {type(response)}") @@ -369,7 +399,14 @@ class ApiTool(Tool): response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters) # validate response - response = self.validate_and_parse_response(response) + parsed_response = self.validate_and_parse_response(response) - # assemble invoke message - yield self.create_text_message(response) + # assemble invoke message based on response type + if parsed_response.is_json and isinstance(parsed_response.content, dict): + yield self.create_json_message(parsed_response.content) + else: + # Convert to string if needed and create text message + text_response = ( + parsed_response.content if isinstance(parsed_response.content, str) else str(parsed_response.content) + ) + yield self.create_text_message(text_response) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 178f2b9689..83444c02d8 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -29,7 +29,7 @@ from core.tools.errors import ( ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) -from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.enums import CreatorUserRole @@ -247,7 +247,8 @@ class ToolEngine: ) elif response.type == ToolInvokeMessage.MessageType.JSON: result += json.dumps( - cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False + safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object), + ensure_ascii=False, ) else: result += str(response.message) diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 9998de0465..ac12d83ef2 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,7 +1,14 @@ import logging from collections.abc import Generator +from datetime import date, datetime +from decimal import Decimal from mimetypes import guess_extension -from typing import Optional +from typing import Optional, cast +from uuid import UUID + +import numpy as np +import pytz +from flask_login import current_user from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage @@ -10,6 +17,41 @@ from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) +def safe_json_value(v): + if isinstance(v, datetime): + tz_name = getattr(current_user, "timezone", None) if current_user is not None else None + if not tz_name: + tz_name = "UTC" + return v.astimezone(pytz.timezone(tz_name)).isoformat() + elif isinstance(v, date): + return v.isoformat() + elif isinstance(v, UUID): + return str(v) + elif isinstance(v, Decimal): + return float(v) + elif isinstance(v, bytes): + try: + return v.decode("utf-8") + except UnicodeDecodeError: + return v.hex() + elif isinstance(v, memoryview): + return v.tobytes().hex() + elif isinstance(v, np.ndarray): + return v.tolist() + elif isinstance(v, dict): + return safe_json_dict(v) + elif isinstance(v, list | tuple | set): + return [safe_json_value(i) for i in v] + else: + return v + + +def safe_json_dict(d): + if not isinstance(d, dict): + raise TypeError("safe_json_dict() expects a dictionary (dict) as input") + return {k: safe_json_value(v) for k, v in d.items()} + + class ToolFileMessageTransformer: @classmethod def transform_tool_invoke_messages( @@ -113,6 +155,12 @@ class ToolFileMessageTransformer: ) else: yield message + + elif message.type == ToolInvokeMessage.MessageType.JSON: + if isinstance(message.message, ToolInvokeMessage.JsonMessage): + json_msg = cast(ToolInvokeMessage.JsonMessage, message.message) + json_msg.json_object = safe_json_value(json_msg.json_object) + yield message else: yield message diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 13274f4e0e..a99f5eece3 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -119,6 +119,13 @@ class ObjectSegment(Segment): class ArraySegment(Segment): + @property + def text(self) -> str: + # Return empty string for empty arrays instead of "[]" + if not self.value: + return "" + return super().text + @property def markdown(self) -> str: items = [] @@ -155,6 +162,9 @@ class ArrayStringSegment(ArraySegment): @property def text(self) -> str: + # Return empty string for empty arrays instead of "[]" + if not self.value: + return "" return json.dumps(self.value, ensure_ascii=False) diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 23512c8ce4..a61e6ba4ac 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -168,7 +168,57 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str: """Extract text from a file based on its file extension.""" match file_extension: - case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml": + case ( + ".txt" + | ".markdown" + | ".md" + | ".html" + | ".htm" + | ".xml" + | ".c" + | ".h" + | ".cpp" + | ".hpp" + | ".cc" + | ".cxx" + | ".c++" + | ".py" + | ".js" + | ".ts" + | ".jsx" + | ".tsx" + | ".java" + | ".php" + | ".rb" + | ".go" + | ".rs" + | ".swift" + | ".kt" + | ".scala" + | ".sh" + | ".bash" + | ".bat" + | ".ps1" + | ".sql" + | ".r" + | ".m" + | ".pl" + | ".lua" + | ".vim" + | ".asm" + | ".s" + | ".css" + | ".scss" + | ".less" + | ".sass" + | ".ini" + | ".cfg" + | ".conf" + | ".toml" + | ".env" + | ".log" + | ".vtt" + ): return _extract_text_from_plain_text(file_content) case ".json": return _extract_text_from_json(file_content) @@ -194,8 +244,6 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) return _extract_text_from_eml(file_content) case ".msg": return _extract_text_from_msg(file_content) - case ".vtt": - return _extract_text_from_vtt(file_content) case ".properties": return _extract_text_from_properties(file_content) case _: diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 2106369bd6..e45f63bbec 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -91,7 +91,7 @@ class Executor: self.auth = node_data.authorization self.timeout = timeout self.ssl_verify = node_data.ssl_verify - self.params = [] + self.params = None self.headers = {} self.content = None self.files = None @@ -139,7 +139,8 @@ class Executor: (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) ) - self.params = result + if result: + self.params = result def _init_headers(self): """ diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index bd35278544..d13393dd14 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -69,6 +69,19 @@ class Storage: from extensions.storage.supabase_storage import SupabaseStorage return SupabaseStorage + case StorageType.CLICKZETTA_VOLUME: + from extensions.storage.clickzetta_volume.clickzetta_volume_storage import ( + ClickZettaVolumeConfig, + ClickZettaVolumeStorage, + ) + + def create_clickzetta_volume_storage(): + # ClickZettaVolumeConfig will automatically read from environment variables + # and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set + volume_config = ClickZettaVolumeConfig() + return ClickZettaVolumeStorage(volume_config) + + return create_clickzetta_volume_storage case _: raise ValueError(f"unsupported storage type {storage_type}") diff --git a/api/extensions/storage/clickzetta_volume/__init__.py b/api/extensions/storage/clickzetta_volume/__init__.py new file mode 100644 index 0000000000..8a1588034b --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/__init__.py @@ -0,0 +1,5 @@ +"""ClickZetta Volume storage implementation.""" + +from .clickzetta_volume_storage import ClickZettaVolumeStorage + +__all__ = ["ClickZettaVolumeStorage"] diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py new file mode 100644 index 0000000000..09ab37f42e --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -0,0 +1,530 @@ +"""ClickZetta Volume Storage Implementation + +This module provides storage backend using ClickZetta Volume functionality. +Supports Table Volume, User Volume, and External Volume types. +""" + +import logging +import os +import tempfile +from collections.abc import Generator +from io import BytesIO +from pathlib import Path +from typing import Optional + +import clickzetta # type: ignore[import] +from pydantic import BaseModel, model_validator + +from extensions.storage.base_storage import BaseStorage + +from .volume_permissions import VolumePermissionManager, check_volume_permission + +logger = logging.getLogger(__name__) + + +class ClickZettaVolumeConfig(BaseModel): + """Configuration for ClickZetta Volume storage.""" + + username: str = "" + password: str = "" + instance: str = "" + service: str = "api.clickzetta.com" + workspace: str = "quick_start" + vcluster: str = "default_ap" + schema_name: str = "dify" + volume_type: str = "table" # table|user|external + volume_name: Optional[str] = None # For external volumes + table_prefix: str = "dataset_" # Prefix for table volume names + dify_prefix: str = "dify_km" # Directory prefix for User Volume + permission_check: bool = True # Enable/disable permission checking + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + """Validate the configuration values. + + This method will first try to use CLICKZETTA_VOLUME_* environment variables, + then fall back to CLICKZETTA_* environment variables (for vector DB config). + """ + import os + + # Helper function to get environment variable with fallback + def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str: + # First try CLICKZETTA_VOLUME_* specific config + volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", "")) + if volume_value: + return str(volume_value) + + # Then try environment variables + volume_env = os.getenv(volume_key) + if volume_env: + return volume_env + + # Fall back to existing CLICKZETTA_* config + fallback_env = os.getenv(fallback_key) + if fallback_env: + return fallback_env + + return default or "" + + # Apply environment variables with fallback to existing CLICKZETTA_* config + values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME")) + values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD")) + values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE")) + values.setdefault( + "service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com") + ) + values.setdefault( + "workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start") + ) + values.setdefault( + "vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap") + ) + values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify")) + + # Volume-specific configurations (no fallback to vector DB config) + values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table")) + values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME")) + values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_")) + values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km")) + # 暂时禁用权限检查功能,直接设置为false + values.setdefault("permission_check", False) + + # Validate required fields + if not values.get("username"): + raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required") + if not values.get("password"): + raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required") + if not values.get("instance"): + raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required") + + # Validate volume type + volume_type = values["volume_type"] + if volume_type not in ["table", "user", "external"]: + raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external") + + if volume_type == "external" and not values.get("volume_name"): + raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type") + + return values + + +class ClickZettaVolumeStorage(BaseStorage): + """ClickZetta Volume storage implementation.""" + + def __init__(self, config: ClickZettaVolumeConfig): + """Initialize ClickZetta Volume storage. + + Args: + config: ClickZetta Volume configuration + """ + self._config = config + self._connection = None + self._permission_manager: VolumePermissionManager | None = None + self._init_connection() + self._init_permission_manager() + + logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type) + + def _init_connection(self): + """Initialize ClickZetta connection.""" + try: + self._connection = clickzetta.connect( + username=self._config.username, + password=self._config.password, + instance=self._config.instance, + service=self._config.service, + workspace=self._config.workspace, + vcluster=self._config.vcluster, + schema=self._config.schema_name, + ) + logger.debug("ClickZetta connection established") + except Exception as e: + logger.exception("Failed to connect to ClickZetta") + raise + + def _init_permission_manager(self): + """Initialize permission manager.""" + try: + self._permission_manager = VolumePermissionManager( + self._connection, self._config.volume_type, self._config.volume_name + ) + logger.debug("Permission manager initialized") + except Exception as e: + logger.exception("Failed to initialize permission manager") + raise + + def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str: + """Get the appropriate volume path based on volume type.""" + if self._config.volume_type == "user": + # Add dify prefix for User Volume to organize files + return f"{self._config.dify_prefix}/{filename}" + elif self._config.volume_type == "table": + # Check if this should use User Volume (special directories) + if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + # Use User Volume with dify prefix for special directories + return f"{self._config.dify_prefix}/{filename}" + + if dataset_id: + return f"{self._config.table_prefix}{dataset_id}/{filename}" + else: + # Extract dataset_id from filename if not provided + # Format: dataset_id/filename + if "/" in filename: + return filename + else: + raise ValueError("dataset_id is required for table volume or filename must include dataset_id/") + elif self._config.volume_type == "external": + return filename + else: + raise ValueError(f"Unsupported volume type: {self._config.volume_type}") + + def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str: + """Get SQL prefix for volume operations.""" + if self._config.volume_type == "user": + return "USER VOLUME" + elif self._config.volume_type == "table": + # For Dify's current file storage pattern, most files are stored in + # paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext" + # These should use USER VOLUME for better compatibility + if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + return "USER VOLUME" + + # Only use TABLE VOLUME for actual dataset-specific paths + # like "dataset_12345/file.pdf" or paths with dataset_ prefix + if dataset_id: + table_name = f"{self._config.table_prefix}{dataset_id}" + else: + # Default table name for generic operations + table_name = "default_dataset" + return f"TABLE VOLUME {table_name}" + elif self._config.volume_type == "external": + return f"VOLUME {self._config.volume_name}" + else: + raise ValueError(f"Unsupported volume type: {self._config.volume_type}") + + def _execute_sql(self, sql: str, fetch: bool = False): + """Execute SQL command.""" + try: + if self._connection is None: + raise RuntimeError("Connection not initialized") + with self._connection.cursor() as cursor: + cursor.execute(sql) + if fetch: + return cursor.fetchall() + return None + except Exception as e: + logger.exception("SQL execution failed: %s", sql) + raise + + def _ensure_table_volume_exists(self, dataset_id: str) -> None: + """Ensure table volume exists for the given dataset_id.""" + if self._config.volume_type != "table" or not dataset_id: + return + + # Skip for upload_files and other special directories that use USER VOLUME + if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + return + + table_name = f"{self._config.table_prefix}{dataset_id}" + + try: + # Check if table exists + check_sql = f"SHOW TABLES LIKE '{table_name}'" + result = self._execute_sql(check_sql, fetch=True) + + if not result: + # Create table with volume + create_sql = f""" + CREATE TABLE {table_name} ( + id INT PRIMARY KEY AUTO_INCREMENT, + filename VARCHAR(255) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + INDEX idx_filename (filename) + ) WITH VOLUME + """ + self._execute_sql(create_sql) + logger.info("Created table volume: %s", table_name) + + except Exception as e: + logger.warning("Failed to create table volume %s: %s", table_name, e) + # Don't raise exception, let the operation continue + # The table might exist but not be visible due to permissions + + def save(self, filename: str, data: bytes) -> None: + """Save data to ClickZetta Volume. + + Args: + filename: File path in volume + data: File content as bytes + """ + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + # Ensure table volume exists (for table volumes) + if dataset_id: + self._ensure_table_volume_exists(dataset_id) + + # Check permissions (if enabled) + if self._config.permission_check: + # Skip permission check for special directories that use USER VOLUME + if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + if self._permission_manager is not None: + check_volume_permission(self._permission_manager, "save", dataset_id) + + # Write data to temporary file + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file.write(data) + temp_file_path = temp_file.name + + try: + # Upload to volume + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'" + else: + sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'" + + self._execute_sql(sql) + logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path) + finally: + # Clean up temporary file + Path(temp_file_path).unlink(missing_ok=True) + + def load_once(self, filename: str) -> bytes: + """Load file content from ClickZetta Volume. + + Args: + filename: File path in volume + + Returns: + File content as bytes + """ + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + # Check permissions (if enabled) + if self._config.permission_check: + # Skip permission check for special directories that use USER VOLUME + if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + if self._permission_manager is not None: + check_volume_permission(self._permission_manager, "load_once", dataset_id) + + # Download to temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'" + else: + sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'" + + self._execute_sql(sql) + + # Find the downloaded file (may be in subdirectories) + downloaded_file = None + for root, dirs, files in os.walk(temp_dir): + for file in files: + if file == filename or file == os.path.basename(filename): + downloaded_file = Path(root) / file + break + if downloaded_file: + break + + if not downloaded_file or not downloaded_file.exists(): + raise FileNotFoundError(f"Downloaded file not found: {filename}") + + content = downloaded_file.read_bytes() + + logger.debug("File %s loaded from ClickZetta Volume", filename) + return content + + def load_stream(self, filename: str) -> Generator: + """Load file as stream from ClickZetta Volume. + + Args: + filename: File path in volume + + Yields: + File content chunks + """ + content = self.load_once(filename) + batch_size = 4096 + stream = BytesIO(content) + + while chunk := stream.read(batch_size): + yield chunk + + logger.debug("File %s loaded as stream from ClickZetta Volume", filename) + + def download(self, filename: str, target_filepath: str): + """Download file from ClickZetta Volume to local path. + + Args: + filename: File path in volume + target_filepath: Local target file path + """ + content = self.load_once(filename) + + with Path(target_filepath).open("wb") as f: + f.write(content) + + logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath) + + def exists(self, filename: str) -> bool: + """Check if file exists in ClickZetta Volume. + + Args: + filename: File path in volume + + Returns: + True if file exists, False otherwise + """ + try: + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'" + else: + sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'" + + rows = self._execute_sql(sql, fetch=True) + + exists = len(rows) > 0 + logger.debug("File %s exists check: %s", filename, exists) + return exists + except Exception as e: + logger.warning("Error checking file existence for %s: %s", filename, e) + return False + + def delete(self, filename: str): + """Delete file from ClickZetta Volume. + + Args: + filename: File path in volume + """ + if not self.exists(filename): + logger.debug("File %s not found, skip delete", filename) + return + + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"REMOVE {volume_prefix} FILE '{volume_path}'" + else: + sql = f"REMOVE {volume_prefix} FILE '{filename}'" + + self._execute_sql(sql) + + logger.debug("File %s deleted from ClickZetta Volume", filename) + + def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: + """Scan files and directories in ClickZetta Volume. + + Args: + path: Path to scan (dataset_id for table volumes) + files: Include files in results + directories: Include directories in results + + Returns: + List of file/directory paths + """ + try: + # For table volumes, path is treated as dataset_id + dataset_id = None + if self._config.volume_type == "table": + dataset_id = path + path = "" # Root of the table volume + + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # For User Volume, add dify prefix to path + if volume_prefix == "USER VOLUME": + if path: + scan_path = f"{self._config.dify_prefix}/{path}" + sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'" + else: + sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'" + else: + if path: + sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'" + else: + sql = f"LIST {volume_prefix}" + + rows = self._execute_sql(sql, fetch=True) + + result = [] + for row in rows: + file_path = row[0] # relative_path column + + # For User Volume, remove dify prefix from results + dify_prefix_with_slash = f"{self._config.dify_prefix}/" + if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash): + file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix + + if files and not file_path.endswith("/") or directories and file_path.endswith("/"): + result.append(file_path) + + logger.debug("Scanned %d items in path %s", len(result), path) + return result + + except Exception as e: + logger.exception("Error scanning path %s", path) + return [] diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py new file mode 100644 index 0000000000..d5d04f121b --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -0,0 +1,516 @@ +"""ClickZetta Volume文件生命周期管理 + +该模块提供文件版本控制、自动清理、备份和恢复等生命周期管理功能。 +支持知识库文件的完整生命周期管理。 +""" + +import json +import logging +from dataclasses import asdict, dataclass +from datetime import datetime, timedelta +from enum import Enum +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +class FileStatus(Enum): + """文件状态枚举""" + + ACTIVE = "active" # 活跃状态 + ARCHIVED = "archived" # 已归档 + DELETED = "deleted" # 已删除(软删除) + BACKUP = "backup" # 备份文件 + + +@dataclass +class FileMetadata: + """文件元数据""" + + filename: str + size: int | None + created_at: datetime + modified_at: datetime + version: int | None + status: FileStatus + checksum: Optional[str] = None + tags: Optional[dict[str, str]] = None + parent_version: Optional[int] = None + + def to_dict(self) -> dict: + """转换为字典格式""" + data = asdict(self) + data["created_at"] = self.created_at.isoformat() + data["modified_at"] = self.modified_at.isoformat() + data["status"] = self.status.value + return data + + @classmethod + def from_dict(cls, data: dict) -> "FileMetadata": + """从字典创建实例""" + data = data.copy() + data["created_at"] = datetime.fromisoformat(data["created_at"]) + data["modified_at"] = datetime.fromisoformat(data["modified_at"]) + data["status"] = FileStatus(data["status"]) + return cls(**data) + + +class FileLifecycleManager: + """文件生命周期管理器""" + + def __init__(self, storage, dataset_id: Optional[str] = None): + """初始化生命周期管理器 + + Args: + storage: ClickZetta Volume存储实例 + dataset_id: 数据集ID(用于Table Volume) + """ + self._storage = storage + self._dataset_id = dataset_id + self._metadata_file = ".dify_file_metadata.json" + self._version_prefix = ".versions/" + self._backup_prefix = ".backups/" + self._deleted_prefix = ".deleted/" + + # 获取权限管理器(如果存在) + self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None) + + def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata: + """保存文件并管理生命周期 + + Args: + filename: 文件名 + data: 文件内容 + tags: 文件标签 + + Returns: + 文件元数据 + """ + # 权限检查 + if not self._check_permission(filename, "save"): + from .volume_permissions import VolumePermissionError + + raise VolumePermissionError( + f"Permission denied for lifecycle save operation on file: {filename}", + operation="save", + volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"), + dataset_id=self._dataset_id, + ) + + try: + # 1. 检查是否存在旧版本 + metadata_dict = self._load_metadata() + current_metadata = metadata_dict.get(filename) + + # 2. 如果存在旧版本,创建版本备份 + if current_metadata: + self._create_version_backup(filename, current_metadata) + + # 3. 计算文件信息 + now = datetime.now() + checksum = self._calculate_checksum(data) + new_version = (current_metadata["version"] + 1) if current_metadata else 1 + + # 4. 保存新文件 + self._storage.save(filename, data) + + # 5. 创建元数据 + created_at = now + parent_version = None + + if current_metadata: + # 如果created_at是字符串,转换为datetime + if isinstance(current_metadata["created_at"], str): + created_at = datetime.fromisoformat(current_metadata["created_at"]) + else: + created_at = current_metadata["created_at"] + parent_version = current_metadata["version"] + + file_metadata = FileMetadata( + filename=filename, + size=len(data), + created_at=created_at, + modified_at=now, + version=new_version, + status=FileStatus.ACTIVE, + checksum=checksum, + tags=tags or {}, + parent_version=parent_version, + ) + + # 6. 更新元数据 + metadata_dict[filename] = file_metadata.to_dict() + self._save_metadata(metadata_dict) + + logger.info("File %s saved with lifecycle management, version %s", filename, new_version) + return file_metadata + + except Exception as e: + logger.exception("Failed to save file with lifecycle") + raise + + def get_file_metadata(self, filename: str) -> Optional[FileMetadata]: + """获取文件元数据 + + Args: + filename: 文件名 + + Returns: + 文件元数据,如果不存在返回None + """ + try: + metadata_dict = self._load_metadata() + if filename in metadata_dict: + return FileMetadata.from_dict(metadata_dict[filename]) + return None + except Exception as e: + logger.exception("Failed to get file metadata for %s", filename) + return None + + def list_file_versions(self, filename: str) -> list[FileMetadata]: + """列出文件的所有版本 + + Args: + filename: 文件名 + + Returns: + 文件版本列表,按版本号排序 + """ + try: + versions = [] + + # 获取当前版本 + current_metadata = self.get_file_metadata(filename) + if current_metadata: + versions.append(current_metadata) + + # 获取历史版本 + version_pattern = f"{self._version_prefix}{filename}.v*" + try: + version_files = self._storage.scan(self._dataset_id or "", files=True) + for file_path in version_files: + if file_path.startswith(f"{self._version_prefix}{filename}.v"): + # 解析版本号 + version_str = file_path.split(".v")[-1].split(".")[0] + try: + version_num = int(version_str) + # 这里简化处理,实际应该从版本文件中读取元数据 + # 暂时创建基本的元数据信息 + except ValueError: + continue + except: + # 如果无法扫描版本文件,只返回当前版本 + pass + + return sorted(versions, key=lambda x: x.version or 0, reverse=True) + + except Exception as e: + logger.exception("Failed to list file versions for %s", filename) + return [] + + def restore_version(self, filename: str, version: int) -> bool: + """恢复文件到指定版本 + + Args: + filename: 文件名 + version: 要恢复的版本号 + + Returns: + 恢复是否成功 + """ + try: + version_filename = f"{self._version_prefix}{filename}.v{version}" + + # 检查版本文件是否存在 + if not self._storage.exists(version_filename): + logger.warning("Version %s of %s not found", version, filename) + return False + + # 读取版本文件内容 + version_data = self._storage.load_once(version_filename) + + # 保存当前版本为备份 + current_metadata = self.get_file_metadata(filename) + if current_metadata: + self._create_version_backup(filename, current_metadata.to_dict()) + + # 恢复文件 + self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)}) + return True + + except Exception as e: + logger.exception("Failed to restore %s to version %s", filename, version) + return False + + def archive_file(self, filename: str) -> bool: + """归档文件 + + Args: + filename: 文件名 + + Returns: + 归档是否成功 + """ + # 权限检查 + if not self._check_permission(filename, "archive"): + logger.warning("Permission denied for archive operation on file: %s", filename) + return False + + try: + # 更新文件状态为归档 + metadata_dict = self._load_metadata() + if filename not in metadata_dict: + logger.warning("File %s not found in metadata", filename) + return False + + metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value + metadata_dict[filename]["modified_at"] = datetime.now().isoformat() + + self._save_metadata(metadata_dict) + + logger.info("File %s archived successfully", filename) + return True + + except Exception as e: + logger.exception("Failed to archive file %s", filename) + return False + + def soft_delete_file(self, filename: str) -> bool: + """软删除文件(移动到删除目录) + + Args: + filename: 文件名 + + Returns: + 删除是否成功 + """ + # 权限检查 + if not self._check_permission(filename, "delete"): + logger.warning("Permission denied for soft delete operation on file: %s", filename) + return False + + try: + # 检查文件是否存在 + if not self._storage.exists(filename): + logger.warning("File %s not found", filename) + return False + + # 读取文件内容 + file_data = self._storage.load_once(filename) + + # 移动到删除目录 + deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}" + self._storage.save(deleted_filename, file_data) + + # 删除原文件 + self._storage.delete(filename) + + # 更新元数据 + metadata_dict = self._load_metadata() + if filename in metadata_dict: + metadata_dict[filename]["status"] = FileStatus.DELETED.value + metadata_dict[filename]["modified_at"] = datetime.now().isoformat() + self._save_metadata(metadata_dict) + + logger.info("File %s soft deleted successfully", filename) + return True + + except Exception as e: + logger.exception("Failed to soft delete file %s", filename) + return False + + def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int: + """清理旧版本文件 + + Args: + max_versions: 保留的最大版本数 + max_age_days: 版本文件的最大保留天数 + + Returns: + 清理的文件数量 + """ + try: + cleaned_count = 0 + cutoff_date = datetime.now() - timedelta(days=max_age_days) + + # 获取所有版本文件 + try: + all_files = self._storage.scan(self._dataset_id or "", files=True) + version_files = [f for f in all_files if f.startswith(self._version_prefix)] + + # 按文件分组 + file_versions: dict[str, list[tuple[int, str]]] = {} + for version_file in version_files: + # 解析文件名和版本 + parts = version_file[len(self._version_prefix) :].split(".v") + if len(parts) >= 2: + base_filename = parts[0] + version_part = parts[1].split(".")[0] + try: + version_num = int(version_part) + if base_filename not in file_versions: + file_versions[base_filename] = [] + file_versions[base_filename].append((version_num, version_file)) + except ValueError: + continue + + # 清理每个文件的旧版本 + for base_filename, versions in file_versions.items(): + # 按版本号排序 + versions.sort(key=lambda x: x[0], reverse=True) + + # 保留最新的max_versions个版本,删除其余的 + if len(versions) > max_versions: + to_delete = versions[max_versions:] + for version_num, version_file in to_delete: + self._storage.delete(version_file) + cleaned_count += 1 + logger.debug("Cleaned old version: %s", version_file) + + logger.info("Cleaned %d old version files", cleaned_count) + + except Exception as e: + logger.warning("Could not scan for version files: %s", e) + + return cleaned_count + + except Exception as e: + logger.exception("Failed to cleanup old versions") + return 0 + + def get_storage_statistics(self) -> dict[str, Any]: + """获取存储统计信息 + + Returns: + 存储统计字典 + """ + try: + metadata_dict = self._load_metadata() + + stats: dict[str, Any] = { + "total_files": len(metadata_dict), + "active_files": 0, + "archived_files": 0, + "deleted_files": 0, + "total_size": 0, + "versions_count": 0, + "oldest_file": None, + "newest_file": None, + } + + oldest_date = None + newest_date = None + + for filename, metadata in metadata_dict.items(): + file_meta = FileMetadata.from_dict(metadata) + + # 统计文件状态 + if file_meta.status == FileStatus.ACTIVE: + stats["active_files"] = (stats["active_files"] or 0) + 1 + elif file_meta.status == FileStatus.ARCHIVED: + stats["archived_files"] = (stats["archived_files"] or 0) + 1 + elif file_meta.status == FileStatus.DELETED: + stats["deleted_files"] = (stats["deleted_files"] or 0) + 1 + + # 统计大小 + stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0) + + # 统计版本 + stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0) + + # 找出最新和最旧的文件 + if oldest_date is None or file_meta.created_at < oldest_date: + oldest_date = file_meta.created_at + stats["oldest_file"] = filename + + if newest_date is None or file_meta.modified_at > newest_date: + newest_date = file_meta.modified_at + stats["newest_file"] = filename + + return stats + + except Exception as e: + logger.exception("Failed to get storage statistics") + return {} + + def _create_version_backup(self, filename: str, metadata: dict): + """创建版本备份""" + try: + # 读取当前文件内容 + current_data = self._storage.load_once(filename) + + # 保存为版本文件 + version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}" + self._storage.save(version_filename, current_data) + + logger.debug("Created version backup: %s", version_filename) + + except Exception as e: + logger.warning("Failed to create version backup for %s: %s", filename, e) + + def _load_metadata(self) -> dict[str, Any]: + """加载元数据文件""" + try: + if self._storage.exists(self._metadata_file): + metadata_content = self._storage.load_once(self._metadata_file) + result = json.loads(metadata_content.decode("utf-8")) + return dict(result) if result else {} + else: + return {} + except Exception as e: + logger.warning("Failed to load metadata: %s", e) + return {} + + def _save_metadata(self, metadata_dict: dict): + """保存元数据文件""" + try: + metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False) + self._storage.save(self._metadata_file, metadata_content.encode("utf-8")) + logger.debug("Metadata saved successfully") + except Exception as e: + logger.exception("Failed to save metadata") + raise + + def _calculate_checksum(self, data: bytes) -> str: + """计算文件校验和""" + import hashlib + + return hashlib.md5(data).hexdigest() + + def _check_permission(self, filename: str, operation: str) -> bool: + """检查文件操作权限 + + Args: + filename: 文件名 + operation: 操作类型 + + Returns: + True if permission granted, False otherwise + """ + # 如果没有权限管理器,默认允许 + if not self._permission_manager: + return True + + try: + # 根据操作类型映射到权限 + operation_mapping = { + "save": "save", + "load": "load_once", + "delete": "delete", + "archive": "delete", # 归档需要删除权限 + "restore": "save", # 恢复需要写权限 + "cleanup": "delete", # 清理需要删除权限 + "read": "load_once", + "write": "save", + } + + mapped_operation = operation_mapping.get(operation, operation) + + # 检查权限 + result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id) + return bool(result) + + except Exception as e: + logger.exception("Permission check failed for %s operation %s", filename, operation) + # 安全默认:权限检查失败时拒绝访问 + return False diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py new file mode 100644 index 0000000000..4801df5102 --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -0,0 +1,646 @@ +"""ClickZetta Volume权限管理机制 + +该模块提供Volume权限检查、验证和管理功能。 +根据ClickZetta的权限模型,不同Volume类型有不同的权限要求。 +""" + +import logging +from enum import Enum +from typing import Optional + +logger = logging.getLogger(__name__) + + +class VolumePermission(Enum): + """Volume权限类型枚举""" + + READ = "SELECT" # 对应ClickZetta的SELECT权限 + WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限 + LIST = "SELECT" # 列出文件需要SELECT权限 + DELETE = "INSERT,UPDATE,DELETE" # 删除文件需要写权限 + USAGE = "USAGE" # External Volume需要的基本权限 + + +class VolumePermissionManager: + """Volume权限管理器""" + + def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None): + """初始化权限管理器 + + Args: + connection_or_config: ClickZetta连接对象或配置字典 + volume_type: Volume类型 (user|table|external) + volume_name: Volume名称 (用于external volume) + """ + # 支持两种初始化方式:连接对象或配置字典 + if isinstance(connection_or_config, dict): + # 从配置字典创建连接 + import clickzetta # type: ignore[import-untyped] + + config = connection_or_config + self._connection = clickzetta.connect( + username=config.get("username"), + password=config.get("password"), + instance=config.get("instance"), + service=config.get("service"), + workspace=config.get("workspace"), + vcluster=config.get("vcluster"), + schema=config.get("schema") or config.get("database"), + ) + self._volume_type = config.get("volume_type", volume_type) + self._volume_name = config.get("volume_name", volume_name) + else: + # 直接使用连接对象 + self._connection = connection_or_config + self._volume_type = volume_type + self._volume_name = volume_name + + if not self._connection: + raise ValueError("Valid connection or config is required") + if not self._volume_type: + raise ValueError("volume_type is required") + + self._permission_cache: dict[str, set[str]] = {} + self._current_username = None # 将从连接中获取当前用户名 + + def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool: + """检查用户是否有执行特定操作的权限 + + Args: + operation: 要执行的操作类型 + dataset_id: 数据集ID (用于table volume) + + Returns: + True if user has permission, False otherwise + """ + try: + if self._volume_type == "user": + return self._check_user_volume_permission(operation) + elif self._volume_type == "table": + return self._check_table_volume_permission(operation, dataset_id) + elif self._volume_type == "external": + return self._check_external_volume_permission(operation) + else: + logger.warning("Unknown volume type: %s", self._volume_type) + return False + + except Exception as e: + logger.exception("Permission check failed") + return False + + def _check_user_volume_permission(self, operation: VolumePermission) -> bool: + """检查User Volume权限 + + User Volume权限规则: + - 用户对自己的User Volume有全部权限 + - 只要用户能够连接到ClickZetta,就默认具有User Volume的基本权限 + - 更注重连接身份验证,而不是复杂的权限检查 + """ + try: + # 获取当前用户名 + current_user = self._get_current_username() + + # 检查基本连接状态 + with self._connection.cursor() as cursor: + # 简单的连接测试,如果能执行查询说明用户有基本权限 + cursor.execute("SELECT 1") + result = cursor.fetchone() + + if result: + logger.debug( + "User Volume permission check for %s, operation %s: granted (basic connection verified)", + current_user, + operation.name, + ) + return True + else: + logger.warning( + "User Volume permission check failed: cannot verify basic connection for %s", current_user + ) + return False + + except Exception as e: + logger.exception("User Volume permission check failed") + # 对于User Volume,如果权限检查失败,可能是配置问题,给出更友好的错误提示 + logger.info("User Volume permission check failed, but permission checking is disabled in this version") + return False + + def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool: + """检查Table Volume权限 + + Table Volume权限规则: + - Table Volume权限继承对应表的权限 + - SELECT权限 -> 可以READ/LIST文件 + - INSERT,UPDATE,DELETE权限 -> 可以WRITE/DELETE文件 + """ + if not dataset_id: + logger.warning("dataset_id is required for table volume permission check") + return False + + table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id + + try: + # 检查表权限 + permissions = self._get_table_permissions(table_name) + required_permissions = set(operation.value.split(",")) + + # 检查是否有所需的所有权限 + has_permission = required_permissions.issubset(permissions) + + logger.debug( + "Table Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s", + table_name, + operation.name, + required_permissions, + permissions, + has_permission, + ) + + return has_permission + + except Exception as e: + logger.exception("Table volume permission check failed for %s", table_name) + return False + + def _check_external_volume_permission(self, operation: VolumePermission) -> bool: + """检查External Volume权限 + + External Volume权限规则: + - 尝试获取对External Volume的权限 + - 如果权限检查失败,进行备选验证 + - 对于开发环境,提供更宽松的权限检查 + """ + if not self._volume_name: + logger.warning("volume_name is required for external volume permission check") + return False + + try: + # 检查External Volume权限 + permissions = self._get_external_volume_permissions(self._volume_name) + + # External Volume权限映射:根据操作类型确定所需权限 + required_permissions = set() + + if operation in [VolumePermission.READ, VolumePermission.LIST]: + required_permissions.add("read") + elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]: + required_permissions.add("write") + + # 检查是否有所需的所有权限 + has_permission = required_permissions.issubset(permissions) + + logger.debug( + "External Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s", + self._volume_name, + operation.name, + required_permissions, + permissions, + has_permission, + ) + + # 如果权限检查失败,尝试备选验证 + if not has_permission: + logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name) + + # 备选验证:尝试列出Volume来验证基本访问权限 + try: + with self._connection.cursor() as cursor: + cursor.execute("SHOW VOLUMES") + volumes = cursor.fetchall() + for volume in volumes: + if len(volume) > 0 and volume[0] == self._volume_name: + logger.info("Fallback verification successful for %s", self._volume_name) + return True + except Exception as fallback_e: + logger.warning("Fallback verification failed for %s: %s", self._volume_name, fallback_e) + + return has_permission + + except Exception as e: + logger.exception("External volume permission check failed for %s", self._volume_name) + logger.info("External Volume permission check failed, but permission checking is disabled in this version") + return False + + def _get_table_permissions(self, table_name: str) -> set[str]: + """获取用户对指定表的权限 + + Args: + table_name: 表名 + + Returns: + 用户对该表的权限集合 + """ + cache_key = f"table:{table_name}" + + if cache_key in self._permission_cache: + return self._permission_cache[cache_key] + + permissions = set() + + try: + with self._connection.cursor() as cursor: + # 使用正确的ClickZetta语法检查当前用户权限 + cursor.execute("SHOW GRANTS") + grants = cursor.fetchall() + + # 解析权限结果,查找对该表的权限 + for grant in grants: + if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) + privilege = grant[0].upper() + object_type = grant[1].upper() if len(grant) > 1 else "" + object_name = grant[2] if len(grant) > 2 else "" + + # 检查是否是对该表的权限 + if ( + object_type == "TABLE" + and object_name == table_name + or object_type == "SCHEMA" + and object_name in table_name + ): + if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: + if privilege == "ALL": + permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) + else: + permissions.add(privilege) + + # 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限 + if not permissions: + try: + cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1") + permissions.add("SELECT") + except Exception: + logger.debug("Cannot query table %s, no SELECT permission", table_name) + + except Exception as e: + logger.warning("Could not check table permissions for %s: %s", table_name, e) + # 安全默认:权限检查失败时拒绝访问 + pass + + # 缓存权限信息 + self._permission_cache[cache_key] = permissions + return permissions + + def _get_current_username(self) -> str: + """获取当前用户名""" + if self._current_username: + return self._current_username + + try: + with self._connection.cursor() as cursor: + cursor.execute("SELECT CURRENT_USER()") + result = cursor.fetchone() + if result: + self._current_username = result[0] + return str(self._current_username) + except Exception as e: + logger.exception("Failed to get current username") + + return "unknown" + + def _get_user_permissions(self, username: str) -> set[str]: + """获取用户的基本权限集合""" + cache_key = f"user_permissions:{username}" + + if cache_key in self._permission_cache: + return self._permission_cache[cache_key] + + permissions = set() + + try: + with self._connection.cursor() as cursor: + # 使用正确的ClickZetta语法检查当前用户权限 + cursor.execute("SHOW GRANTS") + grants = cursor.fetchall() + + # 解析权限结果,查找用户的基本权限 + for grant in grants: + if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) + privilege = grant[0].upper() + object_type = grant[1].upper() if len(grant) > 1 else "" + + # 收集所有相关权限 + if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: + if privilege == "ALL": + permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) + else: + permissions.add(privilege) + + except Exception as e: + logger.warning("Could not check user permissions for %s: %s", username, e) + # 安全默认:权限检查失败时拒绝访问 + pass + + # 缓存权限信息 + self._permission_cache[cache_key] = permissions + return permissions + + def _get_external_volume_permissions(self, volume_name: str) -> set[str]: + """获取用户对指定External Volume的权限 + + Args: + volume_name: External Volume名称 + + Returns: + 用户对该Volume的权限集合 + """ + cache_key = f"external_volume:{volume_name}" + + if cache_key in self._permission_cache: + return self._permission_cache[cache_key] + + permissions = set() + + try: + with self._connection.cursor() as cursor: + # 使用正确的ClickZetta语法检查Volume权限 + logger.info("Checking permissions for volume: %s", volume_name) + cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}") + grants = cursor.fetchall() + + logger.info("Raw grants result for %s: %s", volume_name, grants) + + # 解析权限结果 + # 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to, + # grantee_name, grantor_name, grant_option, granted_time) + for grant in grants: + logger.info("Processing grant: %s", grant) + if len(grant) >= 5: + granted_type = grant[0] + privilege = grant[1].upper() + granted_on = grant[3] + object_name = grant[4] + + logger.info( + "Grant details - type: %s, privilege: %s, granted_on: %s, object_name: %s", + granted_type, + privilege, + granted_on, + object_name, + ) + + # 检查是否是对该Volume的权限或者是层级权限 + if ( + granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name) + ) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"): + logger.info("Matching grant found for %s", volume_name) + + if "READ" in privilege: + permissions.add("read") + logger.info("Added READ permission for %s", volume_name) + if "WRITE" in privilege: + permissions.add("write") + logger.info("Added WRITE permission for %s", volume_name) + if "ALTER" in privilege: + permissions.add("alter") + logger.info("Added ALTER permission for %s", volume_name) + if privilege == "ALL": + permissions.update(["read", "write", "alter"]) + logger.info("Added ALL permissions for %s", volume_name) + + logger.info("Final permissions for %s: %s", volume_name, permissions) + + # 如果没有找到明确的权限,尝试查看Volume列表来验证基本权限 + if not permissions: + try: + cursor.execute("SHOW VOLUMES") + volumes = cursor.fetchall() + for volume in volumes: + if len(volume) > 0 and volume[0] == volume_name: + permissions.add("read") # 至少有读权限 + logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name) + break + except Exception: + logger.debug("Cannot access volume %s, no basic permission", volume_name) + + except Exception as e: + logger.warning("Could not check external volume permissions for %s: %s", volume_name, e) + # 在权限检查失败时,尝试基本的Volume访问验证 + try: + with self._connection.cursor() as cursor: + cursor.execute("SHOW VOLUMES") + volumes = cursor.fetchall() + for volume in volumes: + if len(volume) > 0 and volume[0] == volume_name: + logger.info("Basic volume access verified for %s", volume_name) + permissions.add("read") + permissions.add("write") # 假设有写权限 + break + except Exception as basic_e: + logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e) + # 最后的备选方案:假设有基本权限 + permissions.add("read") + + # 缓存权限信息 + self._permission_cache[cache_key] = permissions + return permissions + + def clear_permission_cache(self): + """清空权限缓存""" + self._permission_cache.clear() + logger.debug("Permission cache cleared") + + def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]: + """获取权限摘要 + + Args: + dataset_id: 数据集ID (用于table volume) + + Returns: + 权限摘要字典 + """ + summary = {} + + for operation in VolumePermission: + summary[operation.name.lower()] = self.check_permission(operation, dataset_id) + + return summary + + def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool: + """检查文件路径的权限继承 + + Args: + file_path: 文件路径 + operation: 要执行的操作 + + Returns: + True if user has permission, False otherwise + """ + try: + # 解析文件路径 + path_parts = file_path.strip("/").split("/") + + if not path_parts: + logger.warning("Invalid file path for permission inheritance check") + return False + + # 对于Table Volume,第一层是dataset_id + if self._volume_type == "table": + if len(path_parts) < 1: + return False + + dataset_id = path_parts[0] + + # 检查对dataset的权限 + has_dataset_permission = self.check_permission(operation, dataset_id) + + if not has_dataset_permission: + logger.debug("Permission denied for dataset %s", dataset_id) + return False + + # 检查路径遍历攻击 + if self._contains_path_traversal(file_path): + logger.warning("Path traversal attack detected: %s", file_path) + return False + + # 检查是否访问敏感目录 + if self._is_sensitive_path(file_path): + logger.warning("Access to sensitive path denied: %s", file_path) + return False + + logger.debug("Permission inherited for path %s", file_path) + return True + + elif self._volume_type == "user": + # User Volume的权限继承 + current_user = self._get_current_username() + + # 检查是否试图访问其他用户的目录 + if len(path_parts) > 1 and path_parts[0] != current_user: + logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0]) + return False + + # 检查基本权限 + return self.check_permission(operation) + + elif self._volume_type == "external": + # External Volume的权限继承 + # 检查对External Volume的权限 + return self.check_permission(operation) + + else: + logger.warning("Unknown volume type for permission inheritance: %s", self._volume_type) + return False + + except Exception as e: + logger.exception("Permission inheritance check failed") + return False + + def _contains_path_traversal(self, file_path: str) -> bool: + """检查路径是否包含路径遍历攻击""" + # 检查常见的路径遍历模式 + traversal_patterns = [ + "../", + "..\\", + "..%2f", + "..%2F", + "..%5c", + "..%5C", + "%2e%2e%2f", + "%2e%2e%5c", + "....//", + "....\\\\", + ] + + file_path_lower = file_path.lower() + + for pattern in traversal_patterns: + if pattern in file_path_lower: + return True + + # 检查绝对路径 + if file_path.startswith("/") or file_path.startswith("\\"): + return True + + # 检查Windows驱动器路径 + if len(file_path) >= 2 and file_path[1] == ":": + return True + + return False + + def _is_sensitive_path(self, file_path: str) -> bool: + """检查路径是否为敏感路径""" + sensitive_patterns = [ + "passwd", + "shadow", + "hosts", + "config", + "secrets", + "private", + "key", + "certificate", + "cert", + "ssl", + "database", + "backup", + "dump", + "log", + "tmp", + ] + + file_path_lower = file_path.lower() + + return any(pattern in file_path_lower for pattern in sensitive_patterns) + + def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool: + """验证操作权限 + + Args: + operation: 操作名称 (save|load|exists|delete|scan) + dataset_id: 数据集ID + + Returns: + True if operation is allowed, False otherwise + """ + operation_mapping = { + "save": VolumePermission.WRITE, + "load": VolumePermission.READ, + "load_once": VolumePermission.READ, + "load_stream": VolumePermission.READ, + "download": VolumePermission.READ, + "exists": VolumePermission.READ, + "delete": VolumePermission.DELETE, + "scan": VolumePermission.LIST, + } + + if operation not in operation_mapping: + logger.warning("Unknown operation: %s", operation) + return False + + volume_permission = operation_mapping[operation] + return self.check_permission(volume_permission, dataset_id) + + +class VolumePermissionError(Exception): + """Volume权限错误异常""" + + def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None): + self.operation = operation + self.volume_type = volume_type + self.dataset_id = dataset_id + super().__init__(message) + + +def check_volume_permission( + permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None +) -> None: + """权限检查装饰器函数 + + Args: + permission_manager: 权限管理器 + operation: 操作名称 + dataset_id: 数据集ID + + Raises: + VolumePermissionError: 如果没有权限 + """ + if not permission_manager.validate_operation(operation, dataset_id): + error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume" + if dataset_id: + error_message += f" (dataset: {dataset_id})" + + raise VolumePermissionError( + error_message, + operation=operation, + volume_type=permission_manager._volume_type or "unknown", + dataset_id=dataset_id, + ) diff --git a/api/extensions/storage/storage_type.py b/api/extensions/storage/storage_type.py index 0a891e36cf..bc2d632159 100644 --- a/api/extensions/storage/storage_type.py +++ b/api/extensions/storage/storage_type.py @@ -5,6 +5,7 @@ class StorageType(StrEnum): ALIYUN_OSS = "aliyun-oss" AZURE_BLOB = "azure-blob" BAIDU_OBS = "baidu-obs" + CLICKZETTA_VOLUME = "clickzetta-volume" GOOGLE_STORAGE = "google-storage" HUAWEI_OBS = "huawei-obs" LOCAL = "local" diff --git a/api/libs/rsa.py b/api/libs/rsa.py index 598e5bc9e3..c72032701f 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -1,5 +1,4 @@ import hashlib -import os from typing import Union from Crypto.Cipher import AES @@ -18,7 +17,7 @@ def generate_key_pair(tenant_id: str) -> str: pem_private = private_key.export_key() pem_public = public_key.export_key() - filepath = os.path.join("privkeys", tenant_id, "private.pem") + filepath = f"privkeys/{tenant_id}/private.pem" storage.save(filepath, pem_private) @@ -48,7 +47,7 @@ def encrypt(text: str, public_key: Union[str, bytes]) -> bytes: def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]: - filepath = os.path.join("privkeys", tenant_id, "private.pem") + filepath = f"privkeys/{tenant_id}/private.pem" cache_key = f"tenant_privkey:{hashlib.sha3_256(filepath.encode()).hexdigest()}" private_key = redis_client.get(cache_key) diff --git a/api/pyproject.toml b/api/pyproject.toml index 9d979eca1c..a86ec7ee6b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -194,6 +194,7 @@ vdb = [ "alibabacloud_tea_openapi~=0.3.9", "chromadb==0.5.20", "clickhouse-connect~=0.7.16", + "clickzetta-connector-python>=0.8.102", "couchbase~=4.3.0", "elasticsearch==8.14.0", "opensearch-py==2.4.0", @@ -213,3 +214,4 @@ vdb = [ "xinference-client~=1.2.2", "mo-vector~=0.1.13", ] + diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 2298acf6eb..2b74fb2dd0 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -3,7 +3,7 @@ import time import click from sqlalchemy import text -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -27,8 +27,8 @@ def clean_embedding_cache_task(): .all() ) embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] - except NotFound: - break + except SQLAlchemyError: + raise if embedding_ids: for embedding_id in embedding_ids: db.session.execute( diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 4c35745959..a896c818a5 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -3,7 +3,7 @@ import logging import time import click -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -42,8 +42,8 @@ def clean_messages(): .all() ) - except NotFound: - break + except SQLAlchemyError: + raise if not messages: break for message in messages: diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 7887835bc5..940da5309e 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -3,7 +3,7 @@ import time import click from sqlalchemy import func, select -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -65,8 +65,8 @@ def clean_unused_datasets_task(): datasets = db.paginate(stmt, page=1, per_page=50) - except NotFound: - break + except SQLAlchemyError: + raise if datasets.items is None or len(datasets.items) == 0: break for dataset in datasets: @@ -146,8 +146,8 @@ def clean_unused_datasets_task(): ) datasets = db.paginate(stmt, page=1, per_page=50) - except NotFound: - break + except SQLAlchemyError: + raise if datasets.items is None or len(datasets.items) == 0: break for dataset in datasets: diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 692a3639cd..713c4c6782 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -50,12 +50,16 @@ class ConversationService: Conversation.from_account_id == (user.id if isinstance(user, Account) else None), or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) - # Check if include_ids is not None and not empty to avoid WHERE false condition - if include_ids is not None and len(include_ids) > 0: + # Check if include_ids is not None to apply filter + if include_ids is not None: + if len(include_ids) == 0: + # If include_ids is empty, return empty result + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) stmt = stmt.where(Conversation.id.in_(include_ids)) - # Check if exclude_ids is not None and not empty to avoid WHERE false condition - if exclude_ids is not None and len(exclude_ids) > 0: - stmt = stmt.where(~Conversation.id.in_(exclude_ids)) + # Check if exclude_ids is not None to apply filter + if exclude_ids is not None: + if len(exclude_ids) > 0: + stmt = stmt.where(~Conversation.id.in_(exclude_ids)) # define sort fields and directions sort_field, sort_direction = cls._get_sort_params(sort_by) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 2d62d49d91..6bbb3bca04 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -256,7 +256,7 @@ class WorkflowDraftVariableService: def _reset_node_var_or_sys_var( self, workflow: Workflow, variable: WorkflowDraftVariable ) -> WorkflowDraftVariable | None: - # If a variable does not allow updating, it makes no sence to resetting it. + # If a variable does not allow updating, it makes no sense to reset it. if not variable.editable: return variable # No execution record for this variable, delete the variable instead. @@ -478,7 +478,7 @@ def _batch_upsert_draft_variable( "node_execution_id": stmt.excluded.node_execution_id, }, ) - elif _UpsertPolicy.IGNORE: + elif policy == _UpsertPolicy.IGNORE: stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name()) else: raise Exception("Invalid value for update policy.") diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index fe6d613b1c..69e5df0253 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -56,15 +56,24 @@ def clean_dataset_task( documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() + # Fix: Always clean vector database resources regardless of document existence + # This ensures all 33 vector databases properly drop tables/collections/indices + if doc_form is None: + # Use default paragraph index type for empty datasets to enable vector database cleanup + from core.rag.index_processor.constant.index_type import IndexType + + doc_form = IndexType.PARAGRAPH_INDEX + logging.info( + click.style(f"No documents found, using default index type for cleanup: {doc_form}", fg="yellow") + ) + + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) + if documents is None or len(documents) == 0: logging.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green")) else: logging.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green")) - # Specify the index type before initializing the index processor - if doc_form is None: - raise ValueError("Index type must be specified.") - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) for document in documents: db.session.delete(document) diff --git a/api/tests/integration_tests/controllers/console/app/test_description_validation.py b/api/tests/integration_tests/controllers/console/app/test_description_validation.py new file mode 100644 index 0000000000..2d0ceac760 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_description_validation.py @@ -0,0 +1,168 @@ +""" +Unit tests for App description validation functions. + +This test module validates the 400-character limit enforcement +for App descriptions across all creation and editing endpoints. +""" + +import os +import sys + +import pytest + +# Add the API root to Python path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + + +class TestAppDescriptionValidationUnit: + """Unit tests for description validation function""" + + def test_validate_description_length_function(self): + """Test the _validate_description_length function directly""" + from controllers.console.app.app import _validate_description_length + + # Test valid descriptions + assert _validate_description_length("") == "" + assert _validate_description_length("x" * 400) == "x" * 400 + assert _validate_description_length(None) is None + + # Test invalid descriptions + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 401) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 500) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 1000) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_validation_consistency_with_dataset(self): + """Test that App and Dataset validation functions are consistent""" + from controllers.console.app.app import _validate_description_length as app_validate + from controllers.console.datasets.datasets import _validate_description_length as dataset_validate + from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate + + # Test same valid inputs + valid_desc = "x" * 400 + assert app_validate(valid_desc) == dataset_validate(valid_desc) == service_dataset_validate(valid_desc) + assert app_validate("") == dataset_validate("") == service_dataset_validate("") + assert app_validate(None) == dataset_validate(None) == service_dataset_validate(None) + + # Test same invalid inputs produce same error + invalid_desc = "x" * 401 + + app_error = None + dataset_error = None + service_dataset_error = None + + try: + app_validate(invalid_desc) + except ValueError as e: + app_error = str(e) + + try: + dataset_validate(invalid_desc) + except ValueError as e: + dataset_error = str(e) + + try: + service_dataset_validate(invalid_desc) + except ValueError as e: + service_dataset_error = str(e) + + assert app_error == dataset_error == service_dataset_error + assert app_error == "Description cannot exceed 400 characters." + + def test_boundary_values(self): + """Test boundary values for description validation""" + from controllers.console.app.app import _validate_description_length + + # Test exact boundary + exactly_400 = "x" * 400 + assert _validate_description_length(exactly_400) == exactly_400 + + # Test just over boundary + just_over_400 = "x" * 401 + with pytest.raises(ValueError): + _validate_description_length(just_over_400) + + # Test just under boundary + just_under_400 = "x" * 399 + assert _validate_description_length(just_under_400) == just_under_400 + + def test_edge_cases(self): + """Test edge cases for description validation""" + from controllers.console.app.app import _validate_description_length + + # Test None input + assert _validate_description_length(None) is None + + # Test empty string + assert _validate_description_length("") == "" + + # Test single character + assert _validate_description_length("a") == "a" + + # Test unicode characters + unicode_desc = "测试" * 200 # 400 characters in Chinese + assert _validate_description_length(unicode_desc) == unicode_desc + + # Test unicode over limit + unicode_over = "测试" * 201 # 402 characters + with pytest.raises(ValueError): + _validate_description_length(unicode_over) + + def test_whitespace_handling(self): + """Test how validation handles whitespace""" + from controllers.console.app.app import _validate_description_length + + # Test description with spaces + spaces_400 = " " * 400 + assert _validate_description_length(spaces_400) == spaces_400 + + # Test description with spaces over limit + spaces_401 = " " * 401 + with pytest.raises(ValueError): + _validate_description_length(spaces_401) + + # Test mixed content + mixed_400 = "a" * 200 + " " * 200 + assert _validate_description_length(mixed_400) == mixed_400 + + # Test mixed over limit + mixed_401 = "a" * 200 + " " * 201 + with pytest.raises(ValueError): + _validate_description_length(mixed_401) + + +if __name__ == "__main__": + # Run tests directly + import traceback + + test_instance = TestAppDescriptionValidationUnit() + test_methods = [method for method in dir(test_instance) if method.startswith("test_")] + + passed = 0 + failed = 0 + + for test_method in test_methods: + try: + print(f"Running {test_method}...") + getattr(test_instance, test_method)() + print(f"✅ {test_method} PASSED") + passed += 1 + except Exception as e: + print(f"❌ {test_method} FAILED: {str(e)}") + traceback.print_exc() + failed += 1 + + print(f"\n📊 Test Results: {passed} passed, {failed} failed") + + if failed == 0: + print("🎉 All tests passed!") + else: + print("💥 Some tests failed!") + sys.exit(1) diff --git a/api/tests/integration_tests/storage/test_clickzetta_volume.py b/api/tests/integration_tests/storage/test_clickzetta_volume.py new file mode 100644 index 0000000000..293b469ef3 --- /dev/null +++ b/api/tests/integration_tests/storage/test_clickzetta_volume.py @@ -0,0 +1,168 @@ +"""Integration tests for ClickZetta Volume Storage.""" + +import os +import tempfile +import unittest + +import pytest + +from extensions.storage.clickzetta_volume.clickzetta_volume_storage import ( + ClickZettaVolumeConfig, + ClickZettaVolumeStorage, +) + + +class TestClickZettaVolumeStorage(unittest.TestCase): + """Test cases for ClickZetta Volume Storage.""" + + def setUp(self): + """Set up test environment.""" + self.config = ClickZettaVolumeConfig( + username=os.getenv("CLICKZETTA_USERNAME", "test_user"), + password=os.getenv("CLICKZETTA_PASSWORD", "test_pass"), + instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"), + service=os.getenv("CLICKZETTA_SERVICE", "uat-api.clickzetta.com"), + workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"), + vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"), + schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"), + volume_type="table", + table_prefix="test_dataset_", + ) + + @pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided") + def test_user_volume_operations(self): + """Test basic operations with User Volume.""" + config = self.config + config.volume_type = "user" + + storage = ClickZettaVolumeStorage(config) + + # Test file operations + test_filename = "test_file.txt" + test_content = b"Hello, ClickZetta Volume!" + + # Save file + storage.save(test_filename, test_content) + + # Check if file exists + assert storage.exists(test_filename) + + # Load file + loaded_content = storage.load_once(test_filename) + assert loaded_content == test_content + + # Test streaming + stream_content = b"" + for chunk in storage.load_stream(test_filename): + stream_content += chunk + assert stream_content == test_content + + # Test download + with tempfile.NamedTemporaryFile() as temp_file: + storage.download(test_filename, temp_file.name) + with open(temp_file.name, "rb") as f: + downloaded_content = f.read() + assert downloaded_content == test_content + + # Test scan + files = storage.scan("", files=True, directories=False) + assert test_filename in files + + # Delete file + storage.delete(test_filename) + assert not storage.exists(test_filename) + + @pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided") + def test_table_volume_operations(self): + """Test basic operations with Table Volume.""" + config = self.config + config.volume_type = "table" + + storage = ClickZettaVolumeStorage(config) + + # Test file operations with dataset_id + dataset_id = "12345" + test_filename = f"{dataset_id}/test_file.txt" + test_content = b"Hello, Table Volume!" + + # Save file + storage.save(test_filename, test_content) + + # Check if file exists + assert storage.exists(test_filename) + + # Load file + loaded_content = storage.load_once(test_filename) + assert loaded_content == test_content + + # Test scan for dataset + files = storage.scan(dataset_id, files=True, directories=False) + assert "test_file.txt" in files + + # Delete file + storage.delete(test_filename) + assert not storage.exists(test_filename) + + def test_config_validation(self): + """Test configuration validation.""" + # Test missing required fields + with pytest.raises(ValueError): + ClickZettaVolumeConfig( + username="", # Empty username should fail + password="pass", + instance="instance", + ) + + # Test invalid volume type + with pytest.raises(ValueError): + ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type") + + # Test external volume without volume_name + with pytest.raises(ValueError): + ClickZettaVolumeConfig( + username="user", + password="pass", + instance="instance", + volume_type="external", + # Missing volume_name + ) + + def test_volume_path_generation(self): + """Test volume path generation for different types.""" + storage = ClickZettaVolumeStorage(self.config) + + # Test table volume path + path = storage._get_volume_path("test.txt", "12345") + assert path == "test_dataset_12345/test.txt" + + # Test path with existing dataset_id prefix + path = storage._get_volume_path("12345/test.txt") + assert path == "12345/test.txt" + + # Test user volume + storage._config.volume_type = "user" + path = storage._get_volume_path("test.txt") + assert path == "test.txt" + + def test_sql_prefix_generation(self): + """Test SQL prefix generation for different volume types.""" + storage = ClickZettaVolumeStorage(self.config) + + # Test table volume SQL prefix + prefix = storage._get_volume_sql_prefix("12345") + assert prefix == "TABLE VOLUME test_dataset_12345" + + # Test user volume SQL prefix + storage._config.volume_type = "user" + prefix = storage._get_volume_sql_prefix() + assert prefix == "USER VOLUME" + + # Test external volume SQL prefix + storage._config.volume_type = "external" + storage._config.volume_name = "my_external_volume" + prefix = storage._get_volume_sql_prefix() + assert prefix == "VOLUME my_external_volume" + + +if __name__ == "__main__": + unittest.main() diff --git a/api/tests/integration_tests/vdb/clickzetta/README.md b/api/tests/integration_tests/vdb/clickzetta/README.md new file mode 100644 index 0000000000..c16dca8018 --- /dev/null +++ b/api/tests/integration_tests/vdb/clickzetta/README.md @@ -0,0 +1,25 @@ +# Clickzetta Integration Tests + +## Running Tests + +To run the Clickzetta integration tests, you need to set the following environment variables: + +```bash +export CLICKZETTA_USERNAME=your_username +export CLICKZETTA_PASSWORD=your_password +export CLICKZETTA_INSTANCE=your_instance +export CLICKZETTA_SERVICE=api.clickzetta.com +export CLICKZETTA_WORKSPACE=your_workspace +export CLICKZETTA_VCLUSTER=your_vcluster +export CLICKZETTA_SCHEMA=dify +``` + +Then run the tests: + +```bash +pytest api/tests/integration_tests/vdb/clickzetta/ +``` + +## Security Note + +Never commit credentials to the repository. Always use environment variables or secure credential management systems. diff --git a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py new file mode 100644 index 0000000000..8b57132772 --- /dev/null +++ b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py @@ -0,0 +1,224 @@ +import os + +import pytest + +from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector +from core.rag.models.document import Document +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + + +class TestClickzettaVector(AbstractVectorTest): + """ + Test cases for Clickzetta vector database integration. + """ + + @pytest.fixture + def vector_store(self): + """Create a Clickzetta vector store instance for testing.""" + # Skip test if Clickzetta credentials are not configured + if not os.getenv("CLICKZETTA_USERNAME"): + pytest.skip("CLICKZETTA_USERNAME is not configured") + if not os.getenv("CLICKZETTA_PASSWORD"): + pytest.skip("CLICKZETTA_PASSWORD is not configured") + if not os.getenv("CLICKZETTA_INSTANCE"): + pytest.skip("CLICKZETTA_INSTANCE is not configured") + + config = ClickzettaConfig( + username=os.getenv("CLICKZETTA_USERNAME", ""), + password=os.getenv("CLICKZETTA_PASSWORD", ""), + instance=os.getenv("CLICKZETTA_INSTANCE", ""), + service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"), + workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"), + vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"), + schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"), + batch_size=10, # Small batch size for testing + enable_inverted_index=True, + analyzer_type="chinese", + analyzer_mode="smart", + vector_distance_function="cosine_distance", + ) + + with setup_mock_redis(): + vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config) + + yield vector + + # Cleanup: delete the test collection + try: + vector.delete() + except Exception: + pass + + def test_clickzetta_vector_basic_operations(self, vector_store): + """Test basic CRUD operations on Clickzetta vector store.""" + # Prepare test data + texts = [ + "这是第一个测试文档,包含一些中文内容。", + "This is the second test document with English content.", + "第三个文档混合了English和中文内容。", + ] + embeddings = [ + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 1.0, 1.1, 1.2], + ] + documents = [ + Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"}) + for i, text in enumerate(texts) + ] + + # Test create (initial insert) + vector_store.create(texts=documents, embeddings=embeddings) + + # Test text_exists + assert vector_store.text_exists("doc_0") + assert not vector_store.text_exists("doc_999") + + # Test search_by_vector + query_vector = [0.1, 0.2, 0.3, 0.4] + results = vector_store.search_by_vector(query_vector, top_k=2) + assert len(results) > 0 + assert results[0].page_content == texts[0] # Should match the first document + + # Test search_by_full_text (Chinese) + results = vector_store.search_by_full_text("中文", top_k=3) + assert len(results) >= 2 # Should find documents with Chinese content + + # Test search_by_full_text (English) + results = vector_store.search_by_full_text("English", top_k=3) + assert len(results) >= 2 # Should find documents with English content + + # Test delete_by_ids + vector_store.delete_by_ids(["doc_0"]) + assert not vector_store.text_exists("doc_0") + assert vector_store.text_exists("doc_1") + + # Test delete_by_metadata_field + vector_store.delete_by_metadata_field("source", "test") + assert not vector_store.text_exists("doc_1") + assert not vector_store.text_exists("doc_2") + + def test_clickzetta_vector_advanced_search(self, vector_store): + """Test advanced search features of Clickzetta vector store.""" + # Prepare test data with more complex metadata + documents = [] + embeddings = [] + for i in range(10): + doc = Document( + page_content=f"Document {i}: " + get_example_text(), + metadata={ + "doc_id": f"adv_doc_{i}", + "category": "technical" if i % 2 == 0 else "general", + "document_id": f"doc_{i // 3}", # Group documents + "importance": i, + }, + ) + documents.append(doc) + # Create varied embeddings + embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i]) + + vector_store.create(texts=documents, embeddings=embeddings) + + # Test vector search with document filter + query_vector = [0.5, 1.0, 1.5, 2.0] + results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"]) + assert len(results) > 0 + # All results should belong to doc_0 or doc_1 groups + for result in results: + assert result.metadata["document_id"] in ["doc_0", "doc_1"] + + # Test score threshold + results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5) + # Check that all results have a score above threshold + for result in results: + assert result.metadata.get("score", 0) >= 0.5 + + def test_clickzetta_batch_operations(self, vector_store): + """Test batch insertion operations.""" + # Prepare large batch of documents + batch_size = 25 + documents = [] + embeddings = [] + + for i in range(batch_size): + doc = Document( + page_content=f"Batch document {i}: This is a test document for batch processing.", + metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"}, + ) + documents.append(doc) + embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)]) + + # Test batch insert + vector_store.add_texts(documents=documents, embeddings=embeddings) + + # Verify all documents were inserted + for i in range(batch_size): + assert vector_store.text_exists(f"batch_doc_{i}") + + # Clean up + vector_store.delete_by_metadata_field("batch", "test_batch") + + def test_clickzetta_edge_cases(self, vector_store): + """Test edge cases and error handling.""" + # Test empty operations + vector_store.create(texts=[], embeddings=[]) + vector_store.add_texts(documents=[], embeddings=[]) + vector_store.delete_by_ids([]) + + # Test special characters in content + special_doc = Document( + page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline", + metadata={"doc_id": "special_doc", "test": "edge_case"}, + ) + embeddings = [[0.1, 0.2, 0.3, 0.4]] + + vector_store.add_texts(documents=[special_doc], embeddings=embeddings) + assert vector_store.text_exists("special_doc") + + # Test search with special characters + results = vector_store.search_by_full_text("quotes", top_k=1) + if results: # Full-text search might not be available + assert len(results) > 0 + + # Clean up + vector_store.delete_by_ids(["special_doc"]) + + def test_clickzetta_full_text_search_modes(self, vector_store): + """Test different full-text search capabilities.""" + # Prepare documents with various language content + documents = [ + Document( + page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"} + ), + Document( + page_content="Clickzetta provides powerful Lakehouse solutions", + metadata={"doc_id": "en_doc_1", "lang": "english"}, + ), + Document( + page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"} + ), + Document( + page_content="Modern data architecture includes Lakehouse technology", + metadata={"doc_id": "en_doc_2", "lang": "english"}, + ), + ] + + embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents] + + vector_store.create(texts=documents, embeddings=embeddings) + + # Test Chinese full-text search + results = vector_store.search_by_full_text("Lakehouse", top_k=4) + assert len(results) >= 2 # Should find at least documents with "Lakehouse" + + # Test English full-text search + results = vector_store.search_by_full_text("solutions", top_k=2) + assert len(results) >= 1 # Should find English documents with "solutions" + + # Test mixed search + results = vector_store.search_by_full_text("数据架构", top_k=2) + assert len(results) >= 1 # Should find Chinese documents with this phrase + + # Clean up + vector_store.delete_by_metadata_field("lang", "chinese") + vector_store.delete_by_metadata_field("lang", "english") diff --git a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py new file mode 100644 index 0000000000..ef54eaa174 --- /dev/null +++ b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +""" +Test Clickzetta integration in Docker environment +""" + +import os +import time + +import requests +from clickzetta import connect + + +def test_clickzetta_connection(): + """Test direct connection to Clickzetta""" + print("=== Testing direct Clickzetta connection ===") + try: + conn = connect( + username=os.getenv("CLICKZETTA_USERNAME", "test_user"), + password=os.getenv("CLICKZETTA_PASSWORD", "test_password"), + instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"), + service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"), + workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"), + vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"), + database=os.getenv("CLICKZETTA_SCHEMA", "dify"), + ) + + with conn.cursor() as cursor: + # Test basic connectivity + cursor.execute("SELECT 1 as test") + result = cursor.fetchone() + print(f"✓ Connection test: {result}") + + # Check if our test table exists + cursor.execute("SHOW TABLES IN dify") + tables = cursor.fetchall() + print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}") + + # Check if test collection exists + test_collection = "collection_test_dataset" + if test_collection in [t[1] for t in tables if t[0] == "dify"]: + cursor.execute(f"DESCRIBE dify.{test_collection}") + columns = cursor.fetchall() + print(f"✓ Table structure for {test_collection}:") + for col in columns: + print(f" - {col[0]}: {col[1]}") + + # Check for indexes + cursor.execute(f"SHOW INDEXES IN dify.{test_collection}") + indexes = cursor.fetchall() + print(f"✓ Indexes on {test_collection}:") + for idx in indexes: + print(f" - {idx}") + + return True + except Exception as e: + print(f"✗ Connection test failed: {e}") + return False + + +def test_dify_api(): + """Test Dify API with Clickzetta backend""" + print("\n=== Testing Dify API ===") + base_url = "http://localhost:5001" + + # Wait for API to be ready + max_retries = 30 + for i in range(max_retries): + try: + response = requests.get(f"{base_url}/console/api/health") + if response.status_code == 200: + print("✓ Dify API is ready") + break + except: + if i == max_retries - 1: + print("✗ Dify API is not responding") + return False + time.sleep(2) + + # Check vector store configuration + try: + # This is a simplified check - in production, you'd use proper auth + print("✓ Dify is configured to use Clickzetta as vector store") + return True + except Exception as e: + print(f"✗ API test failed: {e}") + return False + + +def verify_table_structure(): + """Verify the table structure meets Dify requirements""" + print("\n=== Verifying Table Structure ===") + + expected_columns = { + "id": "VARCHAR", + "page_content": "VARCHAR", + "metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta + "vector": "ARRAY", + } + + expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"] + + print("✓ Expected table structure:") + for col, dtype in expected_columns.items(): + print(f" - {col}: {dtype}") + + print("\n✓ Required metadata fields:") + for field in expected_metadata_fields: + print(f" - {field}") + + print("\n✓ Index requirements:") + print(" - Vector index (HNSW) on 'vector' column") + print(" - Full-text index on 'page_content' (optional)") + print(" - Functional index on metadata->>'$.doc_id' (recommended)") + print(" - Functional index on metadata->>'$.document_id' (recommended)") + + return True + + +def main(): + """Run all tests""" + print("Starting Clickzetta integration tests for Dify Docker\n") + + tests = [ + ("Direct Clickzetta Connection", test_clickzetta_connection), + ("Dify API Status", test_dify_api), + ("Table Structure Verification", verify_table_structure), + ] + + results = [] + for test_name, test_func in tests: + try: + success = test_func() + results.append((test_name, success)) + except Exception as e: + print(f"\n✗ {test_name} crashed: {e}") + results.append((test_name, False)) + + # Summary + print("\n" + "=" * 50) + print("Test Summary:") + print("=" * 50) + + passed = sum(1 for _, success in results if success) + total = len(results) + + for test_name, success in results: + status = "✅ PASSED" if success else "❌ FAILED" + print(f"{test_name}: {status}") + + print(f"\nTotal: {passed}/{total} tests passed") + + if passed == total: + print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.") + print("\nNext steps:") + print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d") + print("2. Access Dify at http://localhost:3000") + print("3. Create a dataset and test vector storage with Clickzetta") + return 0 + else: + print("\n⚠️ Some tests failed. Please check the errors above.") + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py new file mode 100644 index 0000000000..0ab5f398e3 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -0,0 +1,1252 @@ +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import NotFound + +from models.model import MessageAnnotation +from services.annotation_service import AppAnnotationService +from services.app_service import AppService + + +class TestAnnotationService: + """Integration tests for AnnotationService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + patch("services.annotation_service.FeatureService") as mock_feature_service, + patch("services.annotation_service.add_annotation_to_index_task") as mock_add_task, + patch("services.annotation_service.update_annotation_to_index_task") as mock_update_task, + patch("services.annotation_service.delete_annotation_index_task") as mock_delete_task, + patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, + patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, + patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, + patch("services.annotation_service.current_user") as mock_current_user, + ): + # Setup default mock returns + mock_account_feature_service.get_features.return_value.billing.enabled = False + mock_add_task.delay.return_value = None + mock_update_task.delay.return_value = None + mock_delete_task.delay.return_value = None + mock_enable_task.delay.return_value = None + mock_disable_task.delay.return_value = None + mock_batch_import_task.delay.return_value = None + + yield { + "account_feature_service": mock_account_feature_service, + "feature_service": mock_feature_service, + "add_task": mock_add_task, + "update_task": mock_update_task, + "delete_task": mock_delete_task, + "enable_task": mock_enable_task, + "disable_task": mock_disable_task, + "batch_import_task": mock_batch_import_task, + "current_user": mock_current_user, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant first + from services.account_service import AccountService, TenantService + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + # Create app + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Setup current_user mock + self._mock_current_user(mock_external_service_dependencies, account.id, tenant.id) + + return app, account + + def _mock_current_user(self, mock_external_service_dependencies, account_id, tenant_id): + """ + Helper method to mock the current user for testing. + """ + mock_external_service_dependencies["current_user"].id = account_id + mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id + + def _create_test_conversation(self, app, account, fake): + """ + Helper method to create a test conversation with all required fields. + """ + from extensions.ext_database import db + from models.model import Conversation + + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=fake.sentence(), + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from="console", + from_source="console", + from_end_user_id=None, + from_account_id=account.id, + ) + + db.session.add(conversation) + db.session.flush() + return conversation + + def _create_test_message(self, app, conversation, account, fake): + """ + Helper method to create a test message with all required fields. + """ + import json + + from extensions.ext_database import db + from models.model import Message + + message = Message( + app_id=app.id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation.id, + inputs={}, + query=fake.sentence(), + message=json.dumps([{"role": "user", "text": fake.sentence()}]), + message_tokens=0, + message_unit_price=0, + message_price_unit=0.001, + answer=fake.text(max_nb_chars=200), + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0.001, + parent_message_id=None, + provider_response_latency=0, + total_price=0, + currency="USD", + invoke_from="console", + from_source="console", + from_end_user_id=None, + from_account_id=account.id, + ) + + db.session.add(message) + db.session.commit() + return message + + def test_insert_app_annotation_directly_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct insertion of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation directly + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + assert annotation.hit_count == 0 + assert annotation.id is not None + + # Verify annotation was saved to database + from extensions.ext_database import db + + db.session.refresh(annotation) + assert annotation.id is not None + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_insert_app_annotation_directly_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test direct insertion of app annotation when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Try to insert annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id) + + def test_update_app_annotation_directly_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct update of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # First, create an annotation + original_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(original_args, app.id) + + # Update the annotation + updated_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + + # Verify annotation was updated correctly + assert updated_annotation.id == annotation.id + assert updated_annotation.app_id == app.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.account_id == account.id + + # Verify original values were changed + assert updated_annotation.question != original_args["question"] + assert updated_annotation.content != original_args["answer"] + + # Verify update_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["update_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_new( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating new annotation from message. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Setup annotation data with message_id + annotation_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation from message + annotation = AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.conversation_id == conversation.id + assert annotation.message_id == message.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_update( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test updating existing annotation from message. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Create initial annotation + initial_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + initial_annotation = AppAnnotationService.up_insert_app_annotation_from_message(initial_args, app.id) + + # Update the annotation + updated_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.up_insert_app_annotation_from_message(updated_args, app.id) + + # Verify annotation was updated correctly (same ID) + assert updated_annotation.id == initial_annotation.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.question != initial_args["question"] + assert updated_annotation.content != initial_args["answer"] + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating annotation from message when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Try to insert annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id) + + def test_get_annotation_list_by_app_id_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful retrieval of annotation list by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple annotations + annotations = [] + for i in range(3): + annotation_args = { + "question": f"Question {i}: {fake.sentence()}", + "answer": f"Answer {i}: {fake.text(max_nb_chars=200)}", + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotations.append(annotation) + + # Get annotation list + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="" + ) + + # Verify results + assert len(annotation_list) == 3 + assert total == 3 + + # Verify all annotations belong to the correct app + for annotation in annotation_list: + assert annotation.app_id == app.id + assert annotation.account_id == account.id + + def test_get_annotation_list_by_app_id_with_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test retrieval of annotation list with keyword search. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotations with specific keywords + unique_keyword = fake.word() + annotation_args = { + "question": f"Question with {unique_keyword} keyword", + "answer": f"Answer with {unique_keyword} keyword", + } + AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Create another annotation without the keyword + other_args = { + "question": "Question without keyword", + "answer": "Answer without keyword", + } + AppAnnotationService.insert_app_annotation_directly(other_args, app.id) + + # Search with keyword + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword=unique_keyword + ) + + # Verify only matching annotations are returned + assert len(annotation_list) == 1 + assert total == 1 + assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content + + def test_get_annotation_list_by_app_id_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test retrieval of annotation list when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to get annotation list with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="") + + def test_delete_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful deletion of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotation_id = annotation.id + + # Delete the annotation + AppAnnotationService.delete_app_annotation(app.id, annotation_id) + + # Verify annotation was deleted + from extensions.ext_database import db + + deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + assert deleted_annotation is None + + # Verify delete_annotation_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["delete_task"].delay.assert_not_called() + + def test_delete_app_annotation_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deletion of app annotation when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + annotation_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to delete annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id) + + def test_delete_app_annotation_annotation_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test deletion of app annotation when annotation is not found. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + non_existent_annotation_id = fake.uuid4() + + # Try to delete non-existent annotation + with pytest.raises(NotFound, match="Annotation not found"): + AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id) + + def test_enable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful enabling of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Setup enable arguments + enable_args = { + "score_threshold": 0.8, + "embedding_provider_name": "openai", + "embedding_model_name": "text-embedding-ada-002", + } + + # Enable annotation + result = AppAnnotationService.enable_app_annotation(enable_args, app.id) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["enable_task"].delay.assert_called_once() + + def test_disable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful disabling of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Disable annotation + result = AppAnnotationService.disable_app_annotation(app.id) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["disable_task"].delay.assert_called_once() + + def test_enable_app_annotation_cached_job(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test enabling app annotation when job is already cached. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock Redis to return cached job + from extensions.ext_redis import redis_client + + cached_job_id = fake.uuid4() + enable_app_annotation_key = f"enable_app_annotation_{app.id}" + redis_client.set(enable_app_annotation_key, cached_job_id) + + # Setup enable arguments + enable_args = { + "score_threshold": 0.8, + "embedding_provider_name": "openai", + "embedding_model_name": "text-embedding-ada-002", + } + + # Enable annotation (should return cached job) + result = AppAnnotationService.enable_app_annotation(enable_args, app.id) + + # Verify cached result + assert cached_job_id == result["job_id"].decode("utf-8") + assert result["job_status"] == "processing" + + # Verify task was not called again + mock_external_service_dependencies["enable_task"].delay.assert_not_called() + + # Clean up + redis_client.delete(enable_app_annotation_key) + + def test_get_annotation_hit_histories_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of annotation hit histories. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Add some hit histories + for i in range(3): + AppAnnotationService.add_annotation_history( + annotation_id=annotation.id, + app_id=app.id, + annotation_question=annotation.question, + annotation_content=annotation.content, + query=f"Query {i}: {fake.sentence()}", + user_id=account.id, + message_id=fake.uuid4(), + from_source="console", + score=0.8 + (i * 0.1), + ) + + # Get hit histories + hit_histories, total = AppAnnotationService.get_annotation_hit_histories( + app.id, annotation.id, page=1, limit=10 + ) + + # Verify results + assert len(hit_histories) == 3 + assert total == 3 + + # Verify all histories belong to the correct annotation + for history in hit_histories: + assert history.annotation_id == annotation.id + assert history.app_id == app.id + assert history.account_id == account.id + + def test_add_annotation_history_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful addition of annotation history. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Get initial hit count + initial_hit_count = annotation.hit_count + + # Add annotation history + query = fake.sentence() + message_id = fake.uuid4() + score = 0.85 + + AppAnnotationService.add_annotation_history( + annotation_id=annotation.id, + app_id=app.id, + annotation_question=annotation.question, + annotation_content=annotation.content, + query=query, + user_id=account.id, + message_id=message_id, + from_source="console", + score=score, + ) + + # Verify hit count was incremented + from extensions.ext_database import db + + db.session.refresh(annotation) + assert annotation.hit_count == initial_hit_count + 1 + + # Verify history was created + from models.model import AppAnnotationHitHistory + + history = ( + db.session.query(AppAnnotationHitHistory) + .filter( + AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id + ) + .first() + ) + + assert history is not None + assert history.app_id == app.id + assert history.account_id == account.id + assert history.question == query + assert history.score == score + assert history.source == "console" + + def test_get_annotation_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of annotation by ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + created_annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Get annotation by ID + retrieved_annotation = AppAnnotationService.get_annotation_by_id(created_annotation.id) + + # Verify annotation was retrieved correctly + assert retrieved_annotation is not None + assert retrieved_annotation.id == created_annotation.id + assert retrieved_annotation.app_id == app.id + assert retrieved_annotation.question == annotation_args["question"] + assert retrieved_annotation.content == annotation_args["answer"] + assert retrieved_annotation.account_id == account.id + + def test_batch_import_app_annotations_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful batch import of app annotations. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create CSV content + csv_content = "Question 1,Answer 1\nQuestion 2,Answer 2\nQuestion 3,Answer 3" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + mock_external_service_dependencies["feature_service"].get_features.return_value.billing.enabled = False + + # Mock pandas to return expected DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame( + {0: ["Question 1", "Question 2", "Question 3"], 1: ["Answer 1", "Answer 2", "Answer 3"]} + ) + mock_pd.read_csv.return_value = mock_df + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["batch_import_task"].delay.assert_called_once() + + def test_batch_import_app_annotations_empty_file( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test batch import with empty CSV file. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create empty CSV content + csv_content = "" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + # Mock pandas to return empty DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame() + mock_pd.read_csv.return_value = mock_df + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify error result + assert "error_msg" in result + assert "empty" in result["error_msg"].lower() + + def test_batch_import_app_annotations_quota_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test batch import when quota is exceeded. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create CSV content + csv_content = "Question 1,Answer 1\nQuestion 2,Answer 2\nQuestion 3,Answer 3" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + # Mock pandas to return DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame( + {0: ["Question 1", "Question 2", "Question 3"], 1: ["Answer 1", "Answer 2", "Answer 3"]} + ) + mock_pd.read_csv.return_value = mock_df + + # Mock FeatureService to return billing enabled with quota exceeded + mock_external_service_dependencies["feature_service"].get_features.return_value.billing.enabled = True + mock_external_service_dependencies[ + "feature_service" + ].get_features.return_value.annotation_quota_limit.limit = 1 + mock_external_service_dependencies[ + "feature_service" + ].get_features.return_value.annotation_quota_limit.size = 0 + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify error result + assert "error_msg" in result + assert "limit" in result["error_msg"].lower() + + def test_get_app_annotation_setting_by_app_id_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting enabled app annotation setting by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Get annotation setting + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Verify result structure + assert result["enabled"] is True + assert result["id"] == annotation_setting.id + assert result["score_threshold"] == 0.8 + assert result["embedding_model"]["embedding_provider_name"] == "openai" + assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" + + def test_get_app_annotation_setting_by_app_id_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting disabled app annotation setting by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Get annotation setting (no setting exists) + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Verify result structure + assert result["enabled"] is False + + def test_update_app_annotation_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful update of app annotation setting. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Update annotation setting + update_args = { + "score_threshold": 0.9, + } + + result = AppAnnotationService.update_app_annotation_setting(app.id, annotation_setting.id, update_args) + + # Verify result structure + assert result["enabled"] is True + assert result["id"] == annotation_setting.id + assert result["score_threshold"] == 0.9 + assert result["embedding_model"]["embedding_provider_name"] == "openai" + assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" + + # Verify database was updated + db.session.refresh(annotation_setting) + assert annotation_setting.score_threshold == 0.9 + + def test_export_annotation_list_by_app_id_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful export of annotation list by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple annotations + annotations = [] + for i in range(3): + annotation_args = { + "question": f"Question {i}: {fake.sentence()}", + "answer": f"Answer {i}: {fake.text(max_nb_chars=200)}", + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotations.append(annotation) + + # Export annotation list + exported_annotations = AppAnnotationService.export_annotation_list_by_app_id(app.id) + + # Verify results + assert len(exported_annotations) == 3 + + # Verify all annotations belong to the correct app and are ordered by created_at desc + for i, annotation in enumerate(exported_annotations): + assert annotation.app_id == app.id + assert annotation.account_id == account.id + if i > 0: + # Verify descending order (newer first) + assert annotation.created_at <= exported_annotations[i - 1].created_at + + def test_export_annotation_list_by_app_id_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test export of annotation list when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to export annotation list with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id) + + def test_insert_app_annotation_directly_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct insertion of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation directly + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + assert annotation.hit_count == 0 + assert annotation.id is not None + + # Verify add_annotation_to_index_task was called + mock_external_service_dependencies["add_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["add_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == annotation_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id + + def test_update_app_annotation_directly_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct update of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # First, create an annotation + original_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(original_args, app.id) + + # Reset mock to clear previous calls + mock_external_service_dependencies["update_task"].delay.reset_mock() + + # Update the annotation + updated_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + + # Verify annotation was updated correctly + assert updated_annotation.id == annotation.id + assert updated_annotation.app_id == app.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.account_id == account.id + + # Verify original values were changed + assert updated_annotation.question != original_args["question"] + assert updated_annotation.content != original_args["answer"] + + # Verify update_annotation_to_index_task was called + mock_external_service_dependencies["update_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["update_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == updated_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id + + def test_delete_app_annotation_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful deletion of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotation_id = annotation.id + + # Reset mock to clear previous calls + mock_external_service_dependencies["delete_task"].delay.reset_mock() + + # Delete the annotation + AppAnnotationService.delete_app_annotation(app.id, annotation_id) + + # Verify annotation was deleted + deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + assert deleted_annotation is None + + # Verify delete_annotation_index_task was called + mock_external_service_dependencies["delete_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["delete_task"].delay.call_args[0] + assert call_args[0] == annotation_id # annotation_id + assert call_args[1] == app.id # app_id + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == collection_binding.id # collection_binding_id + + def test_up_insert_app_annotation_from_message_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating annotation from message with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Setup annotation data with message_id + annotation_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation from message + annotation = AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.conversation_id == conversation.id + assert annotation.message_id == message.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + + # Verify add_annotation_to_index_task was called + mock_external_service_dependencies["add_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["add_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == annotation_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py new file mode 100644 index 0000000000..38f532fd64 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -0,0 +1,487 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from models.api_based_extension import APIBasedExtension +from services.account_service import AccountService, TenantService +from services.api_based_extension_service import APIBasedExtensionService + + +class TestAPIBasedExtensionService: + """Integration tests for APIBasedExtensionService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + patch("services.api_based_extension_service.APIBasedExtensionRequestor") as mock_requestor, + ): + # Setup default mock returns + mock_account_feature_service.get_features.return_value.billing.enabled = False + + # Mock successful ping response + mock_requestor_instance = mock_requestor.return_value + mock_requestor_instance.request.return_value = {"result": "pong"} + + yield { + "account_feature_service": mock_account_feature_service, + "requestor": mock_requestor, + "requestor_instance": mock_requestor_instance, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + return account, tenant + + def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful saving of API-based extension. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + # Save extension + saved_extension = APIBasedExtensionService.save(extension_data) + + # Verify extension was saved correctly + assert saved_extension.id is not None + assert saved_extension.tenant_id == tenant.id + assert saved_extension.name == extension_data.name + assert saved_extension.api_endpoint == extension_data.api_endpoint + assert saved_extension.api_key == extension_data.api_key # Should be decrypted when retrieved + assert saved_extension.created_at is not None + + # Verify extension was saved to database + from extensions.ext_database import db + + db.session.refresh(saved_extension) + assert saved_extension.id is not None + + # Verify ping connection was called + mock_external_service_dependencies["requestor_instance"].request.assert_called_once() + + def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation errors when saving extension with invalid data. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Test empty name + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = "" + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test empty api_endpoint + extension_data.name = fake.company() + extension_data.api_endpoint = "" + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test empty api_key + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = "" + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService.save(extension_data) + + def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of all extensions by tenant ID. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create multiple extensions + extensions = [] + for i in range(3): + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = f"Extension {i}: {fake.company()}" + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + saved_extension = APIBasedExtensionService.save(extension_data) + extensions.append(saved_extension) + + # Get all extensions for tenant + extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id) + + # Verify results + assert len(extension_list) == 3 + + # Verify all extensions belong to the correct tenant and are ordered by created_at desc + for i, extension in enumerate(extension_list): + assert extension.tenant_id == tenant.id + assert extension.api_key is not None # Should be decrypted + if i > 0: + # Verify descending order (newer first) + assert extension.created_at <= extension_list[i - 1].created_at + + def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of extension by tenant ID and extension ID. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create an extension + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + + # Get extension by ID + retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id) + + # Verify extension was retrieved correctly + assert retrieved_extension is not None + assert retrieved_extension.id == created_extension.id + assert retrieved_extension.tenant_id == tenant.id + assert retrieved_extension.name == extension_data.name + assert retrieved_extension.api_endpoint == extension_data.api_endpoint + assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted + assert retrieved_extension.created_at is not None + + def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of extension when extension is not found. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + non_existent_extension_id = fake.uuid4() + + # Try to get non-existent extension + with pytest.raises(ValueError, match="API based extension is not found"): + APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id) + + def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful deletion of extension. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create an extension first + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + extension_id = created_extension.id + + # Delete the extension + APIBasedExtensionService.delete(created_extension) + + # Verify extension was deleted + from extensions.ext_database import db + + deleted_extension = db.session.query(APIBasedExtension).filter(APIBasedExtension.id == extension_id).first() + assert deleted_extension is None + + def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation error when saving extension with duplicate name. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create first extension + extension_data1 = APIBasedExtension() + extension_data1.tenant_id = tenant.id + extension_data1.name = "Test Extension" + extension_data1.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data1.api_key = fake.password(length=20) + + APIBasedExtensionService.save(extension_data1) + + # Try to create second extension with same name + extension_data2 = APIBasedExtension() + extension_data2.tenant_id = tenant.id + extension_data2.name = "Test Extension" # Same name + extension_data2.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data2.api_key = fake.password(length=20) + + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService.save(extension_data2) + + def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful update of existing extension. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create initial extension + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + + # Save original values for later comparison + original_name = created_extension.name + original_endpoint = created_extension.api_endpoint + + # Update the extension + new_name = fake.company() + new_endpoint = f"https://{fake.domain_name()}/api" + new_api_key = fake.password(length=20) + + created_extension.name = new_name + created_extension.api_endpoint = new_endpoint + created_extension.api_key = new_api_key + + updated_extension = APIBasedExtensionService.save(created_extension) + + # Verify extension was updated correctly + assert updated_extension.id == created_extension.id + assert updated_extension.tenant_id == tenant.id + assert updated_extension.name == new_name + assert updated_extension.api_endpoint == new_endpoint + + # Verify original values were changed + assert updated_extension.name != original_name + assert updated_extension.api_endpoint != original_endpoint + + # Verify ping connection was called for both create and update + assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2 + + # Verify the update by retrieving the extension again + retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id) + assert retrieved_extension.name == new_name + assert retrieved_extension.api_endpoint == new_endpoint + assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved + + def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test connection error when saving extension with invalid endpoint. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock connection error + mock_external_service_dependencies["requestor_instance"].request.side_effect = ValueError( + "connection error: request timeout" + ) + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = "https://invalid-endpoint.com/api" + extension_data.api_key = fake.password(length=20) + + # Try to save extension with connection error + with pytest.raises(ValueError, match="connection error: request timeout"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_invalid_api_key_length( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test validation error when saving extension with API key that is too short. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup extension data with short API key + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = "1234" # Less than 5 characters + + # Try to save extension with short API key + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation errors when saving extension with empty required fields. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Test with None values + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = None + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test with None api_endpoint + extension_data.name = fake.company() + extension_data.api_endpoint = None + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test with None api_key + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = None + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService.save(extension_data) + + def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of extensions when no extensions exist for tenant. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Get all extensions for tenant (none exist) + extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id) + + # Verify empty list is returned + assert len(extension_list) == 0 + assert extension_list == [] + + def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation error when ping response is invalid. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock invalid ping response + mock_external_service_dependencies["requestor_instance"].request.return_value = {"result": "invalid"} + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + # Try to save extension with invalid ping response + with pytest.raises(ValueError, match="{'result': 'invalid'}"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation error when ping response is missing result field. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock ping response without result field + mock_external_service_dependencies["requestor_instance"].request.return_value = {"status": "ok"} + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + # Try to save extension with missing ping result + with pytest.raises(ValueError, match="{'status': 'ok'}"): + APIBasedExtensionService.save(extension_data) + + def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of extension when tenant ID doesn't match. + """ + fake = Faker() + account1, tenant1 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create second account and tenant + account2, tenant2 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create extension in first tenant + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant1.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + + # Try to get extension with wrong tenant ID + with pytest.raises(ValueError, match="API based extension is not found"): + APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id) diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py new file mode 100644 index 0000000000..f2bd9f8084 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -0,0 +1,473 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +import yaml +from faker import Faker + +from models.model import App, AppModelConfig +from services.account_service import AccountService, TenantService +from services.app_dsl_service import AppDslService, ImportMode, ImportStatus +from services.app_service import AppService + + +class TestAppDslService: + """Integration tests for AppDslService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.app_dsl_service.WorkflowService") as mock_workflow_service, + patch("services.app_dsl_service.DependenciesAnalysisService") as mock_dependencies_service, + patch("services.app_dsl_service.WorkflowDraftVariableService") as mock_draft_variable_service, + patch("services.app_dsl_service.ssrf_proxy") as mock_ssrf_proxy, + patch("services.app_dsl_service.redis_client") as mock_redis_client, + patch("services.app_dsl_service.app_was_created") as mock_app_was_created, + patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + ): + # Setup default mock returns + mock_workflow_service.return_value.get_draft_workflow.return_value = None + mock_workflow_service.return_value.sync_draft_workflow.return_value = MagicMock() + mock_dependencies_service.generate_latest_dependencies.return_value = [] + mock_dependencies_service.get_leaked_dependencies.return_value = [] + mock_dependencies_service.generate_dependencies.return_value = [] + mock_draft_variable_service.return_value.delete_workflow_variables.return_value = None + mock_ssrf_proxy.get.return_value.content = b"test content" + mock_ssrf_proxy.get.return_value.raise_for_status.return_value = None + mock_redis_client.setex.return_value = None + mock_redis_client.get.return_value = None + mock_redis_client.delete.return_value = None + mock_app_was_created.send.return_value = None + mock_app_model_config_was_updated.send.return_value = None + + # Mock ModelManager for app service + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + # Mock FeatureService and EnterpriseService + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + + yield { + "workflow_service": mock_workflow_service, + "dependencies_service": mock_dependencies_service, + "draft_variable_service": mock_draft_variable_service, + "ssrf_proxy": mock_ssrf_proxy, + "redis_client": mock_redis_client, + "app_was_created": mock_app_was_created, + "app_model_config_was_updated": mock_app_model_config_was_updated, + "model_manager": mock_model_manager, + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + with patch("services.account_service.FeatureService") as mock_account_feature_service: + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + # Create app + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + return app, account + + def _create_simple_yaml_content(self, app_name="Test App", app_mode="chat"): + """ + Helper method to create simple YAML content for testing. + """ + yaml_data = { + "version": "0.3.0", + "kind": "app", + "app": { + "name": app_name, + "mode": app_mode, + "icon": "🤖", + "icon_background": "#FFEAD5", + "description": "Test app description", + "use_icon_as_answer_icon": False, + }, + "model_config": { + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": { + "max_tokens": 1000, + "temperature": 0.7, + "top_p": 1.0, + }, + }, + "pre_prompt": "You are a helpful assistant.", + "prompt_type": "simple", + }, + } + return yaml.dump(yaml_data, allow_unicode=True) + + def test_import_app_yaml_content_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app import from YAML content. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create YAML content + yaml_content = self._create_simple_yaml_content(fake.company(), "chat") + + # Import app + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=yaml_content, + name="Imported App", + description="Imported app description", + ) + + # Verify import result + assert result.status == ImportStatus.COMPLETED + assert result.app_id is not None + assert result.app_mode == "chat" + assert result.imported_dsl_version == "0.3.0" + assert result.error == "" + + # Verify app was created in database + imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first() + assert imported_app is not None + assert imported_app.name == "Imported App" + assert imported_app.description == "Imported app description" + assert imported_app.mode == "chat" + assert imported_app.tenant_id == account.current_tenant_id + assert imported_app.created_by == account.id + + # Verify model config was created + model_config = ( + db_session_with_containers.query(AppModelConfig).filter(AppModelConfig.app_id == result.app_id).first() + ) + assert model_config is not None + # The provider and model_id are stored in the model field as JSON + model_dict = model_config.model_dict + assert model_dict["provider"] == "openai" + assert model_dict["name"] == "gpt-3.5-turbo" + + def test_import_app_yaml_url_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app import from YAML URL. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create YAML content for mock response + yaml_content = self._create_simple_yaml_content(fake.company(), "chat") + + # Setup mock response + mock_response = MagicMock() + mock_response.content = yaml_content.encode("utf-8") + mock_response.raise_for_status.return_value = None + mock_external_service_dependencies["ssrf_proxy"].get.return_value = mock_response + + # Import app from URL + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/app.yaml", + name="URL Imported App", + description="App imported from URL", + ) + + # Verify import result + assert result.status == ImportStatus.COMPLETED + assert result.app_id is not None + assert result.app_mode == "chat" + assert result.imported_dsl_version == "0.3.0" + assert result.error == "" + + # Verify app was created in database + imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first() + assert imported_app is not None + assert imported_app.name == "URL Imported App" + assert imported_app.description == "App imported from URL" + assert imported_app.mode == "chat" + assert imported_app.tenant_id == account.current_tenant_id + + # Verify ssrf_proxy was called + mock_external_service_dependencies["ssrf_proxy"].get.assert_called_once_with( + "https://example.com/app.yaml", follow_redirects=True, timeout=(10, 10) + ) + + def test_import_app_invalid_yaml_format(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with invalid YAML format. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create invalid YAML content + invalid_yaml = "invalid: yaml: content: [" + + # Import app with invalid YAML + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=invalid_yaml, + name="Invalid App", + ) + + # Verify import failed + assert result.status == ImportStatus.FAILED + assert result.app_id is None + assert "Invalid YAML format" in result.error + assert result.imported_dsl_version == "" + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with missing YAML content. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Import app without YAML content + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + name="Missing Content App", + ) + + # Verify import failed + assert result.status == ImportStatus.FAILED + assert result.app_id is None + assert "yaml_content is required" in result.error + assert result.imported_dsl_version == "" + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with missing YAML URL. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Import app without YAML URL + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_URL, + name="Missing URL App", + ) + + # Verify import failed + assert result.status == ImportStatus.FAILED + assert result.app_id is None + assert "yaml_url is required" in result.error + assert result.imported_dsl_version == "" + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with invalid import mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create YAML content + yaml_content = self._create_simple_yaml_content(fake.company(), "chat") + + # Import app with invalid mode should raise ValueError + dsl_service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Invalid import_mode: invalid-mode"): + dsl_service.import_app( + account=account, + import_mode="invalid-mode", + yaml_content=yaml_content, + name="Invalid Mode App", + ) + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful DSL export for chat app. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create model config for the app + model_config = AppModelConfig() + model_config.id = fake.uuid4() + model_config.app_id = app.id + model_config.provider = "openai" + model_config.model_id = "gpt-3.5-turbo" + model_config.model = json.dumps( + { + "provider": "openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": { + "max_tokens": 1000, + "temperature": 0.7, + }, + } + ) + model_config.pre_prompt = "You are a helpful assistant." + model_config.prompt_type = "simple" + model_config.created_by = account.id + model_config.updated_by = account.id + + # Set the app_model_config_id to link the config + app.app_model_config_id = model_config.id + + db_session_with_containers.add(model_config) + db_session_with_containers.commit() + + # Export DSL + exported_dsl = AppDslService.export_dsl(app, include_secret=False) + + # Parse exported YAML + exported_data = yaml.safe_load(exported_dsl) + + # Verify exported data structure + assert exported_data["kind"] == "app" + assert exported_data["app"]["name"] == app.name + assert exported_data["app"]["mode"] == app.mode + assert exported_data["app"]["icon"] == app.icon + assert exported_data["app"]["icon_background"] == app.icon_background + assert exported_data["app"]["description"] == app.description + + # Verify model config was exported + assert "model_config" in exported_data + # The exported model_config structure may be different from the database structure + # Check that the model config exists and has the expected content + assert exported_data["model_config"] is not None + + # Verify dependencies were exported + assert "dependencies" in exported_data + assert isinstance(exported_data["dependencies"], list) + + def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful DSL export for workflow app. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Update app to workflow mode + app.mode = "workflow" + db_session_with_containers.commit() + + # Mock workflow service to return a workflow + mock_workflow = MagicMock() + mock_workflow.to_dict.return_value = { + "graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []}, + "features": {}, + "environment_variables": [], + "conversation_variables": [], + } + mock_external_service_dependencies[ + "workflow_service" + ].return_value.get_draft_workflow.return_value = mock_workflow + + # Export DSL + exported_dsl = AppDslService.export_dsl(app, include_secret=False) + + # Parse exported YAML + exported_data = yaml.safe_load(exported_dsl) + + # Verify exported data structure + assert exported_data["kind"] == "app" + assert exported_data["app"]["name"] == app.name + assert exported_data["app"]["mode"] == "workflow" + + # Verify workflow was exported + assert "workflow" in exported_data + assert "graph" in exported_data["workflow"] + assert "nodes" in exported_data["workflow"]["graph"] + + # Verify dependencies were exported + assert "dependencies" in exported_data + assert isinstance(exported_data["dependencies"], list) + + # Verify workflow service was called + mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( + app + ) + + def test_check_dependencies_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful dependency checking. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock Redis to return dependencies + mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' + mock_external_service_dependencies["redis_client"].get.return_value = mock_dependencies_json + + # Check dependencies + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.check_dependencies(app_model=app) + + # Verify result + assert result.leaked_dependencies == [] + + # Verify Redis was queried + mock_external_service_dependencies["redis_client"].get.assert_called_once_with( + f"app_check_dependencies:{app.id}" + ) + + # Verify dependencies service was called + mock_external_service_dependencies["dependencies_service"].get_leaked_dependencies.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py new file mode 100644 index 0000000000..69cd9fafee --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -0,0 +1,928 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from constants.model_template import default_app_templates +from models.model import App, Site +from services.account_service import AccountService, TenantService +from services.app_service import AppService + + +class TestAppService: + """Integration tests for AppService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns for app service + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Mock ModelManager for model configuration + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app creation with basic parameters. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + # Create app + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Verify app was created correctly + assert app.name == app_args["name"] + assert app.description == app_args["description"] + assert app.mode == app_args["mode"] + assert app.icon_type == app_args["icon_type"] + assert app.icon == app_args["icon"] + assert app.icon_background == app_args["icon_background"] + assert app.tenant_id == tenant.id + assert app.api_rph == app_args["api_rph"] + assert app.api_rpm == app_args["api_rpm"] + assert app.created_by == account.id + assert app.updated_by == account.id + assert app.status == "normal" + assert app.enable_site is True + assert app.enable_api is True + assert app.is_demo is False + assert app.is_public is False + assert app.is_universal is False + + def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app creation with different app modes. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Test different app modes + # from AppMode enum in default_app_model_template + app_modes = [v.value for v in default_app_templates] + + for mode in app_modes: + app_args = { + "name": f"{fake.company()} {mode}", + "description": f"Test app for {mode} mode", + "mode": mode, + "icon_type": "emoji", + "icon": "🚀", + "icon_background": "#4ECDC4", + } + + app = app_service.create_app(tenant.id, app_args, account) + + # Verify app mode was set correctly + assert app.mode == mode + assert app.name == app_args["name"] + assert app.tenant_id == tenant.id + assert app.created_by == account.id + + def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app retrieval. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + created_app = app_service.create_app(tenant.id, app_args, account) + + # Get app using the service + retrieved_app = app_service.get_app(created_app) + + # Verify retrieved app matches created app + assert retrieved_app.id == created_app.id + assert retrieved_app.name == created_app.name + assert retrieved_app.description == created_app.description + assert retrieved_app.mode == created_app.mode + assert retrieved_app.tenant_id == created_app.tenant_id + assert retrieved_app.created_by == created_app.created_by + + def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful paginated app list retrieval. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Create multiple apps + app_names = [fake.company() for _ in range(5)] + for name in app_names: + app_args = { + "name": name, + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "📱", + "icon_background": "#96CEB4", + } + app_service.create_app(tenant.id, app_args, account) + + # Get paginated apps + args = { + "page": 1, + "limit": 10, + "mode": "chat", + } + + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + + # Verify pagination results + assert paginated_apps is not None + assert len(paginated_apps.items) >= 5 # Should have at least 5 apps + assert paginated_apps.page == 1 + assert paginated_apps.per_page == 10 + + # Verify all apps belong to the correct tenant + for app in paginated_apps.items: + assert app.tenant_id == tenant.id + assert app.mode == "chat" + + def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test paginated app list with various filters. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Create apps with different modes + chat_app_args = { + "name": "Chat App", + "description": "A chat application", + "mode": "chat", + "icon_type": "emoji", + "icon": "💬", + "icon_background": "#FF6B6B", + } + completion_app_args = { + "name": "Completion App", + "description": "A completion application", + "mode": "completion", + "icon_type": "emoji", + "icon": "✍️", + "icon_background": "#4ECDC4", + } + + chat_app = app_service.create_app(tenant.id, chat_app_args, account) + completion_app = app_service.create_app(tenant.id, completion_app_args, account) + + # Test filter by mode + chat_args = { + "page": 1, + "limit": 10, + "mode": "chat", + } + chat_apps = app_service.get_paginate_apps(account.id, tenant.id, chat_args) + assert len(chat_apps.items) == 1 + assert chat_apps.items[0].mode == "chat" + + # Test filter by name + name_args = { + "page": 1, + "limit": 10, + "mode": "chat", + "name": "Chat", + } + filtered_apps = app_service.get_paginate_apps(account.id, tenant.id, name_args) + assert len(filtered_apps.items) == 1 + assert "Chat" in filtered_apps.items[0].name + + # Test filter by created_by_me + created_by_me_args = { + "page": 1, + "limit": 10, + "mode": "completion", + "is_created_by_me": True, + } + my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args) + assert len(my_apps.items) == 1 + + def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test paginated app list with tag filters. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Create an app + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🏷️", + "icon_background": "#FFEAA7", + } + app = app_service.create_app(tenant.id, app_args, account) + + # Mock TagService to return the app ID for tag filtering + with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service: + mock_tag_service.return_value = [app.id] + + # Test with tag filter + args = { + "page": 1, + "limit": 10, + "mode": "chat", + "tag_ids": ["tag1", "tag2"], + } + + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + + # Verify tag service was called + mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"]) + + # Verify results + assert paginated_apps is not None + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].id == app.id + + # Test with tag filter that returns no results + with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service: + mock_tag_service.return_value = [] + + args = { + "page": 1, + "limit": 10, + "mode": "chat", + "tag_ids": ["nonexistent_tag"], + } + + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + + # Should return None when no apps match tag filter + assert paginated_apps is None + + def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app update with all fields. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original values + original_name = app.name + original_description = app.description + original_icon = app.icon + original_icon_background = app.icon_background + original_use_icon_as_answer_icon = app.use_icon_as_answer_icon + + # Update app + update_args = { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": "emoji", + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + } + + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app(app, update_args) + + # Verify updated fields + assert updated_app.name == update_args["name"] + assert updated_app.description == update_args["description"] + assert updated_app.icon == update_args["icon"] + assert updated_app.icon_background == update_args["icon_background"] + assert updated_app.use_icon_as_answer_icon is True + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app name update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original name + original_name = app.name + + # Update app name + new_name = "New App Name" + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_name(app, new_name) + + assert updated_app.name == new_name + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app icon update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original values + original_icon = app.icon + original_icon_background = app.icon_background + + # Update app icon + new_icon = "🌟" + new_icon_background = "#FFD93D" + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) + + assert updated_app.icon == new_icon + assert updated_app.icon_background == new_icon_background + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app site status update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🌐", + "icon_background": "#74B9FF", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original site status + original_site_status = app.enable_site + + # Update site status to disabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_site_status(app, False) + assert updated_app.enable_site is False + assert updated_app.updated_by == account.id + + # Update site status back to enabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_site_status(updated_app, True) + assert updated_app.enable_site is True + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app API status update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🔌", + "icon_background": "#A29BFE", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original API status + original_api_status = app.enable_api + + # Update API status to disabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_api_status(app, False) + assert updated_app.enable_api is False + assert updated_app.updated_by == account.id + + # Update API status back to enabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_api_status(updated_app, True) + assert updated_app.enable_api is True + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app site status update when status doesn't change. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🔄", + "icon_background": "#FD79A8", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original values + original_site_status = app.enable_site + original_updated_at = app.updated_at + + # Update site status to the same value (no change) + updated_app = app_service.update_app_site_status(app, original_site_status) + + # Verify app is returned unchanged + assert updated_app.id == app.id + assert updated_app.enable_site == original_site_status + assert updated_app.updated_at == original_updated_at + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app deletion. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🗑️", + "icon_background": "#E17055", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store app ID for verification + app_id = app.id + + # Mock the async deletion task + with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task: + mock_delete_task.delay.return_value = None + + # Delete app + app_service.delete_app(app) + + # Verify async deletion task was called + mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) + + # Verify app was deleted from database + from extensions.ext_database import db + + deleted_app = db.session.query(App).filter_by(id=app_id).first() + assert deleted_app is None + + def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app deletion with related data cleanup. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🧹", + "icon_background": "#00B894", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store app ID for verification + app_id = app.id + + # Mock webapp auth cleanup + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.webapp_auth.enabled = True + + # Mock the async deletion task + with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task: + mock_delete_task.delay.return_value = None + + # Delete app + app_service.delete_app(app) + + # Verify webapp auth cleanup was called + mock_external_service_dependencies["enterprise_service"].WebAppAuth.cleanup_webapp.assert_called_once_with( + app_id + ) + + # Verify async deletion task was called + mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) + + # Verify app was deleted from database + from extensions.ext_database import db + + deleted_app = db.session.query(App).filter_by(id=app_id).first() + assert deleted_app is None + + def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app metadata retrieval. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "📊", + "icon_background": "#6C5CE7", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Get app metadata + app_meta = app_service.get_app_meta(app) + + # Verify metadata contains expected fields + assert "tool_icons" in app_meta + # Note: get_app_meta currently only returns tool_icons + + def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app code retrieval by app ID. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🔗", + "icon_background": "#FDCB6E", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Get app code by ID + app_code = AppService.get_app_code_by_id(app.id) + + # Verify app code was retrieved correctly + # Note: Site would be created when App is created, site.code is auto-generated + assert app_code is not None + assert len(app_code) > 0 + + def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app ID retrieval by app code. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🆔", + "icon_background": "#E84393", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Create a site for the app + site = Site() + site.app_id = app.id + site.code = fake.postalcode() + site.title = fake.company() + site.status = "normal" + site.default_language = "en-US" + site.customize_token_strategy = "uuid" + from extensions.ext_database import db + + db.session.add(site) + db.session.commit() + + # Get app ID by code + app_id = AppService.get_app_id_by_code(site.code) + + # Verify app ID was retrieved correctly + assert app_id == app.id + + def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app creation with invalid mode. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments with invalid mode + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "invalid_mode", # Invalid mode + "icon_type": "emoji", + "icon": "❌", + "icon_background": "#D63031", + } + + app_service = AppService() + + # Attempt to create app with invalid mode + with pytest.raises(ValueError, match="invalid mode value"): + app_service.create_app(tenant.id, app_args, account) diff --git a/api/tests/unit_tests/controllers/console/app/test_description_validation.py b/api/tests/unit_tests/controllers/console/app/test_description_validation.py new file mode 100644 index 0000000000..178267e560 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_description_validation.py @@ -0,0 +1,252 @@ +import pytest + +from controllers.console.app.app import _validate_description_length as app_validate +from controllers.console.datasets.datasets import _validate_description_length as dataset_validate +from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate + + +class TestDescriptionValidationUnit: + """Unit tests for description validation functions in App and Dataset APIs""" + + def test_app_validate_description_length_valid(self): + """Test App validation function with valid descriptions""" + # Empty string should be valid + assert app_validate("") == "" + + # None should be valid + assert app_validate(None) is None + + # Short description should be valid + short_desc = "Short description" + assert app_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert app_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert app_validate(just_under) == just_under + + def test_app_validate_description_length_invalid(self): + """Test App validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + app_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + app_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 1000 characters should fail + very_long = "x" * 1000 + with pytest.raises(ValueError) as exc_info: + app_validate(very_long) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_dataset_validate_description_length_valid(self): + """Test Dataset validation function with valid descriptions""" + # Empty string should be valid + assert dataset_validate("") == "" + + # Short description should be valid + short_desc = "Short description" + assert dataset_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert dataset_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert dataset_validate(just_under) == just_under + + def test_dataset_validate_description_length_invalid(self): + """Test Dataset validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + dataset_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + dataset_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_service_dataset_validate_description_length_valid(self): + """Test Service Dataset validation function with valid descriptions""" + # Empty string should be valid + assert service_dataset_validate("") == "" + + # None should be valid + assert service_dataset_validate(None) is None + + # Short description should be valid + short_desc = "Short description" + assert service_dataset_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert service_dataset_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert service_dataset_validate(just_under) == just_under + + def test_service_dataset_validate_description_length_invalid(self): + """Test Service Dataset validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + service_dataset_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + service_dataset_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_app_dataset_validation_consistency(self): + """Test that App and Dataset validation functions behave identically""" + test_cases = [ + "", # Empty string + "Short description", # Normal description + "x" * 100, # Medium description + "x" * 400, # Exactly at limit + ] + + # Test valid cases produce same results + for test_desc in test_cases: + assert app_validate(test_desc) == dataset_validate(test_desc) == service_dataset_validate(test_desc) + + # Test invalid cases produce same errors + invalid_cases = [ + "x" * 401, # Just over limit + "x" * 500, # Way over limit + "x" * 1000, # Very long + ] + + for invalid_desc in invalid_cases: + app_error = None + dataset_error = None + service_dataset_error = None + + # Capture App validation error + try: + app_validate(invalid_desc) + except ValueError as e: + app_error = str(e) + + # Capture Dataset validation error + try: + dataset_validate(invalid_desc) + except ValueError as e: + dataset_error = str(e) + + # Capture Service Dataset validation error + try: + service_dataset_validate(invalid_desc) + except ValueError as e: + service_dataset_error = str(e) + + # All should produce errors + assert app_error is not None, f"App validation should fail for {len(invalid_desc)} characters" + assert dataset_error is not None, f"Dataset validation should fail for {len(invalid_desc)} characters" + error_msg = f"Service Dataset validation should fail for {len(invalid_desc)} characters" + assert service_dataset_error is not None, error_msg + + # Errors should be identical + error_msg = f"Error messages should be identical for {len(invalid_desc)} characters" + assert app_error == dataset_error == service_dataset_error, error_msg + assert app_error == "Description cannot exceed 400 characters." + + def test_boundary_values(self): + """Test boundary values around the 400 character limit""" + boundary_tests = [ + (0, True), # Empty + (1, True), # Minimum + (399, True), # Just under limit + (400, True), # Exactly at limit + (401, False), # Just over limit + (402, False), # Over limit + (500, False), # Way over limit + ] + + for length, should_pass in boundary_tests: + test_desc = "x" * length + + if should_pass: + # Should not raise exception + assert app_validate(test_desc) == test_desc + assert dataset_validate(test_desc) == test_desc + assert service_dataset_validate(test_desc) == test_desc + else: + # Should raise ValueError + with pytest.raises(ValueError): + app_validate(test_desc) + with pytest.raises(ValueError): + dataset_validate(test_desc) + with pytest.raises(ValueError): + service_dataset_validate(test_desc) + + def test_special_characters(self): + """Test validation with special characters, Unicode, etc.""" + # Unicode characters + unicode_desc = "测试描述" * 100 # Chinese characters + if len(unicode_desc) <= 400: + assert app_validate(unicode_desc) == unicode_desc + assert dataset_validate(unicode_desc) == unicode_desc + assert service_dataset_validate(unicode_desc) == unicode_desc + + # Special characters + special_desc = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?" * 10 + if len(special_desc) <= 400: + assert app_validate(special_desc) == special_desc + assert dataset_validate(special_desc) == special_desc + assert service_dataset_validate(special_desc) == special_desc + + # Mixed content + mixed_desc = "Mixed content: 测试 123 !@# " * 15 + if len(mixed_desc) <= 400: + assert app_validate(mixed_desc) == mixed_desc + assert dataset_validate(mixed_desc) == mixed_desc + assert service_dataset_validate(mixed_desc) == mixed_desc + elif len(mixed_desc) > 400: + with pytest.raises(ValueError): + app_validate(mixed_desc) + with pytest.raises(ValueError): + dataset_validate(mixed_desc) + with pytest.raises(ValueError): + service_dataset_validate(mixed_desc) + + def test_whitespace_handling(self): + """Test validation with various whitespace scenarios""" + # Leading/trailing whitespace + whitespace_desc = " Description with whitespace " + if len(whitespace_desc) <= 400: + assert app_validate(whitespace_desc) == whitespace_desc + assert dataset_validate(whitespace_desc) == whitespace_desc + assert service_dataset_validate(whitespace_desc) == whitespace_desc + + # Newlines and tabs + multiline_desc = "Line 1\nLine 2\tTabbed content" + if len(multiline_desc) <= 400: + assert app_validate(multiline_desc) == multiline_desc + assert dataset_validate(multiline_desc) == multiline_desc + assert service_dataset_validate(multiline_desc) == multiline_desc + + # Only whitespace over limit + only_spaces = " " * 401 + with pytest.raises(ValueError): + app_validate(only_spaces) + with pytest.raises(ValueError): + dataset_validate(only_spaces) + with pytest.raises(ValueError): + service_dataset_validate(only_spaces) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py new file mode 100644 index 0000000000..5c484403a6 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py @@ -0,0 +1,336 @@ +""" +Unit tests for Service API File Preview endpoint +""" + +import uuid +from unittest.mock import Mock, patch + +import pytest + +from controllers.service_api.app.error import FileAccessDeniedError, FileNotFoundError +from controllers.service_api.app.file_preview import FilePreviewApi +from models.model import App, EndUser, Message, MessageFile, UploadFile + + +class TestFilePreviewApi: + """Test suite for FilePreviewApi""" + + @pytest.fixture + def file_preview_api(self): + """Create FilePreviewApi instance for testing""" + return FilePreviewApi() + + @pytest.fixture + def mock_app(self): + """Mock App model""" + app = Mock(spec=App) + app.id = str(uuid.uuid4()) + app.tenant_id = str(uuid.uuid4()) + return app + + @pytest.fixture + def mock_end_user(self): + """Mock EndUser model""" + end_user = Mock(spec=EndUser) + end_user.id = str(uuid.uuid4()) + return end_user + + @pytest.fixture + def mock_upload_file(self): + """Mock UploadFile model""" + upload_file = Mock(spec=UploadFile) + upload_file.id = str(uuid.uuid4()) + upload_file.name = "test_file.jpg" + upload_file.mime_type = "image/jpeg" + upload_file.size = 1024 + upload_file.key = "storage/key/test_file.jpg" + upload_file.tenant_id = str(uuid.uuid4()) + return upload_file + + @pytest.fixture + def mock_message_file(self): + """Mock MessageFile model""" + message_file = Mock(spec=MessageFile) + message_file.id = str(uuid.uuid4()) + message_file.upload_file_id = str(uuid.uuid4()) + message_file.message_id = str(uuid.uuid4()) + return message_file + + @pytest.fixture + def mock_message(self): + """Mock Message model""" + message = Mock(spec=Message) + message.id = str(uuid.uuid4()) + message.app_id = str(uuid.uuid4()) + return message + + def test_validate_file_ownership_success( + self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test successful file ownership validation""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up the mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # Execute the method + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + + # Assertions + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + def test_validate_file_ownership_file_not_found(self, file_preview_api): + """Test file ownership validation when MessageFile not found""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile not found + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Execute and assert exception + with pytest.raises(FileNotFoundError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "File not found in message context" in str(exc_info.value) + + def test_validate_file_ownership_access_denied(self, file_preview_api, mock_message_file): + """Test file ownership validation when Message not owned by app""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile found but Message not owned by app + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query - found + None, # Message query - not found (access denied) + ] + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "not owned by requesting app" in str(exc_info.value) + + def test_validate_file_ownership_upload_file_not_found(self, file_preview_api, mock_message_file, mock_message): + """Test file ownership validation when UploadFile not found""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile and Message found but UploadFile not found + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query - found + mock_message, # Message query - found + None, # UploadFile query - not found + ] + + # Execute and assert exception + with pytest.raises(FileNotFoundError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "Upload file record not found" in str(exc_info.value) + + def test_validate_file_ownership_tenant_mismatch( + self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test file ownership validation with tenant mismatch""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up tenant mismatch + mock_upload_file.tenant_id = "different_tenant_id" + mock_app.tenant_id = "app_tenant_id" + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "tenant mismatch" in str(exc_info.value) + + def test_validate_file_ownership_invalid_input(self, file_preview_api): + """Test file ownership validation with invalid input""" + + # Test with empty file_id + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership("", "app_id") + assert "Invalid file or app identifier" in str(exc_info.value) + + # Test with empty app_id + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership("file_id", "") + assert "Invalid file or app identifier" in str(exc_info.value) + + def test_build_file_response_basic(self, file_preview_api, mock_upload_file): + """Test basic file response building""" + mock_generator = Mock() + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Check response properties + assert response.mimetype == mock_upload_file.mime_type + assert response.direct_passthrough is True + assert response.headers["Content-Length"] == str(mock_upload_file.size) + assert "Cache-Control" in response.headers + + def test_build_file_response_as_attachment(self, file_preview_api, mock_upload_file): + """Test file response building with attachment flag""" + mock_generator = Mock() + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, True) + + # Check attachment-specific headers + assert "attachment" in response.headers["Content-Disposition"] + assert mock_upload_file.name in response.headers["Content-Disposition"] + assert response.headers["Content-Type"] == "application/octet-stream" + + def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file): + """Test file response building for audio/video files""" + mock_generator = Mock() + mock_upload_file.mime_type = "video/mp4" + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Check Range support for media files + assert response.headers["Accept-Ranges"] == "bytes" + + def test_build_file_response_no_size(self, file_preview_api, mock_upload_file): + """Test file response building when size is unknown""" + mock_generator = Mock() + mock_upload_file.size = 0 # Unknown size + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Content-Length should not be set when size is unknown + assert "Content-Length" not in response.headers + + @patch("controllers.service_api.app.file_preview.storage") + def test_get_method_integration( + self, mock_storage, file_preview_api, mock_app, mock_end_user, mock_upload_file, mock_message_file, mock_message + ): + """Test the full GET method integration (without decorator)""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + mock_generator = Mock() + mock_storage.load.return_value = mock_generator + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse: + # Mock request parsing + mock_parser = Mock() + mock_parser.parse_args.return_value = {"as_attachment": False} + mock_reqparse.RequestParser.return_value = mock_parser + + # Test the core logic directly without Flask decorators + # Validate file ownership + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + # Test file response building + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + assert response is not None + + # Verify storage was called correctly + mock_storage.load.assert_not_called() # Since we're testing components separately + + @patch("controllers.service_api.app.file_preview.storage") + def test_storage_error_handling( + self, mock_storage, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test storage error handling in the core logic""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + # Mock storage error + mock_storage.load.side_effect = Exception("Storage error") + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries for validation + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # First validate file ownership works + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + # Test storage error handling + with pytest.raises(Exception) as exc_info: + mock_storage.load(mock_upload_file.key, stream=True) + + assert "Storage error" in str(exc_info.value) + + @patch("controllers.service_api.app.file_preview.logger") + def test_validate_file_ownership_unexpected_error_logging(self, mock_logger, file_preview_api): + """Test that unexpected errors are logged properly""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database query to raise unexpected exception + mock_db.session.query.side_effect = Exception("Unexpected database error") + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + # Verify error message + assert "File access validation failed" in str(exc_info.value) + + # Verify logging was called + mock_logger.exception.assert_called_once_with( + "Unexpected error during file ownership validation", + extra={"file_id": file_id, "app_id": app_id, "error": "Unexpected database error"}, + ) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py new file mode 100644 index 0000000000..da175e7ccd --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -0,0 +1,419 @@ +"""Test conversation variable handling in AdvancedChatAppRunner.""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +from sqlalchemy.orm import Session + +from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.variables import SegmentType +from factories import variable_factory +from models import ConversationVariable, Workflow + + +class TestAdvancedChatAppRunnerConversationVariables: + """Test that AdvancedChatAppRunner correctly handles conversation variables.""" + + def test_missing_conversation_variables_are_added(self): + """Test that new conversation variables added to workflow are created for existing conversations.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with two conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "existing_var", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "new_var", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow with conversation variables + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Create existing conversation variable (only var1 exists in DB) + existing_db_var = MagicMock(spec=ConversationVariable) + existing_db_var.id = "var1" + existing_db_var.app_id = app_id + existing_db_var.conversation_id = conversation_id + existing_db_var.to_variable = MagicMock(return_value=workflow_vars[0]) + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # First query returns only existing variable + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = [existing_db_var] + mock_session.scalars.return_value = mock_scalars_result + + # Track what gets added to session + added_items = [] + + def track_add_all(items): + added_items.extend(items) + + mock_session.add_all.side_effect = track_add_all + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that the missing variable was added + assert len(added_items) == 1, "Should have added exactly one missing variable" + + # Check that the added item is the missing variable (var2) + added_var = added_items[0] + assert hasattr(added_var, "id"), "Added item should be a ConversationVariable" + # Note: Since we're mocking ConversationVariable.from_variable, + # we can't directly check the id, but we can verify add_all was called + assert mock_session.add_all.called, "Session add_all should have been called" + assert mock_session.commit.called, "Session commit should have been called" + + def test_no_variables_creates_all(self): + """Test that all conversation variables are created when none exist in DB.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "var1", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "var2", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # Query returns empty list (no existing variables) + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = [] + mock_session.scalars.return_value = mock_scalars_result + + # Track what gets added to session + added_items = [] + + def track_add_all(items): + added_items.extend(items) + + mock_session.add_all.side_effect = track_add_all + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + patch("core.app.apps.advanced_chat.app_runner.ConversationVariable") as mock_conv_var_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock ConversationVariable.from_variable to return mock objects + mock_conv_vars = [] + for var in workflow_vars: + mock_cv = MagicMock() + mock_cv.id = var.id + mock_cv.to_variable.return_value = var + mock_conv_vars.append(mock_cv) + + mock_conv_var_class.from_variable.side_effect = mock_conv_vars + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that all variables were created + assert len(added_items) == 2, "Should have added both variables" + assert mock_session.add_all.called, "Session add_all should have been called" + assert mock_session.commit.called, "Session commit should have been called" + + def test_all_variables_exist_no_changes(self): + """Test that no changes are made when all variables already exist in DB.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "var1", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "var2", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Create existing conversation variables (both exist in DB) + existing_db_vars = [] + for var in workflow_vars: + db_var = MagicMock(spec=ConversationVariable) + db_var.id = var.id + db_var.app_id = app_id + db_var.conversation_id = conversation_id + db_var.to_variable = MagicMock(return_value=var) + existing_db_vars.append(db_var) + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # Query returns all existing variables + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = existing_db_vars + mock_session.scalars.return_value = mock_scalars_result + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that no variables were added + assert not mock_session.add_all.called, "Session add_all should not have been called" + assert mock_session.commit.called, "Session commit should still be called" diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index bb6d72f51e..3101f7dd34 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -49,7 +49,7 @@ def test_executor_with_json_body_and_number_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == [] + assert executor.params is None assert executor.json == {"number": 42} assert executor.data is None assert executor.files is None @@ -102,7 +102,7 @@ def test_executor_with_json_body_and_object_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == [] + assert executor.params is None assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} assert executor.data is None assert executor.files is None @@ -157,7 +157,7 @@ def test_executor_with_json_body_and_nested_object_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == [] + assert executor.params is None assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} assert executor.data is None assert executor.files is None @@ -245,7 +245,7 @@ def test_executor_with_form_data(): assert executor.url == "https://api.example.com/upload" assert "Content-Type" in executor.headers assert "multipart/form-data" in executor.headers["Content-Type"] - assert executor.params == [] + assert executor.params is None assert executor.json is None # '__multipart_placeholder__' is expected when no file inputs exist, # to ensure the request is treated as multipart/form-data by the backend. diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py new file mode 100644 index 0000000000..9c1c044f03 --- /dev/null +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -0,0 +1,127 @@ +import uuid +from unittest.mock import MagicMock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from services.conversation_service import ConversationService + + +class TestConversationService: + def test_pagination_with_empty_include_ids(self): + """Test that empty include_ids returns empty result""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=[], # Empty include_ids should return empty result + exclude_ids=None, + ) + + assert result.data == [] + assert result.has_more is False + assert result.limit == 20 + + def test_pagination_with_non_empty_include_ids(self): + """Test that non-empty include_ids filters properly""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=["conv1", "conv2"], # Non-empty include_ids + exclude_ids=None, + ) + + # Verify the where clause was called with id.in_ + assert mock_stmt.where.called + + def test_pagination_with_empty_exclude_ids(self): + """Test that empty exclude_ids doesn't filter""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(5)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=None, + exclude_ids=[], # Empty exclude_ids should not filter + ) + + # Result should contain the mocked conversations + assert len(result.data) == 5 + + def test_pagination_with_non_empty_exclude_ids(self): + """Test that non-empty exclude_ids filters properly""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=None, + exclude_ids=["conv1", "conv2"], # Non-empty exclude_ids + ) + + # Verify the where clause was called for exclusion + assert mock_stmt.where.called diff --git a/api/uv.lock b/api/uv.lock index b00e7564f0..16624dc8fd 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -983,6 +983,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/1f/935d0810b73184a1d306f92458cb0a2e9b0de2377f536da874e063b8e422/clickhouse_connect-0.7.19-cp312-cp312-win_amd64.whl", hash = "sha256:b771ca6a473d65103dcae82810d3a62475c5372fc38d8f211513c72b954fb020", size = 239584, upload-time = "2024-08-21T21:36:22.105Z" }, ] +[[package]] +name = "clickzetta-connector-python" +version = "0.8.102" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "future" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "pyarrow" }, + { name = "python-dateutil" }, + { name = "requests" }, + { name = "sqlalchemy" }, + { name = "urllib3" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/e5/23dcc950e873127df0135cf45144062a3207f5d2067259c73854e8ce7228/clickzetta_connector_python-0.8.102-py3-none-any.whl", hash = "sha256:c45486ae77fd82df7113ec67ec50e772372588d79c23757f8ee6291a057994a7", size = 77861, upload-time = "2025-07-17T03:11:59.543Z" }, +] + [[package]] name = "cloudscraper" version = "1.2.71" @@ -1383,6 +1402,7 @@ vdb = [ { name = "alibabacloud-tea-openapi" }, { name = "chromadb" }, { name = "clickhouse-connect" }, + { name = "clickzetta-connector-python" }, { name = "couchbase" }, { name = "elasticsearch" }, { name = "mo-vector" }, @@ -1568,6 +1588,7 @@ vdb = [ { name = "alibabacloud-tea-openapi", specifier = "~=0.3.9" }, { name = "chromadb", specifier = "==0.5.20" }, { name = "clickhouse-connect", specifier = "~=0.7.16" }, + { name = "clickzetta-connector-python", specifier = ">=0.8.102" }, { name = "couchbase", specifier = "~=4.3.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, { name = "mo-vector", specifier = "~=0.1.13" }, @@ -2111,7 +2132,7 @@ wheels = [ [[package]] name = "google-cloud-bigquery" -version = "3.34.0" +version = "3.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core", extra = ["grpc"] }, @@ -2122,9 +2143,9 @@ dependencies = [ { name = "python-dateutil" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/24/f9/e9da2d56d7028f05c0e2f5edf6ce43c773220c3172666c3dd925791d763d/google_cloud_bigquery-3.34.0.tar.gz", hash = "sha256:5ee1a78ba5c2ccb9f9a8b2bf3ed76b378ea68f49b6cac0544dc55cc97ff7c1ce", size = 489091, upload-time = "2025-05-29T17:18:06.03Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/2f/3dda76b3ec029578838b1fe6396e6b86eb574200352240e23dea49265bb7/google_cloud_bigquery-3.30.0.tar.gz", hash = "sha256:7e27fbafc8ed33cc200fe05af12ecd74d279fe3da6692585a3cef7aee90575b6", size = 474389, upload-time = "2025-02-27T18:49:45.416Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/7e/7115c4f67ca0bc678f25bff1eab56cc37d06eb9a3978940b2ebd0705aa0a/google_cloud_bigquery-3.34.0-py3-none-any.whl", hash = "sha256:de20ded0680f8136d92ff5256270b5920dfe4fae479f5d0f73e90e5df30b1cf7", size = 253555, upload-time = "2025-05-29T17:18:02.904Z" }, + { url = "https://files.pythonhosted.org/packages/0c/6d/856a6ca55c1d9d99129786c929a27dd9d31992628ebbff7f5d333352981f/google_cloud_bigquery-3.30.0-py2.py3-none-any.whl", hash = "sha256:f4d28d846a727f20569c9b2d2f4fa703242daadcb2ec4240905aa485ba461877", size = 247885, upload-time = "2025-02-27T18:49:43.454Z" }, ] [[package]] @@ -3918,11 +3939,11 @@ wheels = [ [[package]] name = "packaging" -version = "24.2" +version = "23.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950, upload-time = "2024-11-08T09:47:47.202Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fb/2b/9b9c33ffed44ee921d0967086d653047286054117d584f1b1a7c22ceaf7b/packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5", size = 146714, upload-time = "2023-10-01T13:50:05.279Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451, upload-time = "2024-11-08T09:47:44.722Z" }, + { url = "https://files.pythonhosted.org/packages/ec/1a/610693ac4ee14fcdf2d9bf3c493370e4f2ef7ae2e19217d7a237ff42367d/packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7", size = 53011, upload-time = "2023-10-01T13:50:03.745Z" }, ] [[package]] @@ -4302,6 +4323,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" }, ] +[[package]] +name = "pyarrow" +version = "14.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d7/8b/d18b7eb6fb22e5ed6ffcbc073c85dae635778dbd1270a6cf5d750b031e84/pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025", size = 1063645, upload-time = "2023-12-18T15:43:41.625Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/8a/411ef0b05483076b7f548c74ccaa0f90c1e60d3875db71a821f6ffa8cf42/pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b", size = 26904455, upload-time = "2023-12-18T15:40:43.477Z" }, + { url = "https://files.pythonhosted.org/packages/6c/6c/882a57798877e3a49ba54d8e0540bea24aed78fb42e1d860f08c3449c75e/pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23", size = 23997116, upload-time = "2023-12-18T15:40:48.533Z" }, + { url = "https://files.pythonhosted.org/packages/ec/3f/ef47fe6192ce4d82803a073db449b5292135406c364a7fc49dfbcd34c987/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200", size = 35944575, upload-time = "2023-12-18T15:40:55.128Z" }, + { url = "https://files.pythonhosted.org/packages/1a/90/2021e529d7f234a3909f419d4341d53382541ef77d957fa274a99c533b18/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696", size = 38079719, upload-time = "2023-12-18T15:41:02.565Z" }, + { url = "https://files.pythonhosted.org/packages/30/a9/474caf5fd54a6d5315aaf9284c6e8f5d071ca825325ad64c53137b646e1f/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a", size = 35429706, upload-time = "2023-12-18T15:41:09.955Z" }, + { url = "https://files.pythonhosted.org/packages/d9/f8/cfba56f5353e51c19b0c240380ce39483f4c76e5c4aee5a000f3d75b72da/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02", size = 38001476, upload-time = "2023-12-18T15:41:16.372Z" }, + { url = "https://files.pythonhosted.org/packages/43/3f/7bdf7dc3b3b0cfdcc60760e7880954ba99ccd0bc1e0df806f3dd61bc01cd/pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b", size = 24576230, upload-time = "2023-12-18T15:41:22.561Z" }, + { url = "https://files.pythonhosted.org/packages/69/5b/d8ab6c20c43b598228710e4e4a6cba03a01f6faa3d08afff9ce76fd0fd47/pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944", size = 26819585, upload-time = "2023-12-18T15:41:27.59Z" }, + { url = "https://files.pythonhosted.org/packages/2d/29/bed2643d0dd5e9570405244a61f6db66c7f4704a6e9ce313f84fa5a3675a/pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5", size = 23965222, upload-time = "2023-12-18T15:41:32.449Z" }, + { url = "https://files.pythonhosted.org/packages/2a/34/da464632e59a8cdd083370d69e6c14eae30221acb284f671c6bc9273fadd/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422", size = 35942036, upload-time = "2023-12-18T15:41:38.767Z" }, + { url = "https://files.pythonhosted.org/packages/a8/ff/cbed4836d543b29f00d2355af67575c934999ff1d43e3f438ab0b1b394f1/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07", size = 38089266, upload-time = "2023-12-18T15:41:47.617Z" }, + { url = "https://files.pythonhosted.org/packages/38/41/345011cb831d3dbb2dab762fc244c745a5df94b199223a99af52a5f7dff6/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591", size = 35404468, upload-time = "2023-12-18T15:41:54.49Z" }, + { url = "https://files.pythonhosted.org/packages/fd/af/2fc23ca2068ff02068d8dabf0fb85b6185df40ec825973470e613dbd8790/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379", size = 38003134, upload-time = "2023-12-18T15:42:01.593Z" }, + { url = "https://files.pythonhosted.org/packages/95/1f/9d912f66a87e3864f694e000977a6a70a644ea560289eac1d733983f215d/pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d", size = 25043754, upload-time = "2023-12-18T15:42:07.108Z" }, +] + [[package]] name = "pyasn1" version = "0.6.1" diff --git a/docker/.env.example b/docker/.env.example index 13cac189aa..1b1e9cad7b 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -333,6 +333,25 @@ OPENDAL_SCHEME=fs # Configurations for OpenDAL Local File System. OPENDAL_FS_ROOT=storage +# ClickZetta Volume Configuration (for storage backend) +# To use ClickZetta Volume as storage backend, set STORAGE_TYPE=clickzetta-volume +# Note: ClickZetta Volume will reuse the existing CLICKZETTA_* connection parameters + +# Volume type selection (three types available): +# - user: Personal/small team use, simple config, user-level permissions +# - table: Enterprise multi-tenant, smart routing, table-level + user-level permissions +# - external: Data lake integration, external storage connection, volume-level + storage-level permissions +CLICKZETTA_VOLUME_TYPE=user + +# External Volume name (required only when TYPE=external) +CLICKZETTA_VOLUME_NAME= + +# Table Volume table prefix (used only when TYPE=table) +CLICKZETTA_VOLUME_TABLE_PREFIX=dataset_ + +# Dify file directory prefix (isolates from other apps, recommended to keep default) +CLICKZETTA_VOLUME_DIFY_PREFIX=dify_km + # S3 Configuration # S3_ENDPOINT= @@ -416,7 +435,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -655,6 +674,20 @@ TABLESTORE_ACCESS_KEY_ID=xxx TABLESTORE_ACCESS_KEY_SECRET=xxx TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false +# Clickzetta configuration, only available when VECTOR_STORE is `clickzetta` +CLICKZETTA_USERNAME= +CLICKZETTA_PASSWORD= +CLICKZETTA_INSTANCE= +CLICKZETTA_SERVICE=api.clickzetta.com +CLICKZETTA_WORKSPACE=quick_start +CLICKZETTA_VCLUSTER=default_ap +CLICKZETTA_SCHEMA=dify +CLICKZETTA_BATCH_SIZE=100 +CLICKZETTA_ENABLE_INVERTED_INDEX=true +CLICKZETTA_ANALYZER_TYPE=chinese +CLICKZETTA_ANALYZER_MODE=smart +CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance + # ------------------------------ # Knowledge Configuration # ------------------------------ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 19910cca6f..8e2d40883d 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -93,6 +93,10 @@ x-shared-env: &shared-api-worker-env STORAGE_TYPE: ${STORAGE_TYPE:-opendal} OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs} OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage} + CLICKZETTA_VOLUME_TYPE: ${CLICKZETTA_VOLUME_TYPE:-user} + CLICKZETTA_VOLUME_NAME: ${CLICKZETTA_VOLUME_NAME:-} + CLICKZETTA_VOLUME_TABLE_PREFIX: ${CLICKZETTA_VOLUME_TABLE_PREFIX:-dataset_} + CLICKZETTA_VOLUME_DIFY_PREFIX: ${CLICKZETTA_VOLUME_DIFY_PREFIX:-dify_km} S3_ENDPOINT: ${S3_ENDPOINT:-} S3_REGION: ${S3_REGION:-us-east-1} S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai} @@ -313,6 +317,18 @@ x-shared-env: &shared-api-worker-env TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx} TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx} TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false} + CLICKZETTA_USERNAME: ${CLICKZETTA_USERNAME:-} + CLICKZETTA_PASSWORD: ${CLICKZETTA_PASSWORD:-} + CLICKZETTA_INSTANCE: ${CLICKZETTA_INSTANCE:-} + CLICKZETTA_SERVICE: ${CLICKZETTA_SERVICE:-api.clickzetta.com} + CLICKZETTA_WORKSPACE: ${CLICKZETTA_WORKSPACE:-quick_start} + CLICKZETTA_VCLUSTER: ${CLICKZETTA_VCLUSTER:-default_ap} + CLICKZETTA_SCHEMA: ${CLICKZETTA_SCHEMA:-dify} + CLICKZETTA_BATCH_SIZE: ${CLICKZETTA_BATCH_SIZE:-100} + CLICKZETTA_ENABLE_INVERTED_INDEX: ${CLICKZETTA_ENABLE_INVERTED_INDEX:-true} + CLICKZETTA_ANALYZER_TYPE: ${CLICKZETTA_ANALYZER_TYPE:-chinese} + CLICKZETTA_ANALYZER_MODE: ${CLICKZETTA_ANALYZER_MODE:-smart} + CLICKZETTA_VECTOR_DISTANCE_FUNCTION: ${CLICKZETTA_VECTOR_DISTANCE_FUNCTION:-cosine_distance} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} diff --git a/web/__tests__/description-validation.test.tsx b/web/__tests__/description-validation.test.tsx new file mode 100644 index 0000000000..85263b035f --- /dev/null +++ b/web/__tests__/description-validation.test.tsx @@ -0,0 +1,97 @@ +/** + * Description Validation Test + * + * Tests for the 400-character description validation across App and Dataset + * creation and editing workflows to ensure consistent validation behavior. + */ + +describe('Description Validation Logic', () => { + // Simulate backend validation function + const validateDescriptionLength = (description?: string | null) => { + if (description && description.length > 400) + throw new Error('Description cannot exceed 400 characters.') + + return description + } + + describe('Backend Validation Function', () => { + test('allows description within 400 characters', () => { + const validDescription = 'x'.repeat(400) + expect(() => validateDescriptionLength(validDescription)).not.toThrow() + expect(validateDescriptionLength(validDescription)).toBe(validDescription) + }) + + test('allows empty description', () => { + expect(() => validateDescriptionLength('')).not.toThrow() + expect(() => validateDescriptionLength(null)).not.toThrow() + expect(() => validateDescriptionLength(undefined)).not.toThrow() + }) + + test('rejects description exceeding 400 characters', () => { + const invalidDescription = 'x'.repeat(401) + expect(() => validateDescriptionLength(invalidDescription)).toThrow( + 'Description cannot exceed 400 characters.', + ) + }) + }) + + describe('Backend Validation Consistency', () => { + test('App and Dataset have consistent validation limits', () => { + const maxLength = 400 + const validDescription = 'x'.repeat(maxLength) + const invalidDescription = 'x'.repeat(maxLength + 1) + + // Both should accept exactly 400 characters + expect(validDescription.length).toBe(400) + expect(() => validateDescriptionLength(validDescription)).not.toThrow() + + // Both should reject 401 characters + expect(invalidDescription.length).toBe(401) + expect(() => validateDescriptionLength(invalidDescription)).toThrow() + }) + + test('validation error messages are consistent', () => { + const expectedErrorMessage = 'Description cannot exceed 400 characters.' + + // This would be the error message from both App and Dataset backend validation + expect(expectedErrorMessage).toBe('Description cannot exceed 400 characters.') + + const invalidDescription = 'x'.repeat(401) + try { + validateDescriptionLength(invalidDescription) + } + catch (error) { + expect((error as Error).message).toBe(expectedErrorMessage) + } + }) + }) + + describe('Character Length Edge Cases', () => { + const testCases = [ + { length: 0, shouldPass: true, description: 'empty description' }, + { length: 1, shouldPass: true, description: '1 character' }, + { length: 399, shouldPass: true, description: '399 characters' }, + { length: 400, shouldPass: true, description: '400 characters (boundary)' }, + { length: 401, shouldPass: false, description: '401 characters (over limit)' }, + { length: 500, shouldPass: false, description: '500 characters' }, + { length: 1000, shouldPass: false, description: '1000 characters' }, + ] + + testCases.forEach(({ length, shouldPass, description }) => { + test(`handles ${description} correctly`, () => { + const testDescription = length > 0 ? 'x'.repeat(length) : '' + expect(testDescription.length).toBe(length) + + if (shouldPass) { + expect(() => validateDescriptionLength(testDescription)).not.toThrow() + expect(validateDescriptionLength(testDescription)).toBe(testDescription) + } + else { + expect(() => validateDescriptionLength(testDescription)).toThrow( + 'Description cannot exceed 400 characters.', + ) + } + }) + }) + }) +}) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index d70179266a..f8189b0c8a 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -4,7 +4,6 @@ import React, { useEffect, useMemo } from 'react' import { usePathname } from 'next/navigation' import useSWR from 'swr' import { useTranslation } from 'react-i18next' -import { useBoolean } from 'ahooks' import { RiEqualizer2Fill, RiEqualizer2Line, @@ -44,17 +43,12 @@ type IExtraInfoProps = { } const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => { - const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile) const { t } = useTranslation() const docLink = useDocLink() const hasRelatedApps = relatedApps?.data && relatedApps?.data?.length > 0 const relatedAppsTotal = relatedApps?.data?.length || 0 - useEffect(() => { - setShowTips(!isMobile) - }, [isMobile, setShowTips]) - return
{/* Related apps for desktop */}
{ - const { locale } = useContext(I18n) - const { t } = useTranslation() - const [toc, setToc] = useState>([]) - const [isTocExpanded, setIsTocExpanded] = useState(false) - const { theme } = useTheme() - - // Set initial TOC expanded state based on screen width - useEffect(() => { - const mediaQuery = window.matchMedia('(min-width: 1280px)') - setIsTocExpanded(mediaQuery.matches) - }, []) - - // Extract TOC from article content - useEffect(() => { - const extractTOC = () => { - const article = document.querySelector('article') - if (article) { - const headings = article.querySelectorAll('h2') - const tocItems = Array.from(headings).map((heading) => { - const anchor = heading.querySelector('a') - if (anchor) { - return { - href: anchor.getAttribute('href') || '', - text: anchor.textContent || '', - } - } - return null - }).filter((item): item is { href: string; text: string } => item !== null) - setToc(tocItems) - } - } - - setTimeout(extractTOC, 0) - }, [locale]) - - // Handle TOC item click - const handleTocClick = (e: React.MouseEvent, item: { href: string; text: string }) => { - e.preventDefault() - const targetId = item.href.replace('#', '') - const element = document.getElementById(targetId) - if (element) { - const scrollContainer = document.querySelector('.scroll-container') - if (scrollContainer) { - const headerOffset = -40 - const elementTop = element.offsetTop - headerOffset - scrollContainer.scrollTo({ - top: elementTop, - behavior: 'smooth', - }) - } - } - } - - const Template = useMemo(() => { - switch (locale) { - case LanguagesSupported[1]: - return - case LanguagesSupported[7]: - return - default: - return - } - }, [apiBaseUrl, locale]) - - return ( -
-
- {isTocExpanded - ? ( - - ) - : ( - - )} -
-
- {Template} -
-
- ) -} - -export default Doc diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/container.tsx similarity index 90% rename from web/app/(commonLayout)/datasets/Container.tsx rename to web/app/(commonLayout)/datasets/container.tsx index 112b6a752e..5328fd03aa 100644 --- a/web/app/(commonLayout)/datasets/Container.tsx +++ b/web/app/(commonLayout)/datasets/container.tsx @@ -9,10 +9,10 @@ import { useQuery } from '@tanstack/react-query' // Components import ExternalAPIPanel from '../../components/datasets/external-api/external-api-panel' -import Datasets from './Datasets' -import DatasetFooter from './DatasetFooter' +import Datasets from './datasets' +import DatasetFooter from './dataset-footer' import ApiServer from '../../components/develop/ApiServer' -import Doc from './Doc' +import Doc from './doc' import TabSliderNew from '@/app/components/base/tab-slider-new' import TagManagementModal from '@/app/components/base/tag-management' import TagFilter from '@/app/components/base/tag-management/filter' @@ -86,8 +86,8 @@ const Container = () => { }, [currentWorkspace, router]) return ( -
-
+
+
setActiveTab(newActiveTab)} diff --git a/web/app/(commonLayout)/datasets/DatasetCard.tsx b/web/app/(commonLayout)/datasets/dataset-card.tsx similarity index 93% rename from web/app/(commonLayout)/datasets/DatasetCard.tsx rename to web/app/(commonLayout)/datasets/dataset-card.tsx index 4b40be2c7f..2f0563d47e 100644 --- a/web/app/(commonLayout)/datasets/DatasetCard.tsx +++ b/web/app/(commonLayout)/datasets/dataset-card.tsx @@ -5,6 +5,7 @@ import { useRouter } from 'next/navigation' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { RiMoreFill } from '@remixicon/react' +import { mutate } from 'swr' import cn from '@/utils/classnames' import Confirm from '@/app/components/base/confirm' import { ToastContext } from '@/app/components/base/toast' @@ -57,6 +58,19 @@ const DatasetCard = ({ const onConfirmDelete = useCallback(async () => { try { await deleteDataset(dataset.id) + + // Clear SWR cache to prevent stale data in knowledge retrieval nodes + mutate( + (key) => { + if (typeof key === 'string') return key.includes('/datasets') + if (typeof key === 'object' && key !== null) + return key.url === '/datasets' || key.url?.includes('/datasets') + return false + }, + undefined, + { revalidate: true }, + ) + notify({ type: 'success', message: t('dataset.datasetDeleted') }) if (onSuccess) onSuccess() @@ -162,24 +176,19 @@ const DatasetCard = ({
{dataset.description}
-
+
{ e.stopPropagation() e.preventDefault() }}>
+ containerRef: React.RefObject tags: string[] keywords: string includeAll: boolean diff --git a/web/app/(commonLayout)/datasets/doc.tsx b/web/app/(commonLayout)/datasets/doc.tsx new file mode 100644 index 0000000000..c31dad3c00 --- /dev/null +++ b/web/app/(commonLayout)/datasets/doc.tsx @@ -0,0 +1,203 @@ +'use client' + +import { useEffect, useMemo, useState } from 'react' +import { useContext } from 'use-context-selector' +import { useTranslation } from 'react-i18next' +import { RiCloseLine, RiListUnordered } from '@remixicon/react' +import TemplateEn from './template/template.en.mdx' +import TemplateZh from './template/template.zh.mdx' +import TemplateJa from './template/template.ja.mdx' +import I18n from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import cn from '@/utils/classnames' + +type DocProps = { + apiBaseUrl: string +} + +const Doc = ({ apiBaseUrl }: DocProps) => { + const { locale } = useContext(I18n) + const { t } = useTranslation() + const [toc, setToc] = useState>([]) + const [isTocExpanded, setIsTocExpanded] = useState(false) + const [activeSection, setActiveSection] = useState('') + const { theme } = useTheme() + + // Set initial TOC expanded state based on screen width + useEffect(() => { + const mediaQuery = window.matchMedia('(min-width: 1280px)') + setIsTocExpanded(mediaQuery.matches) + }, []) + + // Extract TOC from article content + useEffect(() => { + const extractTOC = () => { + const article = document.querySelector('article') + if (article) { + const headings = article.querySelectorAll('h2') + const tocItems = Array.from(headings).map((heading) => { + const anchor = heading.querySelector('a') + if (anchor) { + return { + href: anchor.getAttribute('href') || '', + text: anchor.textContent || '', + } + } + return null + }).filter((item): item is { href: string; text: string } => item !== null) + setToc(tocItems) + // Set initial active section + if (tocItems.length > 0) + setActiveSection(tocItems[0].href.replace('#', '')) + } + } + + setTimeout(extractTOC, 0) + }, [locale]) + + // Track scroll position for active section highlighting + useEffect(() => { + const handleScroll = () => { + const scrollContainer = document.querySelector('.scroll-container') + if (!scrollContainer || toc.length === 0) + return + + // Find active section based on scroll position + let currentSection = '' + toc.forEach((item) => { + const targetId = item.href.replace('#', '') + const element = document.getElementById(targetId) + if (element) { + const rect = element.getBoundingClientRect() + // Consider section active if its top is above the middle of viewport + if (rect.top <= window.innerHeight / 2) + currentSection = targetId + } + }) + + if (currentSection && currentSection !== activeSection) + setActiveSection(currentSection) + } + + const scrollContainer = document.querySelector('.scroll-container') + if (scrollContainer) { + scrollContainer.addEventListener('scroll', handleScroll) + handleScroll() // Initial check + return () => scrollContainer.removeEventListener('scroll', handleScroll) + } + }, [toc, activeSection]) + + // Handle TOC item click + const handleTocClick = (e: React.MouseEvent, item: { href: string; text: string }) => { + e.preventDefault() + const targetId = item.href.replace('#', '') + const element = document.getElementById(targetId) + if (element) { + const scrollContainer = document.querySelector('.scroll-container') + if (scrollContainer) { + const headerOffset = -40 + const elementTop = element.offsetTop - headerOffset + scrollContainer.scrollTo({ + top: elementTop, + behavior: 'smooth', + }) + } + } + } + + const Template = useMemo(() => { + switch (locale) { + case LanguagesSupported[1]: + return + case LanguagesSupported[7]: + return + default: + return + } + }, [apiBaseUrl, locale]) + + return ( +
+
+ {isTocExpanded + ? ( + + ) + : ( + + )} +
+
+ {Template} +
+
+ ) +} + +export default Doc diff --git a/web/app/(commonLayout)/datasets/NewDatasetCard.tsx b/web/app/(commonLayout)/datasets/new-dataset-card.tsx similarity index 100% rename from web/app/(commonLayout)/datasets/NewDatasetCard.tsx rename to web/app/(commonLayout)/datasets/new-dataset-card.tsx diff --git a/web/app/(commonLayout)/datasets/page.tsx b/web/app/(commonLayout)/datasets/page.tsx index 60a542f0a2..cbfe25ebd2 100644 --- a/web/app/(commonLayout)/datasets/page.tsx +++ b/web/app/(commonLayout)/datasets/page.tsx @@ -1,6 +1,6 @@ 'use client' import { useTranslation } from 'react-i18next' -import Container from './Container' +import Container from './container' import useDocumentTitle from '@/hooks/use-document-title' const AppList = () => { diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index ebb2e6a806..f1bb5d9156 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
Okay, I will translate the Chinese text in your document while keeping all formatting and code content unchanged. -
+
-
+
-
+
-
+
-
+
-
+
-
+
diff --git a/web/app/(commonLayout)/datasets/template/template.ja.mdx b/web/app/(commonLayout)/datasets/template/template.ja.mdx index 6c0e20e1bb..3011cecbc1 100644 --- a/web/app/(commonLayout)/datasets/template/template.ja.mdx +++ b/web/app/(commonLayout)/datasets/template/template.ja.mdx @@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index c21ce3bf5f..b7ea889a46 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
@@ -1915,7 +1915,7 @@ ___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
diff --git a/web/app/account/account-page/AvatarWithEdit.tsx b/web/app/account/account-page/AvatarWithEdit.tsx index 8250789def..41a6971bf5 100644 --- a/web/app/account/account-page/AvatarWithEdit.tsx +++ b/web/app/account/account-page/AvatarWithEdit.tsx @@ -87,7 +87,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
{ setIsShowAvatarPicker(true) }} - className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black bg-opacity-50 opacity-0 transition-opacity group-hover:opacity-100" + className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black/50 opacity-0 transition-opacity group-hover:opacity-100" > diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index c04d79d2f2..288dcf8c8b 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -12,7 +12,6 @@ import { RiFileUploadLine, } from '@remixicon/react' import AppIcon from '../base/app-icon' -import cn from '@/utils/classnames' import { useStore as useAppStore } from '@/app/components/app/store' import { ToastContext } from '@/app/components/base/toast' import { useAppContext } from '@/context/app-context' @@ -31,6 +30,7 @@ import Divider from '../base/divider' import type { Operation } from './app-operations' import AppOperations from './app-operations' import dynamic from 'next/dynamic' +import cn from '@/utils/classnames' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false, @@ -256,32 +256,40 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx }} className='block w-full' > -
-
- -
-
+
+
+
+ +
+ {expand && ( +
+
+ +
+
+ )} +
+ {!expand && ( +
+
-
-
-
-
{appDetail.name}
+ )} + {expand && ( +
+
+
{appDetail.name}
+
+
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
-
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
-
+ )}
)} diff --git a/web/app/components/app/app-access-control/access-control-dialog.tsx b/web/app/components/app/app-access-control/access-control-dialog.tsx index 72dd33c72e..479eedc9cf 100644 --- a/web/app/components/app/app-access-control/access-control-dialog.tsx +++ b/web/app/components/app/app-access-control/access-control-dialog.tsx @@ -32,7 +32,7 @@ const AccessControlDialog = ({ leaveFrom="opacity-100" leaveTo="opacity-0" > -
+
diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index da4a25c1d8..0fad6cc740 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -106,7 +106,7 @@ function SelectedGroupsBreadCrumb() { setSelectedGroupsForBreadcrumb([]) }, [setSelectedGroupsForBreadcrumb]) return
- 0 && 'text-text-accent cursor-pointer')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} + 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} {selectedGroupsForBreadcrumb.map((group, index) => { return
/ @@ -198,7 +198,7 @@ type BaseItemProps = { children: React.ReactNode } function BaseItem({ children, className }: BaseItemProps) { - return
+ return
{children}
} diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index 4c36ad9956..4b1a5620ae 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -4,7 +4,6 @@ import React, { useRef, useState } from 'react' import { useGetState, useInfiniteScroll } from 'ahooks' import { useTranslation } from 'react-i18next' import Link from 'next/link' -import produce from 'immer' import TypeIcon from '../type-icon' import Modal from '@/app/components/base/modal' import type { DataSet } from '@/models/datasets' @@ -29,9 +28,10 @@ const SelectDataSet: FC = ({ onSelect, }) => { const { t } = useTranslation() - const [selected, setSelected] = React.useState(selectedIds.map(id => ({ id }) as any)) + const [selected, setSelected] = React.useState([]) const [loaded, setLoaded] = React.useState(false) const [datasets, setDataSets] = React.useState(null) + const [hasInitialized, setHasInitialized] = React.useState(false) const hasNoData = !datasets || datasets?.length === 0 const canSelectMulti = true @@ -49,19 +49,17 @@ const SelectDataSet: FC = ({ const newList = [...(datasets || []), ...data.filter(item => item.indexing_technique || item.provider === 'external')] setDataSets(newList) setLoaded(true) - if (!selected.find(item => !item.name)) - return { list: [] } - const newSelected = produce(selected, (draft) => { - selected.forEach((item, index) => { - if (!item.name) { // not fetched database - const newItem = newList.find(i => i.id === item.id) - if (newItem) - draft[index] = newItem - } - }) - }) - setSelected(newSelected) + // Initialize selected datasets based on selectedIds and available datasets + if (!hasInitialized) { + if (selectedIds.length > 0) { + const validSelectedDatasets = selectedIds + .map(id => newList.find(item => item.id === id)) + .filter(Boolean) as DataSet[] + setSelected(validSelectedDatasets) + } + setHasInitialized(true) + } } return { list: [] } }, diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index 9835481ae0..9b33dbe30e 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -55,8 +55,6 @@ const SettingsModal: FC = ({ const { data: embeddingsModelList } = useModelList(ModelTypeEnum.textEmbedding) const { modelList: rerankModelList, - defaultModel: rerankDefaultModel, - currentModel: isRerankDefaultModelValid, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const { t } = useTranslation() const docLink = useDocLink() diff --git a/web/app/components/app/create-app-dialog/app-list/sidebar.tsx b/web/app/components/app/create-app-dialog/app-list/sidebar.tsx index 346de078b4..85c55c5385 100644 --- a/web/app/components/app/create-app-dialog/app-list/sidebar.tsx +++ b/web/app/components/app/create-app-dialog/app-list/sidebar.tsx @@ -40,13 +40,13 @@ type CategoryItemProps = { } function CategoryItem({ category, active, onClick }: CategoryItemProps) { return
  • { onClick?.(category) }}> {category === AppCategories.RECOMMENDED &&
    } + className={classNames('system-sm-medium text-components-menu-item-text group-hover:text-components-menu-item-text-hover group-[.active]:text-components-menu-item-text-active', active && 'system-sm-semibold')} />
  • } diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index c37f7b051a..70a45a4bbe 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -82,8 +82,11 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') getRedirection(isCurrentWorkspaceEditor, app, push) } - catch { - notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) + catch (e: any) { + notify({ + type: 'error', + message: e.message || t('app.newApp.appCreateFailed'), + }) } isCreatingRef.current = false }, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, push, isCurrentWorkspaceEditor]) diff --git a/web/app/components/app/create-from-dsl-modal/uploader.tsx b/web/app/components/app/create-from-dsl-modal/uploader.tsx index 6ad4116dd6..3ab54733dc 100644 --- a/web/app/components/app/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/app/create-from-dsl-modal/uploader.tsx @@ -106,8 +106,8 @@ const Uploader: FC = ({
    - {t('datasetCreation.stepOne.uploader.button')} - {t('datasetDocuments.list.batchModal.browse')} + {t('app.dslUploader.button')} + {t('app.dslUploader.browse')}
    {dragging &&
    } diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 603b5922c5..688da4c25d 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -117,8 +117,11 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { if (onRefresh) onRefresh() } - catch { - notify({ type: 'error', message: t('app.editFailed') }) + catch (e: any) { + notify({ + type: 'error', + message: e.message || t('app.editFailed'), + }) } }, [app.id, notify, onRefresh, t]) @@ -364,26 +367,20 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
    {app.description}
    -
    +
    {isCurrentWorkspaceEditor && ( <>
    { e.stopPropagation() e.preventDefault() }}> -
    +
    { />
    -
    +
    } diff --git a/web/app/components/apps/footer.tsx b/web/app/components/apps/footer.tsx index c5efb2b8b4..9fed4c8757 100644 --- a/web/app/components/apps/footer.tsx +++ b/web/app/components/apps/footer.tsx @@ -1,6 +1,6 @@ -import React, { useState } from 'react' +import React from 'react' import Link from 'next/link' -import { RiCloseLine, RiDiscordFill, RiGithubFill } from '@remixicon/react' +import { RiDiscordFill, RiGithubFill } from '@remixicon/react' import { useTranslation } from 'react-i18next' type CustomLinkProps = { @@ -26,24 +26,9 @@ const CustomLink = React.memo(({ const Footer = () => { const { t } = useTranslation() - const [isVisible, setIsVisible] = useState(true) - - const handleClose = () => { - setIsVisible(false) - } - - if (!isVisible) - return null return (