diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000..3e95f2e982 --- /dev/null +++ b/.env.example @@ -0,0 +1,1197 @@ +# ------------------------------ +# Environment Variables for API service & worker +# ------------------------------ + +# ------------------------------ +# Common Variables +# ------------------------------ + +# The backend URL of the console API, +# used to concatenate the authorization callback. +# If empty, it is the same domain. +# Example: https://api.console.dify.ai +CONSOLE_API_URL= + +# The front-end URL of the console web, +# used to concatenate some front-end addresses and for CORS configuration use. +# If empty, it is the same domain. +# Example: https://console.dify.ai +CONSOLE_WEB_URL= + +# Service API Url, +# used to display Service API Base Url to the front-end. +# If empty, it is the same domain. +# Example: https://api.dify.ai +SERVICE_API_URL= + +# WebApp API backend Url, +# used to declare the back-end URL for the front-end API. +# If empty, it is the same domain. +# Example: https://api.app.dify.ai +APP_API_URL= + +# WebApp Url, +# used to display WebAPP API Base Url to the front-end. +# If empty, it is the same domain. +# Example: https://app.dify.ai +APP_WEB_URL= + +# File preview or download Url prefix. +# used to display File preview or download Url to the front-end or as Multi-model inputs; +# Url is signed and has expiration time. +# Setting FILES_URL is required for file processing plugins. +# - For https://example.com, use FILES_URL=https://example.com +# - For http://example.com, use FILES_URL=http://example.com +# Recommendation: use a dedicated domain (e.g., https://upload.example.com). +# Alternatively, use http://:5001 or http://api:5001, +# ensuring port 5001 is externally accessible (see docker-compose.yaml). +FILES_URL= + +# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network. +# Set this to the internal Docker service URL for proper plugin file access. +# Example: INTERNAL_FILES_URL=http://api:5001 +INTERNAL_FILES_URL= + +# ------------------------------ +# Server Configuration +# ------------------------------ + +# The log level for the application. +# Supported values are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` +LOG_LEVEL=INFO +# Log file path +LOG_FILE=/app/logs/server.log +# Log file max size, the unit is MB +LOG_FILE_MAX_SIZE=20 +# Log file max backup count +LOG_FILE_BACKUP_COUNT=5 +# Log dateformat +LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S +# Log Timezone +LOG_TZ=UTC + +# Debug mode, default is false. +# It is recommended to turn on this configuration for local development +# to prevent some problems caused by monkey patch. +DEBUG=false + +# Flask debug mode, it can output trace information at the interface when turned on, +# which is convenient for debugging. +FLASK_DEBUG=false + +# Enable request logging, which will log the request and response information. +# And the log level is DEBUG +ENABLE_REQUEST_LOGGING=False + +# A secret key that is used for securely signing the session cookie +# and encrypting sensitive information on the database. +# You can generate a strong key using `openssl rand -base64 42`. +SECRET_KEY=sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U + +# Password for admin user initialization. +# If left unset, admin user will not be prompted for a password +# when creating the initial admin account. +# The length of the password cannot exceed 30 characters. +INIT_PASSWORD= + +# Deployment environment. +# Supported values are `PRODUCTION`, `TESTING`. Default is `PRODUCTION`. +# Testing environment. There will be a distinct color label on the front-end page, +# indicating that this environment is a testing environment. +DEPLOY_ENV=PRODUCTION + +# Whether to enable the version check policy. +# If set to empty, https://updates.dify.ai will be called for version check. +CHECK_UPDATE_URL=https://updates.dify.ai + +# Used to change the OpenAI base address, default is https://api.openai.com/v1. +# When OpenAI cannot be accessed in China, replace it with a domestic mirror address, +# or when a local model provides OpenAI compatible API, it can be replaced. +OPENAI_API_BASE=https://api.openai.com/v1 + +# When enabled, migrations will be executed prior to application startup +# and the application will start after the migrations have completed. +MIGRATION_ENABLED=true + +# File Access Time specifies a time interval in seconds for the file to be accessed. +# The default value is 300 seconds. +FILES_ACCESS_TIMEOUT=300 + +# Access token expiration time in minutes +ACCESS_TOKEN_EXPIRE_MINUTES=60 + +# Refresh token expiration time in days +REFRESH_TOKEN_EXPIRE_DAYS=30 + +# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. +APP_MAX_ACTIVE_REQUESTS=0 +APP_MAX_EXECUTION_TIME=1200 + +# ------------------------------ +# Container Startup Related Configuration +# Only effective when starting with docker image or docker-compose. +# ------------------------------ + +# API service binding address, default: 0.0.0.0, i.e., all addresses can be accessed. +DIFY_BIND_ADDRESS=0.0.0.0 + +# API service binding port number, default 5001. +DIFY_PORT=5001 + +# The number of API server workers, i.e., the number of workers. +# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent +# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers +SERVER_WORKER_AMOUNT=1 + +# Defaults to gevent. If using windows, it can be switched to sync or solo. +SERVER_WORKER_CLASS=gevent + +# Default number of worker connections, the default is 10. +SERVER_WORKER_CONNECTIONS=10 + +# Similar to SERVER_WORKER_CLASS. +# If using windows, it can be switched to sync or solo. +CELERY_WORKER_CLASS= + +# Request handling timeout. The default is 200, +# it is recommended to set it to 360 to support a longer sse connection time. +GUNICORN_TIMEOUT=360 + +# The number of Celery workers. The default is 1, and can be set as needed. +CELERY_WORKER_AMOUNT= + +# Flag indicating whether to enable autoscaling of Celery workers. +# +# Autoscaling is useful when tasks are CPU intensive and can be dynamically +# allocated and deallocated based on the workload. +# +# When autoscaling is enabled, the maximum and minimum number of workers can +# be specified. The autoscaling algorithm will dynamically adjust the number +# of workers within the specified range. +# +# Default is false (i.e., autoscaling is disabled). +# +# Example: +# CELERY_AUTO_SCALE=true +CELERY_AUTO_SCALE=false + +# The maximum number of Celery workers that can be autoscaled. +# This is optional and only used when autoscaling is enabled. +# Default is not set. +CELERY_MAX_WORKERS= + +# The minimum number of Celery workers that can be autoscaled. +# This is optional and only used when autoscaling is enabled. +# Default is not set. +CELERY_MIN_WORKERS= + +# API Tool configuration +API_TOOL_DEFAULT_CONNECT_TIMEOUT=10 +API_TOOL_DEFAULT_READ_TIMEOUT=60 + +# ------------------------------- +# Datasource Configuration +# -------------------------------- +ENABLE_WEBSITE_JINAREADER=true +ENABLE_WEBSITE_FIRECRAWL=true +ENABLE_WEBSITE_WATERCRAWL=true + +# ------------------------------ +# Database Configuration +# The database uses PostgreSQL. Please use the public schema. +# It is consistent with the configuration in the 'db' service below. +# ------------------------------ + +DB_USERNAME=postgres +DB_PASSWORD=difyai123456 +DB_HOST=db +DB_PORT=5432 +DB_DATABASE=dify +# The size of the database connection pool. +# The default is 30 connections, which can be appropriately increased. +SQLALCHEMY_POOL_SIZE=30 +# Database connection pool recycling time, the default is 3600 seconds. +SQLALCHEMY_POOL_RECYCLE=3600 +# Whether to print SQL, default is false. +SQLALCHEMY_ECHO=false +# If True, will test connections for liveness upon each checkout +SQLALCHEMY_POOL_PRE_PING=false +# Whether to enable the Last in first out option or use default FIFO queue if is false +SQLALCHEMY_POOL_USE_LIFO=false + +# Maximum number of connections to the database +# Default is 100 +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-connection.html#GUC-MAX-CONNECTIONS +POSTGRES_MAX_CONNECTIONS=100 + +# Sets the amount of shared memory used for postgres's shared buffers. +# Default is 128MB +# Recommended value: 25% of available memory +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-SHARED-BUFFERS +POSTGRES_SHARED_BUFFERS=128MB + +# Sets the amount of memory used by each database worker for working space. +# Default is 4MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-WORK-MEM +POSTGRES_WORK_MEM=4MB + +# Sets the amount of memory reserved for maintenance activities. +# Default is 64MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-MAINTENANCE-WORK-MEM +POSTGRES_MAINTENANCE_WORK_MEM=64MB + +# Sets the planner's assumption about the effective cache size. +# Default is 4096MB +# +# Reference: https://www.postgresql.org/docs/current/runtime-config-query.html#GUC-EFFECTIVE-CACHE-SIZE +POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB + +# ------------------------------ +# Redis Configuration +# This Redis configuration is used for caching and for pub/sub during conversation. +# ------------------------------ + +REDIS_HOST=redis +REDIS_PORT=6379 +REDIS_USERNAME= +REDIS_PASSWORD=difyai123456 +REDIS_USE_SSL=false +REDIS_DB=0 + +# Whether to use Redis Sentinel mode. +# If set to true, the application will automatically discover and connect to the master node through Sentinel. +REDIS_USE_SENTINEL=false + +# List of Redis Sentinel nodes. If Sentinel mode is enabled, provide at least one Sentinel IP and port. +# Format: `:,:,:` +REDIS_SENTINELS= +REDIS_SENTINEL_SERVICE_NAME= +REDIS_SENTINEL_USERNAME= +REDIS_SENTINEL_PASSWORD= +REDIS_SENTINEL_SOCKET_TIMEOUT=0.1 + +# List of Redis Cluster nodes. If Cluster mode is enabled, provide at least one Cluster IP and port. +# Format: `:,:,:` +REDIS_USE_CLUSTERS=false +REDIS_CLUSTERS= +REDIS_CLUSTERS_PASSWORD= + +# ------------------------------ +# Celery Configuration +# ------------------------------ + +# Use redis as the broker, and redis db 1 for celery broker. +# Format as follows: `redis://:@:/` +# Example: redis://:difyai123456@redis:6379/1 +# If use Redis Sentinel, format as follows: `sentinel://:@:/` +# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1 +CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1 +BROKER_USE_SSL=false + +# If you are using Redis Sentinel for high availability, configure the following settings. +CELERY_USE_SENTINEL=false +CELERY_SENTINEL_MASTER_NAME= +CELERY_SENTINEL_PASSWORD= +CELERY_SENTINEL_SOCKET_TIMEOUT=0.1 + +# ------------------------------ +# CORS Configuration +# Used to set the front-end cross-domain access policy. +# ------------------------------ + +# Specifies the allowed origins for cross-origin requests to the Web API, +# e.g. https://dify.app or * for all origins. +WEB_API_CORS_ALLOW_ORIGINS=* + +# Specifies the allowed origins for cross-origin requests to the console API, +# e.g. https://cloud.dify.ai or * for all origins. +CONSOLE_CORS_ALLOW_ORIGINS=* + +# ------------------------------ +# File Storage Configuration +# ------------------------------ + +# The type of storage to use for storing user files. +STORAGE_TYPE=opendal + +# Apache OpenDAL Configuration +# The configuration for OpenDAL consists of the following format: OPENDAL__. +# You can find all the service configurations (CONFIG_NAME) in the repository at: https://github.com/apache/opendal/tree/main/core/src/services. +# Dify will scan configurations starting with OPENDAL_ and automatically apply them. +# The scheme name for the OpenDAL storage. +OPENDAL_SCHEME=fs +# Configurations for OpenDAL Local File System. +OPENDAL_FS_ROOT=storage + +# ClickZetta Volume Configuration (for storage backend) +# To use ClickZetta Volume as storage backend, set STORAGE_TYPE=clickzetta-volume +# Note: ClickZetta Volume will reuse the existing CLICKZETTA_* connection parameters + +# Volume type selection (three types available): +# - user: Personal/small team use, simple config, user-level permissions +# - table: Enterprise multi-tenant, smart routing, table-level + user-level permissions +# - external: Data lake integration, external storage connection, volume-level + storage-level permissions +CLICKZETTA_VOLUME_TYPE=user + +# External Volume name (required only when TYPE=external) +CLICKZETTA_VOLUME_NAME= + +# Table Volume table prefix (used only when TYPE=table) +CLICKZETTA_VOLUME_TABLE_PREFIX=dataset_ + +# Dify file directory prefix (isolates from other apps, recommended to keep default) +CLICKZETTA_VOLUME_DIFY_PREFIX=dify_km + +# S3 Configuration +# +S3_ENDPOINT= +S3_REGION=us-east-1 +S3_BUCKET_NAME=difyai +S3_ACCESS_KEY= +S3_SECRET_KEY= +# Whether to use AWS managed IAM roles for authenticating with the S3 service. +# If set to false, the access key and secret key must be provided. +S3_USE_AWS_MANAGED_IAM=false + +# Azure Blob Configuration +# +AZURE_BLOB_ACCOUNT_NAME=difyai +AZURE_BLOB_ACCOUNT_KEY=difyai +AZURE_BLOB_CONTAINER_NAME=difyai-container +AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net + +# Google Storage Configuration +# +GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name +GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64= + +# The Alibaba Cloud OSS configurations, +# +ALIYUN_OSS_BUCKET_NAME=your-bucket-name +ALIYUN_OSS_ACCESS_KEY=your-access-key +ALIYUN_OSS_SECRET_KEY=your-secret-key +ALIYUN_OSS_ENDPOINT=https://oss-ap-southeast-1-internal.aliyuncs.com +ALIYUN_OSS_REGION=ap-southeast-1 +ALIYUN_OSS_AUTH_VERSION=v4 +# Don't start with '/'. OSS doesn't support leading slash in object names. +ALIYUN_OSS_PATH=your-path + +# Tencent COS Configuration +# +TENCENT_COS_BUCKET_NAME=your-bucket-name +TENCENT_COS_SECRET_KEY=your-secret-key +TENCENT_COS_SECRET_ID=your-secret-id +TENCENT_COS_REGION=your-region +TENCENT_COS_SCHEME=your-scheme + +# Oracle Storage Configuration +# +OCI_ENDPOINT=https://your-object-storage-namespace.compat.objectstorage.us-ashburn-1.oraclecloud.com +OCI_BUCKET_NAME=your-bucket-name +OCI_ACCESS_KEY=your-access-key +OCI_SECRET_KEY=your-secret-key +OCI_REGION=us-ashburn-1 + +# Huawei OBS Configuration +# +HUAWEI_OBS_BUCKET_NAME=your-bucket-name +HUAWEI_OBS_SECRET_KEY=your-secret-key +HUAWEI_OBS_ACCESS_KEY=your-access-key +HUAWEI_OBS_SERVER=your-server-url + +# Volcengine TOS Configuration +# +VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name +VOLCENGINE_TOS_SECRET_KEY=your-secret-key +VOLCENGINE_TOS_ACCESS_KEY=your-access-key +VOLCENGINE_TOS_ENDPOINT=your-server-url +VOLCENGINE_TOS_REGION=your-region + +# Baidu OBS Storage Configuration +# +BAIDU_OBS_BUCKET_NAME=your-bucket-name +BAIDU_OBS_SECRET_KEY=your-secret-key +BAIDU_OBS_ACCESS_KEY=your-access-key +BAIDU_OBS_ENDPOINT=your-server-url + +# Supabase Storage Configuration +# +SUPABASE_BUCKET_NAME=your-bucket-name +SUPABASE_API_KEY=your-access-key +SUPABASE_URL=your-server-url + +# ------------------------------ +# Vector Database Configuration +# ------------------------------ + +# The type of vector store to use. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. +VECTOR_STORE=weaviate + +# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. +WEAVIATE_ENDPOINT=http://weaviate:8080 +WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih + +# The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`. +QDRANT_URL=http://qdrant:6333 +QDRANT_API_KEY=difyai123456 +QDRANT_CLIENT_TIMEOUT=20 +QDRANT_GRPC_ENABLED=false +QDRANT_GRPC_PORT=6334 +QDRANT_REPLICATION_FACTOR=1 + +# Milvus configuration. Only available when VECTOR_STORE is `milvus`. +# The milvus uri. +MILVUS_URI=http://host.docker.internal:19530 +MILVUS_DATABASE= +MILVUS_TOKEN= +MILVUS_USER= +MILVUS_PASSWORD= +MILVUS_ENABLE_HYBRID_SEARCH=False +MILVUS_ANALYZER_PARAMS= + +# MyScale configuration, only available when VECTOR_STORE is `myscale` +# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to: +# https://myscale.com/docs/en/text-search/#understanding-fts-index-parameters +MYSCALE_HOST=myscale +MYSCALE_PORT=8123 +MYSCALE_USER=default +MYSCALE_PASSWORD= +MYSCALE_DATABASE=dify +MYSCALE_FTS_PARAMS= + +# Couchbase configurations, only available when VECTOR_STORE is `couchbase` +# The connection string must include hostname defined in the docker-compose file (couchbase-server in this case) +COUCHBASE_CONNECTION_STRING=couchbase://couchbase-server +COUCHBASE_USER=Administrator +COUCHBASE_PASSWORD=password +COUCHBASE_BUCKET_NAME=Embeddings +COUCHBASE_SCOPE_NAME=_default + +# pgvector configurations, only available when VECTOR_STORE is `pgvector` +PGVECTOR_HOST=pgvector +PGVECTOR_PORT=5432 +PGVECTOR_USER=postgres +PGVECTOR_PASSWORD=difyai123456 +PGVECTOR_DATABASE=dify +PGVECTOR_MIN_CONNECTION=1 +PGVECTOR_MAX_CONNECTION=5 +PGVECTOR_PG_BIGM=false +PGVECTOR_PG_BIGM_VERSION=1.2-20240606 + +# vastbase configurations, only available when VECTOR_STORE is `vastbase` +VASTBASE_HOST=vastbase +VASTBASE_PORT=5432 +VASTBASE_USER=dify +VASTBASE_PASSWORD=Difyai123456 +VASTBASE_DATABASE=dify +VASTBASE_MIN_CONNECTION=1 +VASTBASE_MAX_CONNECTION=5 + +# pgvecto-rs configurations, only available when VECTOR_STORE is `pgvecto-rs` +PGVECTO_RS_HOST=pgvecto-rs +PGVECTO_RS_PORT=5432 +PGVECTO_RS_USER=postgres +PGVECTO_RS_PASSWORD=difyai123456 +PGVECTO_RS_DATABASE=dify + +# analyticdb configurations, only available when VECTOR_STORE is `analyticdb` +ANALYTICDB_KEY_ID=your-ak +ANALYTICDB_KEY_SECRET=your-sk +ANALYTICDB_REGION_ID=cn-hangzhou +ANALYTICDB_INSTANCE_ID=gp-ab123456 +ANALYTICDB_ACCOUNT=testaccount +ANALYTICDB_PASSWORD=testpassword +ANALYTICDB_NAMESPACE=dify +ANALYTICDB_NAMESPACE_PASSWORD=difypassword +ANALYTICDB_HOST=gp-test.aliyuncs.com +ANALYTICDB_PORT=5432 +ANALYTICDB_MIN_CONNECTION=1 +ANALYTICDB_MAX_CONNECTION=5 + +# TiDB vector configurations, only available when VECTOR_STORE is `tidb_vector` +TIDB_VECTOR_HOST=tidb +TIDB_VECTOR_PORT=4000 +TIDB_VECTOR_USER= +TIDB_VECTOR_PASSWORD= +TIDB_VECTOR_DATABASE=dify + +# Matrixone vector configurations. +MATRIXONE_HOST=matrixone +MATRIXONE_PORT=6001 +MATRIXONE_USER=dump +MATRIXONE_PASSWORD=111 +MATRIXONE_DATABASE=dify + +# Tidb on qdrant configuration, only available when VECTOR_STORE is `tidb_on_qdrant` +TIDB_ON_QDRANT_URL=http://127.0.0.1 +TIDB_ON_QDRANT_API_KEY=dify +TIDB_ON_QDRANT_CLIENT_TIMEOUT=20 +TIDB_ON_QDRANT_GRPC_ENABLED=false +TIDB_ON_QDRANT_GRPC_PORT=6334 +TIDB_PUBLIC_KEY=dify +TIDB_PRIVATE_KEY=dify +TIDB_API_URL=http://127.0.0.1 +TIDB_IAM_API_URL=http://127.0.0.1 +TIDB_REGION=regions/aws-us-east-1 +TIDB_PROJECT_ID=dify +TIDB_SPEND_LIMIT=100 + +# Chroma configuration, only available when VECTOR_STORE is `chroma` +CHROMA_HOST=127.0.0.1 +CHROMA_PORT=8000 +CHROMA_TENANT=default_tenant +CHROMA_DATABASE=default_database +CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthClientProvider +CHROMA_AUTH_CREDENTIALS= + +# Oracle configuration, only available when VECTOR_STORE is `oracle` +ORACLE_USER=dify +ORACLE_PASSWORD=dify +ORACLE_DSN=oracle:1521/FREEPDB1 +ORACLE_CONFIG_DIR=/app/api/storage/wallet +ORACLE_WALLET_LOCATION=/app/api/storage/wallet +ORACLE_WALLET_PASSWORD=dify +ORACLE_IS_AUTONOMOUS=false + +# relyt configurations, only available when VECTOR_STORE is `relyt` +RELYT_HOST=db +RELYT_PORT=5432 +RELYT_USER=postgres +RELYT_PASSWORD=difyai123456 +RELYT_DATABASE=postgres + +# open search configuration, only available when VECTOR_STORE is `opensearch` +OPENSEARCH_HOST=opensearch +OPENSEARCH_PORT=9200 +OPENSEARCH_SECURE=true +OPENSEARCH_VERIFY_CERTS=true +OPENSEARCH_AUTH_METHOD=basic +OPENSEARCH_USER=admin +OPENSEARCH_PASSWORD=admin +# If using AWS managed IAM, e.g. Managed Cluster or OpenSearch Serverless +OPENSEARCH_AWS_REGION=ap-southeast-1 +OPENSEARCH_AWS_SERVICE=aoss + +# tencent vector configurations, only available when VECTOR_STORE is `tencent` +TENCENT_VECTOR_DB_URL=http://127.0.0.1 +TENCENT_VECTOR_DB_API_KEY=dify +TENCENT_VECTOR_DB_TIMEOUT=30 +TENCENT_VECTOR_DB_USERNAME=dify +TENCENT_VECTOR_DB_DATABASE=dify +TENCENT_VECTOR_DB_SHARD=1 +TENCENT_VECTOR_DB_REPLICAS=2 +TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false + +# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch` +ELASTICSEARCH_HOST=0.0.0.0 +ELASTICSEARCH_PORT=9200 +ELASTICSEARCH_USERNAME=elastic +ELASTICSEARCH_PASSWORD=elastic +KIBANA_PORT=5601 + +# baidu vector configurations, only available when VECTOR_STORE is `baidu` +BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 +BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 +BAIDU_VECTOR_DB_ACCOUNT=root +BAIDU_VECTOR_DB_API_KEY=dify +BAIDU_VECTOR_DB_DATABASE=dify +BAIDU_VECTOR_DB_SHARD=1 +BAIDU_VECTOR_DB_REPLICAS=3 + +# VikingDB configurations, only available when VECTOR_STORE is `vikingdb` +VIKINGDB_ACCESS_KEY=your-ak +VIKINGDB_SECRET_KEY=your-sk +VIKINGDB_REGION=cn-shanghai +VIKINGDB_HOST=api-vikingdb.xxx.volces.com +VIKINGDB_SCHEMA=http +VIKINGDB_CONNECTION_TIMEOUT=30 +VIKINGDB_SOCKET_TIMEOUT=30 + +# Lindorm configuration, only available when VECTOR_STORE is `lindorm` +LINDORM_URL=http://lindorm:30070 +LINDORM_USERNAME=lindorm +LINDORM_PASSWORD=lindorm +LINDORM_QUERY_TIMEOUT=1 + +# OceanBase Vector configuration, only available when VECTOR_STORE is `oceanbase` +OCEANBASE_VECTOR_HOST=oceanbase +OCEANBASE_VECTOR_PORT=2881 +OCEANBASE_VECTOR_USER=root@test +OCEANBASE_VECTOR_PASSWORD=difyai123456 +OCEANBASE_VECTOR_DATABASE=test +OCEANBASE_CLUSTER_NAME=difyai +OCEANBASE_MEMORY_LIMIT=6G +OCEANBASE_ENABLE_HYBRID_SEARCH=false + +# opengauss configurations, only available when VECTOR_STORE is `opengauss` +OPENGAUSS_HOST=opengauss +OPENGAUSS_PORT=6600 +OPENGAUSS_USER=postgres +OPENGAUSS_PASSWORD=Dify@123 +OPENGAUSS_DATABASE=dify +OPENGAUSS_MIN_CONNECTION=1 +OPENGAUSS_MAX_CONNECTION=5 +OPENGAUSS_ENABLE_PQ=false + +# huawei cloud search service vector configurations, only available when VECTOR_STORE is `huawei_cloud` +HUAWEI_CLOUD_HOSTS=https://127.0.0.1:9200 +HUAWEI_CLOUD_USER=admin +HUAWEI_CLOUD_PASSWORD=admin + +# Upstash Vector configuration, only available when VECTOR_STORE is `upstash` +UPSTASH_VECTOR_URL=https://xxx-vector.upstash.io +UPSTASH_VECTOR_TOKEN=dify + +# TableStore Vector configuration +# (only used when VECTOR_STORE is tablestore) +TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com +TABLESTORE_INSTANCE_NAME=instance-name +TABLESTORE_ACCESS_KEY_ID=xxx +TABLESTORE_ACCESS_KEY_SECRET=xxx + +# Clickzetta configuration, only available when VECTOR_STORE is `clickzetta` +CLICKZETTA_USERNAME= +CLICKZETTA_PASSWORD= +CLICKZETTA_INSTANCE= +CLICKZETTA_SERVICE=api.clickzetta.com +CLICKZETTA_WORKSPACE=quick_start +CLICKZETTA_VCLUSTER=default_ap +CLICKZETTA_SCHEMA=dify +CLICKZETTA_BATCH_SIZE=100 +CLICKZETTA_ENABLE_INVERTED_INDEX=true +CLICKZETTA_ANALYZER_TYPE=chinese +CLICKZETTA_ANALYZER_MODE=smart +CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance + +# ------------------------------ +# Knowledge Configuration +# ------------------------------ + +# Upload file size limit, default 15M. +UPLOAD_FILE_SIZE_LIMIT=15 + +# The maximum number of files that can be uploaded at a time, default 5. +UPLOAD_FILE_BATCH_LIMIT=5 + +# ETL type, support: `dify`, `Unstructured` +# `dify` Dify's proprietary file extraction scheme +# `Unstructured` Unstructured.io file extraction scheme +ETL_TYPE=dify + +# Unstructured API path and API key, needs to be configured when ETL_TYPE is Unstructured +# Or using Unstructured for document extractor node for pptx. +# For example: http://unstructured:8000/general/v0/general +UNSTRUCTURED_API_URL= +UNSTRUCTURED_API_KEY= +SCARF_NO_ANALYTICS=true + +# ------------------------------ +# Model Configuration +# ------------------------------ + +# The maximum number of tokens allowed for prompt generation. +# This setting controls the upper limit of tokens that can be used by the LLM +# when generating a prompt in the prompt generation tool. +# Default: 512 tokens. +PROMPT_GENERATION_MAX_TOKENS=512 + +# The maximum number of tokens allowed for code generation. +# This setting controls the upper limit of tokens that can be used by the LLM +# when generating code in the code generation tool. +# Default: 1024 tokens. +CODE_GENERATION_MAX_TOKENS=1024 + +# Enable or disable plugin based token counting. If disabled, token counting will return 0. +# This can improve performance by skipping token counting operations. +# Default: false (disabled). +PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false + +# ------------------------------ +# Multi-modal Configuration +# ------------------------------ + +# The format of the image/video/audio/document sent when the multi-modal model is input, +# the default is base64, optional url. +# The delay of the call in url mode will be lower than that in base64 mode. +# It is generally recommended to use the more compatible base64 mode. +# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video/audio/document. +MULTIMODAL_SEND_FORMAT=base64 +# Upload image file size limit, default 10M. +UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 +# Upload video file size limit, default 100M. +UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 +# Upload audio file size limit, default 50M. +UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 + +# ------------------------------ +# Sentry Configuration +# Used for application monitoring and error log tracking. +# ------------------------------ +SENTRY_DSN= + +# API Service Sentry DSN address, default is empty, when empty, +# all monitoring information is not reported to Sentry. +# If not set, Sentry error reporting will be disabled. +API_SENTRY_DSN= +# API Service The reporting ratio of Sentry events, if it is 0.01, it is 1%. +API_SENTRY_TRACES_SAMPLE_RATE=1.0 +# API Service The reporting ratio of Sentry profiles, if it is 0.01, it is 1%. +API_SENTRY_PROFILES_SAMPLE_RATE=1.0 + +# Web Service Sentry DSN address, default is empty, when empty, +# all monitoring information is not reported to Sentry. +# If not set, Sentry error reporting will be disabled. +WEB_SENTRY_DSN= + +# ------------------------------ +# Notion Integration Configuration +# Variables can be obtained by applying for Notion integration: https://www.notion.so/my-integrations +# ------------------------------ + +# Configure as "public" or "internal". +# Since Notion's OAuth redirect URL only supports HTTPS, +# if deploying locally, please use Notion's internal integration. +NOTION_INTEGRATION_TYPE=public +# Notion OAuth client secret (used for public integration type) +NOTION_CLIENT_SECRET= +# Notion OAuth client id (used for public integration type) +NOTION_CLIENT_ID= +# Notion internal integration secret. +# If the value of NOTION_INTEGRATION_TYPE is "internal", +# you need to configure this variable. +NOTION_INTERNAL_SECRET= + +# ------------------------------ +# Mail related configuration +# ------------------------------ + +# Mail type, support: resend, smtp, sendgrid +MAIL_TYPE=resend + +# Default send from email address, if not specified +# If using SendGrid, use the 'from' field for authentication if necessary. +MAIL_DEFAULT_SEND_FROM= + +# API-Key for the Resend email provider, used when MAIL_TYPE is `resend`. +RESEND_API_URL=https://api.resend.com +RESEND_API_KEY=your-resend-api-key + + +# SMTP server configuration, used when MAIL_TYPE is `smtp` +SMTP_SERVER= +SMTP_PORT=465 +SMTP_USERNAME= +SMTP_PASSWORD= +SMTP_USE_TLS=true +SMTP_OPPORTUNISTIC_TLS=false + +# Sendgid configuration +SENDGRID_API_KEY= + +# ------------------------------ +# Others Configuration +# ------------------------------ + +# Maximum length of segmentation tokens for indexing +INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 + +# Member invitation link valid time (hours), +# Default: 72. +INVITE_EXPIRY_HOURS=72 + +# Reset password token valid time (minutes), +RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 + +# The sandbox service endpoint. +CODE_EXECUTION_ENDPOINT=http://sandbox:8194 +CODE_EXECUTION_API_KEY=dify-sandbox +CODE_MAX_NUMBER=9223372036854775807 +CODE_MIN_NUMBER=-9223372036854775808 +CODE_MAX_DEPTH=5 +CODE_MAX_PRECISION=20 +CODE_MAX_STRING_LENGTH=80000 +CODE_MAX_STRING_ARRAY_LENGTH=30 +CODE_MAX_OBJECT_ARRAY_LENGTH=30 +CODE_MAX_NUMBER_ARRAY_LENGTH=1000 +CODE_EXECUTION_CONNECT_TIMEOUT=10 +CODE_EXECUTION_READ_TIMEOUT=60 +CODE_EXECUTION_WRITE_TIMEOUT=10 +TEMPLATE_TRANSFORM_MAX_LENGTH=80000 + +# Workflow runtime configuration +WORKFLOW_MAX_EXECUTION_STEPS=500 +WORKFLOW_MAX_EXECUTION_TIME=1200 +WORKFLOW_CALL_MAX_DEPTH=5 +MAX_VARIABLE_SIZE=204800 +WORKFLOW_PARALLEL_DEPTH_LIMIT=3 +WORKFLOW_FILE_UPLOAD_LIMIT=10 + +# Workflow storage configuration +# Options: rdbms, hybrid +# rdbms: Use only the relational database (default) +# hybrid: Save new data to object storage, read from both object storage and RDBMS +WORKFLOW_NODE_EXECUTION_STORAGE=rdbms + +# Repository configuration +# Core workflow execution repository implementation +CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository + +# Core workflow node execution repository implementation +CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository + +# API workflow node execution repository implementation +API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository + +# API workflow run repository implementation +API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository + +# HTTP request node in workflow configuration +HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 +HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 +HTTP_REQUEST_NODE_SSL_VERIFY=True + +# Respect X-* headers to redirect clients +RESPECT_XFORWARD_HEADERS_ENABLED=false + +# SSRF Proxy server HTTP URL +SSRF_PROXY_HTTP_URL=http://ssrf_proxy:3128 +# SSRF Proxy server HTTPS URL +SSRF_PROXY_HTTPS_URL=http://ssrf_proxy:3128 + +# Maximum loop count in the workflow +LOOP_NODE_MAX_COUNT=100 + +# The maximum number of tools that can be used in the agent. +MAX_TOOLS_NUM=10 + +# Maximum number of Parallelism branches in the workflow +MAX_PARALLEL_LIMIT=10 + +# The maximum number of iterations for agent setting +MAX_ITERATIONS_NUM=99 + +# ------------------------------ +# Environment Variables for web Service +# ------------------------------ + +# The timeout for the text generation in millisecond +TEXT_GENERATION_TIMEOUT_MS=60000 + +# Allow rendering unsafe URLs which have "data:" scheme. +ALLOW_UNSAFE_DATA_SCHEME=false + +# ------------------------------ +# Environment Variables for db Service +# ------------------------------ + +# The name of the default postgres user. +POSTGRES_USER=${DB_USERNAME} +# The password for the default postgres user. +POSTGRES_PASSWORD=${DB_PASSWORD} +# The name of the default postgres database. +POSTGRES_DB=${DB_DATABASE} +# postgres data directory +PGDATA=/var/lib/postgresql/data/pgdata + +# ------------------------------ +# Environment Variables for sandbox Service +# ------------------------------ + +# The API key for the sandbox service +SANDBOX_API_KEY=dify-sandbox +# The mode in which the Gin framework runs +SANDBOX_GIN_MODE=release +# The timeout for the worker in seconds +SANDBOX_WORKER_TIMEOUT=15 +# Enable network for the sandbox service +SANDBOX_ENABLE_NETWORK=true +# HTTP proxy URL for SSRF protection +SANDBOX_HTTP_PROXY=http://ssrf_proxy:3128 +# HTTPS proxy URL for SSRF protection +SANDBOX_HTTPS_PROXY=http://ssrf_proxy:3128 +# The port on which the sandbox service runs +SANDBOX_PORT=8194 + +# ------------------------------ +# Environment Variables for weaviate Service +# (only used when VECTOR_STORE is weaviate) +# ------------------------------ +WEAVIATE_PERSISTENCE_DATA_PATH=/var/lib/weaviate +WEAVIATE_QUERY_DEFAULTS_LIMIT=25 +WEAVIATE_AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true +WEAVIATE_DEFAULT_VECTORIZER_MODULE=none +WEAVIATE_CLUSTER_HOSTNAME=node1 +WEAVIATE_AUTHENTICATION_APIKEY_ENABLED=true +WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih +WEAVIATE_AUTHENTICATION_APIKEY_USERS=hello@dify.ai +WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED=true +WEAVIATE_AUTHORIZATION_ADMINLIST_USERS=hello@dify.ai + +# ------------------------------ +# Environment Variables for Chroma +# (only used when VECTOR_STORE is chroma) +# ------------------------------ + +# Authentication credentials for Chroma server +CHROMA_SERVER_AUTHN_CREDENTIALS=difyai123456 +# Authentication provider for Chroma server +CHROMA_SERVER_AUTHN_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider +# Persistence setting for Chroma server +CHROMA_IS_PERSISTENT=TRUE + +# ------------------------------ +# Environment Variables for Oracle Service +# (only used when VECTOR_STORE is oracle) +# ------------------------------ +ORACLE_PWD=Dify123456 +ORACLE_CHARACTERSET=AL32UTF8 + +# ------------------------------ +# Environment Variables for milvus Service +# (only used when VECTOR_STORE is milvus) +# ------------------------------ +# ETCD configuration for auto compaction mode +ETCD_AUTO_COMPACTION_MODE=revision +# ETCD configuration for auto compaction retention in terms of number of revisions +ETCD_AUTO_COMPACTION_RETENTION=1000 +# ETCD configuration for backend quota in bytes +ETCD_QUOTA_BACKEND_BYTES=4294967296 +# ETCD configuration for the number of changes before triggering a snapshot +ETCD_SNAPSHOT_COUNT=50000 +# MinIO access key for authentication +MINIO_ACCESS_KEY=minioadmin +# MinIO secret key for authentication +MINIO_SECRET_KEY=minioadmin +# ETCD service endpoints +ETCD_ENDPOINTS=etcd:2379 +# MinIO service address +MINIO_ADDRESS=minio:9000 +# Enable or disable security authorization +MILVUS_AUTHORIZATION_ENABLED=true + +# ------------------------------ +# Environment Variables for pgvector / pgvector-rs Service +# (only used when VECTOR_STORE is pgvector / pgvector-rs) +# ------------------------------ +PGVECTOR_PGUSER=postgres +# The password for the default postgres user. +PGVECTOR_POSTGRES_PASSWORD=difyai123456 +# The name of the default postgres database. +PGVECTOR_POSTGRES_DB=dify +# postgres data directory +PGVECTOR_PGDATA=/var/lib/postgresql/data/pgdata + +# ------------------------------ +# Environment Variables for opensearch +# (only used when VECTOR_STORE is opensearch) +# ------------------------------ +OPENSEARCH_DISCOVERY_TYPE=single-node +OPENSEARCH_BOOTSTRAP_MEMORY_LOCK=true +OPENSEARCH_JAVA_OPTS_MIN=512m +OPENSEARCH_JAVA_OPTS_MAX=1024m +OPENSEARCH_INITIAL_ADMIN_PASSWORD=Qazwsxedc!@#123 +OPENSEARCH_MEMLOCK_SOFT=-1 +OPENSEARCH_MEMLOCK_HARD=-1 +OPENSEARCH_NOFILE_SOFT=65536 +OPENSEARCH_NOFILE_HARD=65536 + +# ------------------------------ +# Environment Variables for Nginx reverse proxy +# ------------------------------ +NGINX_SERVER_NAME=_ +NGINX_HTTPS_ENABLED=false +# HTTP port +NGINX_PORT=80 +# SSL settings are only applied when HTTPS_ENABLED is true +NGINX_SSL_PORT=443 +# if HTTPS_ENABLED is true, you're required to add your own SSL certificates/keys to the `./nginx/ssl` directory +# and modify the env vars below accordingly. +NGINX_SSL_CERT_FILENAME=dify.crt +NGINX_SSL_CERT_KEY_FILENAME=dify.key +NGINX_SSL_PROTOCOLS=TLSv1.1 TLSv1.2 TLSv1.3 + +# Nginx performance tuning +NGINX_WORKER_PROCESSES=auto +NGINX_CLIENT_MAX_BODY_SIZE=100M +NGINX_KEEPALIVE_TIMEOUT=65 + +# Proxy settings +NGINX_PROXY_READ_TIMEOUT=3600s +NGINX_PROXY_SEND_TIMEOUT=3600s + +# Set true to accept requests for /.well-known/acme-challenge/ +NGINX_ENABLE_CERTBOT_CHALLENGE=false + +# ------------------------------ +# Certbot Configuration +# ------------------------------ + +# Email address (required to get certificates from Let's Encrypt) +CERTBOT_EMAIL=your_email@example.com + +# Domain name +CERTBOT_DOMAIN=your_domain.com + +# certbot command options +# i.e: --force-renewal --dry-run --test-cert --debug +CERTBOT_OPTIONS= + +# ------------------------------ +# Environment Variables for SSRF Proxy +# ------------------------------ +SSRF_HTTP_PORT=3128 +SSRF_COREDUMP_DIR=/var/spool/squid +SSRF_REVERSE_PROXY_PORT=8194 +SSRF_SANDBOX_HOST=sandbox +SSRF_DEFAULT_TIME_OUT=5 +SSRF_DEFAULT_CONNECT_TIME_OUT=5 +SSRF_DEFAULT_READ_TIME_OUT=5 +SSRF_DEFAULT_WRITE_TIME_OUT=5 + +# ------------------------------ +# docker env var for specifying vector db type at startup +# (based on the vector db type, the corresponding docker +# compose profile will be used) +# if you want to use unstructured, add ',unstructured' to the end +# ------------------------------ +COMPOSE_PROFILES=${VECTOR_STORE:-weaviate} + +# ------------------------------ +# Docker Compose Service Expose Host Port Configurations +# ------------------------------ +EXPOSE_NGINX_PORT=80 +EXPOSE_NGINX_SSL_PORT=443 + +# ---------------------------------------------------------------------------- +# ModelProvider & Tool Position Configuration +# Used to specify the model providers and tools that can be used in the app. +# ---------------------------------------------------------------------------- + +# Pin, include, and exclude tools +# Use comma-separated values with no spaces between items. +# Example: POSITION_TOOL_PINS=bing,google +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +# Pin, include, and exclude model providers +# Use comma-separated values with no spaces between items. +# Example: POSITION_PROVIDER_PINS=openai,openllm +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= + +# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP +CSP_WHITELIST= + +# Enable or disable create tidb service job +CREATE_TIDB_SERVICE_JOB_ENABLED=false + +# Maximum number of submitted thread count in a ThreadPool for parallel node execution +MAX_SUBMIT_COUNT=100 + +# The maximum number of top-k value for RAG. +TOP_K_MAX_VALUE=10 + +# ------------------------------ +# Plugin Daemon Configuration +# ------------------------------ + +DB_PLUGIN_DATABASE=dify_plugin +EXPOSE_PLUGIN_DAEMON_PORT=5002 +PLUGIN_DAEMON_PORT=5002 +PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi +PLUGIN_DAEMON_URL=http://plugin_daemon:5002 +PLUGIN_MAX_PACKAGE_SIZE=52428800 +PLUGIN_PPROF_ENABLED=false + +PLUGIN_DEBUGGING_HOST=0.0.0.0 +PLUGIN_DEBUGGING_PORT=5003 +EXPOSE_PLUGIN_DEBUGGING_HOST=localhost +EXPOSE_PLUGIN_DEBUGGING_PORT=5003 + +# If this key is changed, DIFY_INNER_API_KEY in plugin_daemon service must also be updated or agent node will fail. +PLUGIN_DIFY_INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 +PLUGIN_DIFY_INNER_API_URL=http://api:5001 + +ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id} + +MARKETPLACE_ENABLED=true +MARKETPLACE_API_URL=https://marketplace.dify.ai + +FORCE_VERIFYING_SIGNATURE=true + +PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 +PLUGIN_MAX_EXECUTION_TIMEOUT=600 +# PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple +PIP_MIRROR_URL= + +# https://github.com/langgenius/dify-plugin-daemon/blob/main/.env.example +# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss volcengine_tos +PLUGIN_STORAGE_TYPE=local +PLUGIN_STORAGE_LOCAL_ROOT=/app/storage +PLUGIN_WORKING_PATH=/app/storage/cwd +PLUGIN_INSTALLED_PATH=plugin +PLUGIN_PACKAGE_CACHE_PATH=plugin_packages +PLUGIN_MEDIA_CACHE_PATH=assets +# Plugin oss bucket +PLUGIN_STORAGE_OSS_BUCKET= +# Plugin oss s3 credentials +PLUGIN_S3_USE_AWS=false +PLUGIN_S3_USE_AWS_MANAGED_IAM=false +PLUGIN_S3_ENDPOINT= +PLUGIN_S3_USE_PATH_STYLE=false +PLUGIN_AWS_ACCESS_KEY= +PLUGIN_AWS_SECRET_KEY= +PLUGIN_AWS_REGION= +# Plugin oss azure blob +PLUGIN_AZURE_BLOB_STORAGE_CONTAINER_NAME= +PLUGIN_AZURE_BLOB_STORAGE_CONNECTION_STRING= +# Plugin oss tencent cos +PLUGIN_TENCENT_COS_SECRET_KEY= +PLUGIN_TENCENT_COS_SECRET_ID= +PLUGIN_TENCENT_COS_REGION= +# Plugin oss aliyun oss +PLUGIN_ALIYUN_OSS_REGION= +PLUGIN_ALIYUN_OSS_ENDPOINT= +PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID= +PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET= +PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4 +PLUGIN_ALIYUN_OSS_PATH= +# Plugin oss volcengine tos +PLUGIN_VOLCENGINE_TOS_ENDPOINT= +PLUGIN_VOLCENGINE_TOS_ACCESS_KEY= +PLUGIN_VOLCENGINE_TOS_SECRET_KEY= +PLUGIN_VOLCENGINE_TOS_REGION= + +# ------------------------------ +# OTLP Collector Configuration +# ------------------------------ +ENABLE_OTEL=false +OTLP_TRACE_ENDPOINT= +OTLP_METRIC_ENDPOINT= +OTLP_BASE_ENDPOINT=http://localhost:4318 +OTLP_API_KEY= +OTEL_EXPORTER_OTLP_PROTOCOL= +OTEL_EXPORTER_TYPE=otlp +OTEL_SAMPLING_RATE=0.1 +OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000 +OTEL_MAX_QUEUE_SIZE=2048 +OTEL_MAX_EXPORT_BATCH_SIZE=512 +OTEL_METRIC_EXPORT_INTERVAL=60000 +OTEL_BATCH_EXPORT_TIMEOUT=10000 +OTEL_METRIC_EXPORT_TIMEOUT=30000 + +# Prevent Clickjacking +ALLOW_EMBED=false + +# Dataset queue monitor configuration +QUEUE_MONITOR_THRESHOLD=200 +# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai +QUEUE_MONITOR_ALERT_EMAILS= +# Monitor interval in minutes, default is 30 minutes +QUEUE_MONITOR_INTERVAL=30 diff --git a/.github/ISSUE_TEMPLATE/chore.yaml b/.github/ISSUE_TEMPLATE/chore.yaml new file mode 100644 index 0000000000..cf74dcc546 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/chore.yaml @@ -0,0 +1,44 @@ +name: "✨ Refactor" +description: Refactor existing code for improved readability and maintainability. +title: "[Chore/Refactor] " +labels: + - refactor +body: + - type: checkboxes + attributes: + label: Self Checks + description: "To make sure we get to you in time, please check the following :)" + options: + - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). + required: true + - label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). + required: true + - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. + required: true + - label: I confirm that I am using English to submit this report, otherwise it will be closed. + required: true + - label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :) + required: true + - label: "Please do not modify this template :) and fill in all the required fields." + required: true + - type: textarea + id: description + attributes: + label: Description + placeholder: "Describe the refactor you are proposing." + validations: + required: true + - type: textarea + id: motivation + attributes: + label: Motivation + placeholder: "Explain why this refactor is necessary." + validations: + required: false + - type: textarea + id: additional-context + attributes: + label: Additional Context + placeholder: "Add any other context or screenshots about the request here." + validations: + required: false diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index a5a5071fae..9c3daddbfc 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -99,3 +99,6 @@ jobs: - name: Run Tool run: uv run --project api bash dev/pytest/pytest_tools.sh + + - name: Run TestContainers + run: uv run --project api bash dev/pytest/pytest_testcontainers.sh diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 5e290c5d02..152ff3b648 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -9,6 +9,7 @@ permissions: jobs: autofix: + if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index b933560a5e..17af047267 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -7,6 +7,7 @@ on: - "deploy/dev" - "deploy/enterprise" - "build/**" + - "release/e-*" tags: - "*" diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index c79d58563f..1cb9c0967b 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -5,6 +5,10 @@ on: types: [closed] branches: [main] +permissions: + contents: write + pull-requests: write + jobs: check-and-update: if: github.event.pull_request.merged == true @@ -16,7 +20,7 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 2 # last 2 commits - persist-credentials: false + token: ${{ secrets.GITHUB_TOKEN }} - name: Check for file changes in i18n/en-US id: check_files @@ -49,7 +53,7 @@ jobs: if: env.FILES_CHANGED == 'true' run: pnpm install --frozen-lockfile - - name: Run npm script + - name: Generate i18n translations if: env.FILES_CHANGED == 'true' run: pnpm run auto-gen-i18n @@ -57,6 +61,7 @@ jobs: if: env.FILES_CHANGED == 'true' uses: peter-evans/create-pull-request@v6 with: + token: ${{ secrets.GITHUB_TOKEN }} commit-message: Update i18n files based on en-US changes title: 'chore: translate i18n files' body: This PR was automatically created to update i18n files based on changes in en-US locale. diff --git a/.gitignore b/.gitignore index dd4673a3d2..5c68d89a4d 100644 --- a/.gitignore +++ b/.gitignore @@ -215,3 +215,4 @@ mise.toml # AI Assistant .roo/ api/.env.backup +/clickzetta diff --git a/README.md b/README.md index 16a1268cb1..775f6f351f 100644 --- a/README.md +++ b/README.md @@ -235,6 +235,10 @@ Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https:/ One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Deploy to AKS with Azure Devops Pipeline + +One-Click deploy Dify to AKS with [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contributing diff --git a/README_AR.md b/README_AR.md index d2cb0098a3..e7a4dbdb27 100644 --- a/README_AR.md +++ b/README_AR.md @@ -217,6 +217,10 @@ docker compose up -d انشر ​​Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### استخدام Azure Devops Pipeline للنشر على AKS + +انشر Dify على AKS بنقرة واحدة باستخدام [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## المساهمة diff --git a/README_BN.md b/README_BN.md index f57413ec8b..e4da437eff 100644 --- a/README_BN.md +++ b/README_BN.md @@ -235,6 +235,10 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + #### AKS-এ ডিপ্লয় করার জন্য Azure Devops Pipeline ব্যবহার + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ব্যবহার করে Dify কে AKS-এ এক ক্লিকে ডিপ্লয় করুন + ## Contributing diff --git a/README_CN.md b/README_CN.md index e9c73eb48b..82149519d3 100644 --- a/README_CN.md +++ b/README_CN.md @@ -233,6 +233,9 @@ docker compose up -d 使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云 +#### 使用 Azure Devops Pipeline 部署到AKS + +使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 将 Dify 一键部署到 AKS ## Star History diff --git a/README_DE.md b/README_DE.md index d31a56542d..2420ac0392 100644 --- a/README_DE.md +++ b/README_DE.md @@ -230,6 +230,10 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/) Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Verwendung von Azure Devops Pipeline für AKS-Bereitstellung + +Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) verwenden + ## Contributing diff --git a/README_ES.md b/README_ES.md index 918bfe2286..4fa59dc18f 100644 --- a/README_ES.md +++ b/README_ES.md @@ -230,6 +230,10 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Uso de Azure Devops Pipeline para implementar en AKS + +Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contribuir diff --git a/README_FR.md b/README_FR.md index 56ca878aae..dcbc869620 100644 --- a/README_FR.md +++ b/README_FR.md @@ -228,6 +228,10 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Utilisation d'Azure Devops Pipeline pour déployer sur AKS + +Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contribuer diff --git a/README_JA.md b/README_JA.md index 6d277a36ed..d840fd6419 100644 --- a/README_JA.md +++ b/README_JA.md @@ -227,6 +227,10 @@ docker compose up -d #### Alibaba Cloud Data Management [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます +#### AKSへのデプロイにAzure Devops Pipelineを使用 + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)を使用してDifyをAKSにワンクリックでデプロイ + ## 貢献 diff --git a/README_KL.md b/README_KL.md index dac67eeb29..41c7969e1c 100644 --- a/README_KL.md +++ b/README_KL.md @@ -228,6 +228,10 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### AKS 'e' Deploy je Azure Devops Pipeline lo'laH + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) lo'laH Dify AKS 'e' wa'DIch click 'e' Deploy + ## Contributing diff --git a/README_KR.md b/README_KR.md index 072481da02..d4b31a8928 100644 --- a/README_KR.md +++ b/README_KR.md @@ -222,6 +222,10 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다 +#### AKS에 배포하기 위해 Azure Devops Pipeline 사용 + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)을 사용하여 Dify를 AKS에 원클릭으로 배포 + ## 기여 diff --git a/README_PT.md b/README_PT.md index 1260f8e6fd..94452cb233 100644 --- a/README_PT.md +++ b/README_PT.md @@ -227,6 +227,10 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Usando Azure Devops Pipeline para Implantar no AKS + +Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contribuindo diff --git a/README_SI.md b/README_SI.md index 7ded001d86..d840e9155f 100644 --- a/README_SI.md +++ b/README_SI.md @@ -228,6 +228,10 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Uporaba Azure Devops Pipeline za uvajanje v AKS + +Z enim klikom namestite Dify v AKS z uporabo [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Prispevam diff --git a/README_TR.md b/README_TR.md index 37953f0de1..470a7570e0 100644 --- a/README_TR.md +++ b/README_TR.md @@ -221,6 +221,10 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın +#### AKS'ye Dağıtım için Azure Devops Pipeline Kullanımı + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) kullanarak Dify'ı tek tıkla AKS'ye dağıtın + ## Katkıda Bulunma diff --git a/README_TW.md b/README_TW.md index f70d6a25f6..18f1d2754a 100644 --- a/README_TW.md +++ b/README_TW.md @@ -233,6 +233,10 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲 +#### 使用 Azure Devops Pipeline 部署到AKS + +使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 將 Dify 一鍵部署到 AKS + ## 貢獻 diff --git a/README_VI.md b/README_VI.md index ddd9aa95f6..2ab6da80fc 100644 --- a/README_VI.md +++ b/README_VI.md @@ -224,6 +224,10 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Sử dụng Azure Devops Pipeline để Triển khai lên AKS + +Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure Devops Pipeline Helm Chart bởi @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Đóng góp diff --git a/api/.env.example b/api/.env.example index 18f2dbf647..4beabfecea 100644 --- a/api/.env.example +++ b/api/.env.example @@ -232,6 +232,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com TABLESTORE_INSTANCE_NAME=instance-name TABLESTORE_ACCESS_KEY_ID=xxx TABLESTORE_ACCESS_KEY_SECRET=xxx +TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false # Tidb Vector configuration TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com diff --git a/api/Dockerfile b/api/Dockerfile index e097b5811e..d69291f7ea 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -19,7 +19,7 @@ RUN apt-get update \ # Install Python dependencies COPY pyproject.toml uv.lock ./ -RUN uv sync --locked +RUN uv sync --locked --no-dev # production stage FROM base AS production diff --git a/api/commands.py b/api/commands.py index 79bb6713d0..8ee52ba716 100644 --- a/api/commands.py +++ b/api/commands.py @@ -5,10 +5,11 @@ import secrets from typing import Any, Optional import click +import sqlalchemy as sa from flask import current_app from pydantic import TypeAdapter from sqlalchemy import select -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError from configs import dify_config from constants.languages import languages @@ -180,8 +181,8 @@ def migrate_annotation_vector_database(): ) if not apps: break - except NotFound: - break + except SQLAlchemyError: + raise page += 1 for app in apps: @@ -307,8 +308,8 @@ def migrate_knowledge_vector_database(): ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - except NotFound: - break + except SQLAlchemyError: + raise page += 1 for dataset in datasets: @@ -457,7 +458,7 @@ def convert_to_agent_apps(): """ with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query)) + rs = conn.execute(sa.text(sql_query)) apps = [] for i in rs: @@ -560,8 +561,8 @@ def old_metadata_migration(): .order_by(DatasetDocument.created_at.desc()) ) documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - except NotFound: - break + except SQLAlchemyError: + raise if not documents: break for document in documents: @@ -702,7 +703,7 @@ def fix_app_site_missing(): sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id where sites.id is null limit 1000""" with db.engine.begin() as conn: - rs = conn.execute(db.text(sql)) + rs = conn.execute(sa.text(sql)) processed_count = 0 for i in rs: @@ -916,7 +917,7 @@ def clear_orphaned_file_records(force: bool): ) orphaned_message_files = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])}) @@ -937,7 +938,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style("- Deleting orphaned message_files records", fg="white")) query = "DELETE FROM message_files WHERE id IN :ids" with db.engine.begin() as conn: - conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])}) + conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])}) click.echo( click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green") ) @@ -954,7 +955,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white")) query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]}) click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) @@ -974,7 +975,7 @@ def clear_orphaned_file_records(force: bool): f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" ) with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) elif ids_table["type"] == "text": @@ -989,7 +990,7 @@ def clear_orphaned_file_records(force: bool): f"FROM {ids_table['table']}" ) with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: for j in i[0]: all_ids_in_tables.append({"table": ids_table["table"], "id": j}) @@ -1008,7 +1009,7 @@ def clear_orphaned_file_records(force: bool): f"FROM {ids_table['table']}" ) with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: for j in i[0]: all_ids_in_tables.append({"table": ids_table["table"], "id": j}) @@ -1037,7 +1038,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white")) query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids" with db.engine.begin() as conn: - conn.execute(db.text(query), {"ids": tuple(orphaned_files)}) + conn.execute(sa.text(query), {"ids": tuple(orphaned_files)}) except Exception as e: click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red")) return @@ -1107,7 +1108,7 @@ def remove_orphaned_files_on_storage(force: bool): click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white")) query = f"SELECT {files_table['key_column']} FROM {files_table['table']}" with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: all_files_in_tables.append(str(i[0])) click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 68b16e48db..4e228ab932 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -10,6 +10,7 @@ from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig from .storage.amazon_s3_storage_config import S3StorageConfig from .storage.azure_blob_storage_config import AzureBlobStorageConfig from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig +from .storage.clickzetta_volume_storage_config import ClickZettaVolumeStorageConfig from .storage.google_cloud_storage_config import GoogleCloudStorageConfig from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig from .storage.oci_storage_config import OCIStorageConfig @@ -20,6 +21,7 @@ from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig from .vdb.analyticdb_config import AnalyticdbConfig from .vdb.baidu_vector_config import BaiduVectorDBConfig from .vdb.chroma_config import ChromaConfig +from .vdb.clickzetta_config import ClickzettaConfig from .vdb.couchbase_config import CouchbaseConfig from .vdb.elasticsearch_config import ElasticsearchConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig @@ -52,6 +54,7 @@ class StorageConfig(BaseSettings): "aliyun-oss", "azure-blob", "baidu-obs", + "clickzetta-volume", "google-storage", "huawei-obs", "oci-storage", @@ -61,8 +64,9 @@ class StorageConfig(BaseSettings): "local", ] = Field( description="Type of storage to use." - " Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', " - "'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.", + " Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', " + "'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', " + "'volcengine-tos', 'supabase'. Default is 'opendal'.", default="opendal", ) @@ -215,7 +219,7 @@ class DatabaseConfig(BaseSettings): class CeleryConfig(DatabaseConfig): CELERY_BACKEND: str = Field( - description="Backend for Celery task results. Options: 'database', 'redis'.", + description="Backend for Celery task results. Options: 'database', 'redis', 'rabbitmq'.", default="redis", ) @@ -245,7 +249,12 @@ class CeleryConfig(DatabaseConfig): @computed_field def CELERY_RESULT_BACKEND(self) -> str | None: - return f"db+{self.SQLALCHEMY_DATABASE_URI}" if self.CELERY_BACKEND == "database" else self.CELERY_BROKER_URL + if self.CELERY_BACKEND in ("database", "rabbitmq"): + return f"db+{self.SQLALCHEMY_DATABASE_URI}" + elif self.CELERY_BACKEND == "redis": + return self.CELERY_BROKER_URL + else: + return None @property def BROKER_USE_SSL(self) -> bool: @@ -298,6 +307,7 @@ class MiddlewareConfig( AliyunOSSStorageConfig, AzureBlobStorageConfig, BaiduOBSStorageConfig, + ClickZettaVolumeStorageConfig, GoogleCloudStorageConfig, HuaweiCloudOBSStorageConfig, OCIStorageConfig, @@ -310,6 +320,7 @@ class MiddlewareConfig( VectorStoreConfig, AnalyticdbConfig, ChromaConfig, + ClickzettaConfig, HuaweiCloudConfig, MilvusConfig, MyScaleConfig, diff --git a/api/configs/middleware/storage/clickzetta_volume_storage_config.py b/api/configs/middleware/storage/clickzetta_volume_storage_config.py new file mode 100644 index 0000000000..56e1b6a957 --- /dev/null +++ b/api/configs/middleware/storage/clickzetta_volume_storage_config.py @@ -0,0 +1,65 @@ +"""ClickZetta Volume Storage Configuration""" + +from typing import Optional + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class ClickZettaVolumeStorageConfig(BaseSettings): + """Configuration for ClickZetta Volume storage.""" + + CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field( + description="Username for ClickZetta Volume authentication", + default=None, + ) + + CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field( + description="Password for ClickZetta Volume authentication", + default=None, + ) + + CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field( + description="ClickZetta instance identifier", + default=None, + ) + + CLICKZETTA_VOLUME_SERVICE: str = Field( + description="ClickZetta service endpoint", + default="api.clickzetta.com", + ) + + CLICKZETTA_VOLUME_WORKSPACE: str = Field( + description="ClickZetta workspace name", + default="quick_start", + ) + + CLICKZETTA_VOLUME_VCLUSTER: str = Field( + description="ClickZetta virtual cluster name", + default="default_ap", + ) + + CLICKZETTA_VOLUME_SCHEMA: str = Field( + description="ClickZetta schema name", + default="dify", + ) + + CLICKZETTA_VOLUME_TYPE: str = Field( + description="ClickZetta volume type (table|user|external)", + default="user", + ) + + CLICKZETTA_VOLUME_NAME: Optional[str] = Field( + description="ClickZetta volume name for external volumes", + default=None, + ) + + CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field( + description="Prefix for ClickZetta volume table names", + default="dataset_", + ) + + CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field( + description="Directory prefix for User Volume to organize Dify files", + default="dify_km", + ) diff --git a/api/configs/middleware/vdb/clickzetta_config.py b/api/configs/middleware/vdb/clickzetta_config.py new file mode 100644 index 0000000000..04f81e25fc --- /dev/null +++ b/api/configs/middleware/vdb/clickzetta_config.py @@ -0,0 +1,69 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class ClickzettaConfig(BaseModel): + """ + Clickzetta Lakehouse vector database configuration + """ + + CLICKZETTA_USERNAME: Optional[str] = Field( + description="Username for authenticating with Clickzetta Lakehouse", + default=None, + ) + + CLICKZETTA_PASSWORD: Optional[str] = Field( + description="Password for authenticating with Clickzetta Lakehouse", + default=None, + ) + + CLICKZETTA_INSTANCE: Optional[str] = Field( + description="Clickzetta Lakehouse instance ID", + default=None, + ) + + CLICKZETTA_SERVICE: Optional[str] = Field( + description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')", + default="api.clickzetta.com", + ) + + CLICKZETTA_WORKSPACE: Optional[str] = Field( + description="Clickzetta workspace name", + default="default", + ) + + CLICKZETTA_VCLUSTER: Optional[str] = Field( + description="Clickzetta virtual cluster name", + default="default_ap", + ) + + CLICKZETTA_SCHEMA: Optional[str] = Field( + description="Database schema name in Clickzetta", + default="public", + ) + + CLICKZETTA_BATCH_SIZE: Optional[int] = Field( + description="Batch size for bulk insert operations", + default=100, + ) + + CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field( + description="Enable inverted index for full-text search capabilities", + default=True, + ) + + CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field( + description="Analyzer type for full-text search: keyword, english, chinese, unicode", + default="chinese", + ) + + CLICKZETTA_ANALYZER_MODE: Optional[str] = Field( + description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)", + default="smart", + ) + + CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field( + description="Distance function for vector similarity: l2_distance or cosine_distance", + default="cosine_distance", + ) diff --git a/api/configs/middleware/vdb/tablestore_config.py b/api/configs/middleware/vdb/tablestore_config.py index c4dcc0d465..1aab01c6e1 100644 --- a/api/configs/middleware/vdb/tablestore_config.py +++ b/api/configs/middleware/vdb/tablestore_config.py @@ -28,3 +28,8 @@ class TableStoreConfig(BaseSettings): description="AccessKey secret for the instance name", default=None, ) + + TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: bool = Field( + description="Whether to normalize full-text search scores to [0, 1]", + default=False, + ) diff --git a/api/constants/__init__.py b/api/constants/__init__.py index 9e052320ac..c98f4d55c8 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -9,10 +9,10 @@ DEFAULT_FILE_NUMBER_LIMITS = 3 IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) -VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"] +VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"] VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS]) -AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"] +AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"] AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index e25f92399c..57dbc8da64 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -84,6 +84,7 @@ from .datasets import ( external, hit_testing, metadata, + upload_file, website, ) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index c2ba880405..ee6011cd65 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -100,7 +100,7 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 -class AnnotationListApi(Resource): +class AnnotationApi(Resource): @setup_required @login_required @account_initialization_required @@ -123,6 +123,23 @@ class AnnotationListApi(Resource): } return response, 200 + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("annotation") + @marshal_with(annotation_fields) + def post(self, app_id): + if not current_user.is_editor: + raise Forbidden() + + app_id = str(app_id) + parser = reqparse.RequestParser() + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") + args = parser.parse_args() + annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) + return annotation + @setup_required @login_required @account_initialization_required @@ -131,8 +148,25 @@ class AnnotationListApi(Resource): raise Forbidden() app_id = str(app_id) - AppAnnotationService.clear_all_annotations(app_id) - return {"result": "success"}, 204 + + # Use request.args.getlist to get annotation_ids array directly + annotation_ids = request.args.getlist("annotation_id") + + # If annotation_ids are provided, handle batch deletion + if annotation_ids: + # Check if any annotation_ids contain empty strings or invalid values + if not all(annotation_id.strip() for annotation_id in annotation_ids if annotation_id): + return { + "code": "bad_request", + "message": "annotation_ids are required if the parameter is provided.", + }, 400 + + result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids) + return result, 204 + # If no annotation_ids are provided, handle clearing all annotations + else: + AppAnnotationService.clear_all_annotations(app_id) + return {"result": "success"}, 204 class AnnotationExportApi(Resource): @@ -149,25 +183,6 @@ class AnnotationExportApi(Resource): return response, 200 -class AnnotationCreateApi(Resource): - @setup_required - @login_required - @account_initialization_required - @cloud_edition_billing_resource_check("annotation") - @marshal_with(annotation_fields) - def post(self, app_id): - if not current_user.is_editor: - raise Forbidden() - - app_id = str(app_id) - parser = reqparse.RequestParser() - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") - args = parser.parse_args() - annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) - return annotation - - class AnnotationUpdateDeleteApi(Resource): @setup_required @login_required @@ -210,14 +225,15 @@ class AnnotationBatchImportApi(Resource): raise Forbidden() app_id = str(app_id) - # get file from request - file = request.files["file"] # check file if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() + + # get file from request + file = request.files["file"] # check file type if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") @@ -276,7 +292,7 @@ api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply api.add_resource( AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" ) -api.add_resource(AnnotationListApi, "/apps//annotations") +api.add_resource(AnnotationApi, "/apps//annotations") api.add_resource(AnnotationExportApi, "/apps//annotations/export") api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 9fe32dde6d..1cc13d669c 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -28,6 +28,12 @@ from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] +def _validate_description_length(description): + if description and len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + class AppListApi(Resource): @setup_required @login_required @@ -94,7 +100,7 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") @@ -146,7 +152,7 @@ class AppApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") @@ -189,7 +195,7 @@ class AppCopyApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 32b64d10c5..343b7acd7b 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -67,7 +67,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "message_count": i.message_count}) @@ -176,7 +176,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) @@ -234,7 +234,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"} @@ -310,7 +310,7 @@ ORDER BY response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} @@ -373,7 +373,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( { @@ -435,7 +435,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)}) @@ -495,7 +495,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 6c7c73707b..7f80afd83b 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -2,6 +2,7 @@ from datetime import datetime from decimal import Decimal import pytz +import sqlalchemy as sa from flask import jsonify from flask_login import current_user from flask_restful import Resource, reqparse @@ -71,7 +72,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "runs": i.runs}) @@ -133,7 +134,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) @@ -195,7 +196,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( { @@ -277,7 +278,7 @@ GROUP BY response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f551bc2432..2befd2a651 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -41,7 +41,7 @@ def _validate_name(name): def _validate_description_length(description): - if len(description) > 400: + if description and len(description) > 400: raise ValueError("Description cannot exceed 400 characters.") return description @@ -113,7 +113,7 @@ class DatasetListApi(Resource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", @@ -683,6 +683,7 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.HUAWEI_CLOUD | VectorType.TENCENT | VectorType.MATRIXONE + | VectorType.CLICKZETTA ): return { "retrieval_method": [ @@ -731,6 +732,7 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.TENCENT | VectorType.HUAWEI_CLOUD | VectorType.MATRIXONE + | VectorType.CLICKZETTA ): return { "retrieval_method": [ diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index b6e91dd98e..4e0955bd43 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -642,7 +642,7 @@ class DocumentIndexingStatusApi(DocumentResource): return marshal(document_dict, document_status_fields) -class DocumentDetailApi(DocumentResource): +class DocumentApi(DocumentResource): METADATA_CHOICES = {"all", "only", "without"} @setup_required @@ -730,6 +730,28 @@ class DocumentDetailApi(DocumentResource): return response, 200 + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def delete(self, dataset_id, document_id): + dataset_id = str(dataset_id) + document_id = str(document_id) + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + + document = self.get_document(dataset_id, document_id) + + try: + DocumentService.delete_document(document) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError("Cannot delete document during indexing.") + + return {"result": "success"}, 204 + class DocumentProcessingApi(DocumentResource): @setup_required @@ -768,30 +790,6 @@ class DocumentProcessingApi(DocumentResource): return {"result": "success"}, 200 -class DocumentDeleteApi(DocumentResource): - @setup_required - @login_required - @account_initialization_required - @cloud_edition_billing_rate_limit_check("knowledge") - def delete(self, dataset_id, document_id): - dataset_id = str(dataset_id) - document_id = str(document_id) - dataset = DatasetService.get_dataset(dataset_id) - if dataset is None: - raise NotFound("Dataset not found.") - # check user's model setting - DatasetService.check_dataset_model_setting(dataset) - - document = self.get_document(dataset_id, document_id) - - try: - DocumentService.delete_document(document) - except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError("Cannot delete document during indexing.") - - return {"result": "success"}, 204 - - class DocumentMetadataApi(DocumentResource): @setup_required @login_required @@ -1037,11 +1035,10 @@ api.add_resource( api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets//batch//indexing-estimate") api.add_resource(DocumentBatchIndexingStatusApi, "/datasets//batch//indexing-status") api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") -api.add_resource(DocumentDetailApi, "/datasets//documents/") +api.add_resource(DocumentApi, "/datasets//documents/") api.add_resource( DocumentProcessingApi, "/datasets//documents//processing/" ) -api.add_resource(DocumentDeleteApi, "/datasets//documents/") api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") api.add_resource(DocumentStatusApi, "/datasets//documents/status//batch") api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") diff --git a/api/controllers/console/datasets/upload_file.py b/api/controllers/console/datasets/upload_file.py new file mode 100644 index 0000000000..9b456c771d --- /dev/null +++ b/api/controllers/console/datasets/upload_file.py @@ -0,0 +1,62 @@ +from flask_login import current_user +from flask_restful import Resource +from werkzeug.exceptions import NotFound + +from controllers.console import api +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from core.file import helpers as file_helpers +from extensions.ext_database import db +from models.dataset import Dataset +from models.model import UploadFile +from services.dataset_service import DocumentService + + +class UploadFileApi(Resource): + @setup_required + @account_initialization_required + def get(self, dataset_id, document_id): + """Get upload file.""" + # check dataset + dataset_id = str(dataset_id) + dataset = ( + db.session.query(Dataset) + .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id) + .first() + ) + if not dataset: + raise NotFound("Dataset not found.") + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset.id, document_id) + if not document: + raise NotFound("Document not found.") + # check upload file + if document.data_source_type != "upload_file": + raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.") + data_source_info = document.data_source_info_dict + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + if not upload_file: + raise NotFound("UploadFile not found.") + else: + raise ValueError("Upload file id not found in document data source info.") + + url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) + return { + "id": upload_file.id, + "name": upload_file.name, + "size": upload_file.size, + "extension": upload_file.extension, + "url": url, + "download_url": f"{url}&as_attachment=true", + "mime_type": upload_file.mime_type, + "created_by": upload_file.created_by, + "created_at": upload_file.created_at.timestamp(), + }, 200 + + +api.add_resource(UploadFileApi, "/datasets//documents//upload-file") diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index 6944c56bf8..0a4dfe1c10 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -127,7 +127,7 @@ class EducationActivateLimitError(BaseHTTPException): code = 429 -class CompilanceRateLimitError(BaseHTTPException): - error_code = "compilance_rate_limit" +class ComplianceRateLimitError(BaseHTTPException): + error_code = "compliance_rate_limit" description = "Rate limit exceeded for downloading compliance report." code = 429 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 6d9f794307..ad62bd6e08 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -58,21 +58,38 @@ class InstalledAppsListApi(Resource): # filter out apps that user doesn't have access to if FeatureService.get_system_features().webapp_auth.enabled: user_id = current_user.id - res = [] app_ids = [installed_app["app"].id for installed_app in installed_app_list] webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids) + + # Pre-filter out apps without setting or with sso_verified + filtered_installed_apps = [] + app_id_to_app_code = {} + for installed_app in installed_app_list: - webapp_setting = webapp_settings.get(installed_app["app"].id) - if not webapp_setting: + app_id = installed_app["app"].id + webapp_setting = webapp_settings.get(app_id) + if not webapp_setting or webapp_setting.access_mode == "sso_verified": continue - if webapp_setting.access_mode == "sso_verified": - continue - app_code = AppService.get_app_code_by_id(str(installed_app["app"].id)) - if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( - user_id=user_id, - app_code=app_code, - ): + app_code = AppService.get_app_code_by_id(str(app_id)) + app_id_to_app_code[app_id] = app_code + filtered_installed_apps.append(installed_app) + + app_codes = list(app_id_to_app_code.values()) + + # Batch permission check + permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps( + user_id=user_id, + app_codes=app_codes, + ) + + # Keep only allowed apps + res = [] + for installed_app in filtered_installed_apps: + app_id = installed_app["app"].id + app_code = app_id_to_app_code[app_id] + if permissions.get(app_code): res.append(installed_app) + installed_app_list = res logger.debug("installed_app_list: %s, user_id: %s", installed_app_list, user_id) diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 66b6214f82..256ff24b3b 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -49,7 +49,6 @@ class FileApi(Resource): @marshal_with(file_fields) @cloud_edition_billing_resource_check("documents") def post(self): - file = request.files["file"] source_str = request.form.get("source") source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None @@ -58,6 +57,7 @@ class FileApi(Resource): if len(request.files) > 1: raise TooManyFilesError() + file = request.files["file"] if not file.filename: raise FilenameNotExistsError diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 19999e7361..6012c9ecc8 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -191,9 +191,6 @@ class WebappLogoWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): - # get file from request - file = request.files["file"] - # check file if "file" not in request.files: raise NoFileUploadedError() @@ -201,6 +198,8 @@ class WebappLogoWorkspaceApi(Resource): if len(request.files) > 1: raise TooManyFilesError() + # get file from request + file = request.files["file"] if not file.filename: raise FilenameNotExistsError diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index d964e27819..b26f29d98d 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) from . import index -from .app import annotation, app, audio, completion, conversation, file, message, site, workflow +from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow from .dataset import dataset, document, hit_testing, metadata, segment, upload_file from .workspace import models diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index edc66cc5e9..ea57f04850 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -2,7 +2,7 @@ import logging from flask import request from flask_restful import Resource, reqparse -from werkzeug.exceptions import InternalServerError, NotFound +from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services from controllers.service_api import api @@ -30,6 +30,7 @@ from libs import helper from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService +from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -113,7 +114,7 @@ class ChatApi(Resource): parser.add_argument("conversation_id", type=uuid_value, location="json") parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") - + parser.add_argument("workflow_id", type=str, required=False, location="json") args = parser.parse_args() external_trace_id = get_external_trace_id(request) @@ -128,6 +129,12 @@ class ChatApi(Resource): ) return helper.compact_generate_response(response) + except WorkflowNotFoundError as ex: + raise NotFound(str(ex)) + except IsDraftWorkflowError as ex: + raise BadRequest(str(ex)) + except WorkflowIdFormatError as ex: + raise BadRequest(str(ex)) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 36a7905572..79c860e6b8 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,7 +1,9 @@ +import json + from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range from sqlalchemy.orm import Session -from werkzeug.exceptions import NotFound +from werkzeug.exceptions import BadRequest, NotFound import services from controllers.service_api import api @@ -15,6 +17,7 @@ from fields.conversation_fields import ( simple_conversation_fields, ) from fields.conversation_variable_fields import ( + conversation_variable_fields, conversation_variable_infinite_scroll_pagination_fields, ) from libs.helper import uuid_value @@ -120,7 +123,41 @@ class ConversationVariablesApi(Resource): raise NotFound("Conversation Not Exists.") +class ConversationVariableDetailApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + @marshal_with(conversation_variable_fields) + def put(self, app_model: App, end_user: EndUser, c_id, variable_id): + """Update a conversation variable's value""" + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: + raise NotChatAppError() + + conversation_id = str(c_id) + variable_id = str(variable_id) + + parser = reqparse.RequestParser() + parser.add_argument("value", required=True, location="json") + args = parser.parse_args() + + try: + return ConversationService.update_conversation_variable( + app_model, conversation_id, variable_id, end_user, json.loads(args["value"]) + ) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationVariableNotExistsError: + raise NotFound("Conversation Variable Not Exists.") + except services.errors.conversation.ConversationVariableTypeMismatchError as e: + raise BadRequest(str(e)) + + api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name") api.add_resource(ConversationApi, "/conversations") api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail") api.add_resource(ConversationVariablesApi, "/conversations//variables", endpoint="conversation_variables") +api.add_resource( + ConversationVariableDetailApi, + "/conversations//variables/", + endpoint="conversation_variable_detail", + methods=["PUT"], +) diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index ca91da80c1..ba705f71e2 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -107,3 +107,15 @@ class UnsupportedFileTypeError(BaseHTTPException): error_code = "unsupported_file_type" description = "File type not allowed." code = 415 + + +class FileNotFoundError(BaseHTTPException): + error_code = "file_not_found" + description = "The requested file was not found." + code = 404 + + +class FileAccessDeniedError(BaseHTTPException): + error_code = "file_access_denied" + description = "Access to the requested file is denied." + code = 403 diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index b0fd8e65ef..f09d07bcb6 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -20,18 +20,17 @@ class FileApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) @marshal_with(file_fields) def post(self, app_model: App, end_user: EndUser): - file = request.files["file"] - # check file if "file" not in request.files: raise NoFileUploadedError() - if not file.mimetype: - raise UnsupportedFileTypeError() - if len(request.files) > 1: raise TooManyFilesError() + file = request.files["file"] + if not file.mimetype: + raise UnsupportedFileTypeError() + if not file.filename: raise FilenameNotExistsError diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py new file mode 100644 index 0000000000..57141033d1 --- /dev/null +++ b/api/controllers/service_api/app/file_preview.py @@ -0,0 +1,186 @@ +import logging +from urllib.parse import quote + +from flask import Response +from flask_restful import Resource, reqparse + +from controllers.service_api import api +from controllers.service_api.app.error import ( + FileAccessDeniedError, + FileNotFoundError, +) +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import App, EndUser, Message, MessageFile, UploadFile + +logger = logging.getLogger(__name__) + + +class FilePreviewApi(Resource): + """ + Service API File Preview endpoint + + Provides secure file preview/download functionality for external API users. + Files can only be accessed if they belong to messages within the requesting app's context. + """ + + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) + def get(self, app_model: App, end_user: EndUser, file_id: str): + """ + Preview/Download a file that was uploaded via Service API + + Args: + app_model: The authenticated app model + end_user: The authenticated end user (optional) + file_id: UUID of the file to preview + + Query Parameters: + user: Optional user identifier + as_attachment: Boolean, whether to download as attachment (default: false) + + Returns: + Stream response with file content + + Raises: + FileNotFoundError: File does not exist + FileAccessDeniedError: File access denied (not owned by app) + """ + file_id = str(file_id) + + # Parse query parameters + parser = reqparse.RequestParser() + parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") + args = parser.parse_args() + + # Validate file ownership and get file objects + message_file, upload_file = self._validate_file_ownership(file_id, app_model.id) + + # Get file content generator + try: + generator = storage.load(upload_file.key, stream=True) + except Exception as e: + raise FileNotFoundError(f"Failed to load file content: {str(e)}") + + # Build response with appropriate headers + response = self._build_file_response(generator, upload_file, args["as_attachment"]) + + return response + + def _validate_file_ownership(self, file_id: str, app_id: str) -> tuple[MessageFile, UploadFile]: + """ + Validate that the file belongs to a message within the requesting app's context + + Security validations performed: + 1. File exists in MessageFile table (was used in a conversation) + 2. Message belongs to the requesting app + 3. UploadFile record exists and is accessible + 4. File tenant matches app tenant (additional security layer) + + Args: + file_id: UUID of the file to validate + app_id: UUID of the requesting app + + Returns: + Tuple of (MessageFile, UploadFile) if validation passes + + Raises: + FileNotFoundError: File or related records not found + FileAccessDeniedError: File does not belong to the app's context + """ + try: + # Input validation + if not file_id or not app_id: + raise FileAccessDeniedError("Invalid file or app identifier") + + # First, find the MessageFile that references this upload file + message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first() + + if not message_file: + raise FileNotFoundError("File not found in message context") + + # Get the message and verify it belongs to the requesting app + message = ( + db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first() + ) + + if not message: + raise FileAccessDeniedError("File access denied: not owned by requesting app") + + # Get the actual upload file record + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + + if not upload_file: + raise FileNotFoundError("Upload file record not found") + + # Additional security: verify tenant isolation + app = db.session.query(App).where(App.id == app_id).first() + if app and upload_file.tenant_id != app.tenant_id: + raise FileAccessDeniedError("File access denied: tenant mismatch") + + return message_file, upload_file + + except (FileNotFoundError, FileAccessDeniedError): + # Re-raise our custom exceptions + raise + except Exception as e: + # Log unexpected errors for debugging + logger.exception( + "Unexpected error during file ownership validation", + extra={"file_id": file_id, "app_id": app_id, "error": str(e)}, + ) + raise FileAccessDeniedError("File access validation failed") + + def _build_file_response(self, generator, upload_file: UploadFile, as_attachment: bool = False) -> Response: + """ + Build Flask Response object with appropriate headers for file streaming + + Args: + generator: File content generator from storage + upload_file: UploadFile database record + as_attachment: Whether to set Content-Disposition as attachment + + Returns: + Flask Response object with streaming file content + """ + response = Response( + generator, + mimetype=upload_file.mime_type, + direct_passthrough=True, + headers={}, + ) + + # Add Content-Length if known + if upload_file.size and upload_file.size > 0: + response.headers["Content-Length"] = str(upload_file.size) + + # Add Accept-Ranges header for audio/video files to support seeking + if upload_file.mime_type in [ + "audio/mpeg", + "audio/wav", + "audio/mp4", + "audio/ogg", + "audio/flac", + "audio/aac", + "video/mp4", + "video/webm", + "video/quicktime", + "audio/x-m4a", + ]: + response.headers["Accept-Ranges"] = "bytes" + + # Set Content-Disposition for downloads + if as_attachment and upload_file.name: + encoded_filename = quote(upload_file.name) + response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" + # Override content-type for downloads to force download + response.headers["Content-Type"] = "application/octet-stream" + + # Add caching headers for performance + response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour + + return response + + +# Register the API endpoint +api.add_resource(FilePreviewApi, "/files//preview") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 370ff911b4..cd8a5f03ac 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -5,7 +5,7 @@ from flask import request from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful.inputs import int_range from sqlalchemy.orm import Session, sessionmaker -from werkzeug.exceptions import InternalServerError +from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from controllers.service_api import api from controllers.service_api.app.error import ( @@ -34,6 +34,7 @@ from libs.helper import TimestampField from models.model import App, AppMode, EndUser from repositories.factory import DifyAPIRepositoryFactory from services.app_generate_service import AppGenerateService +from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService @@ -120,6 +121,59 @@ class WorkflowRunApi(Resource): raise InternalServerError() +class WorkflowRunByIdApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser, workflow_id: str): + """ + Run specific workflow by ID + """ + app_mode = AppMode.value_of(app_model.mode) + if app_mode != AppMode.WORKFLOW: + raise NotWorkflowAppError() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + args = parser.parse_args() + + # Add workflow_id to args for AppGenerateService + args["workflow_id"] = workflow_id + + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id + streaming = args.get("response_mode") == "streaming" + + try: + response = AppGenerateService.generate( + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming + ) + + return helper.compact_generate_response(response) + except WorkflowNotFoundError as ex: + raise NotFound(str(ex)) + except IsDraftWorkflowError as ex: + raise BadRequest(str(ex)) + except WorkflowIdFormatError as ex: + raise BadRequest(str(ex)) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + class WorkflowTaskStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id: str): @@ -193,5 +247,6 @@ class WorkflowAppLogApi(Resource): api.add_resource(WorkflowRunApi, "/workflows/run") api.add_resource(WorkflowRunDetailApi, "/workflows/run/") +api.add_resource(WorkflowRunByIdApi, "/workflows//run") api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") api.add_resource(WorkflowAppLogApi, "/workflows/logs") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index a499719fc3..29eef41253 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -29,7 +29,7 @@ def _validate_name(name): def _validate_description_length(description): - if len(description) > 400: + if description and len(description) > 400: raise ValueError("Description cannot exceed 400 characters.") return description @@ -87,7 +87,7 @@ class DatasetListApi(DatasetApiResource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ac85c0b38d..2955d5d20d 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -234,8 +234,6 @@ class DocumentAddByFileApi(DatasetApiResource): args["retrieval_model"].get("reranking_model").get("reranking_model_name"), ) - # save file info - file = request.files["file"] # check file if "file" not in request.files: raise NoFileUploadedError() @@ -243,6 +241,8 @@ class DocumentAddByFileApi(DatasetApiResource): if len(request.files) > 1: raise TooManyFilesError() + # save file info + file = request.files["file"] if not file.filename: raise FilenameNotExistsError @@ -358,39 +358,6 @@ class DocumentUpdateByFileApi(DatasetApiResource): return documents_and_batch_fields, 200 -class DocumentDeleteApi(DatasetApiResource): - @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id, dataset_id, document_id): - """Delete document.""" - document_id = str(document_id) - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) - - # get dataset info - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - - if not dataset: - raise ValueError("Dataset does not exist.") - - document = DocumentService.get_document(dataset.id, document_id) - - # 404 if document not found - if document is None: - raise NotFound("Document Not Exists.") - - # 403 if document is archived - if DocumentService.check_archived(document): - raise ArchivedDocumentImmutableError() - - try: - # delete document - DocumentService.delete_document(document) - except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError("Cannot delete document during indexing.") - - return 204 - - class DocumentListApi(DatasetApiResource): def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) @@ -473,7 +440,7 @@ class DocumentIndexingStatusApi(DatasetApiResource): return data -class DocumentDetailApi(DatasetApiResource): +class DocumentApi(DatasetApiResource): METADATA_CHOICES = {"all", "only", "without"} def get(self, tenant_id, dataset_id, document_id): @@ -567,6 +534,37 @@ class DocumentDetailApi(DatasetApiResource): return response + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def delete(self, tenant_id, dataset_id, document_id): + """Delete document.""" + document_id = str(document_id) + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + + # get dataset info + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + + if not dataset: + raise ValueError("Dataset does not exist.") + + document = DocumentService.get_document(dataset.id, document_id) + + # 404 if document not found + if document is None: + raise NotFound("Document Not Exists.") + + # 403 if document is archived + if DocumentService.check_archived(document): + raise ArchivedDocumentImmutableError() + + try: + # delete document + DocumentService.delete_document(document) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError("Cannot delete document during indexing.") + + return 204 + api.add_resource( DocumentAddByTextApi, @@ -588,7 +586,6 @@ api.add_resource( "/datasets//documents//update_by_file", "/datasets//documents//update-by-file", ) -api.add_resource(DocumentDeleteApi, "/datasets//documents/") +api.add_resource(DocumentApi, "/datasets//documents/") api.add_resource(DocumentListApi, "/datasets//documents") api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") -api.add_resource(DocumentDetailApi, "/datasets//documents/") diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index df06a73a85..8e9317606e 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -12,18 +12,17 @@ from services.file_service import FileService class FileApi(WebApiResource): @marshal_with(file_fields) def post(self, app_model, end_user): - file = request.files["file"] - source = request.form.get("source") - if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() + file = request.files["file"] if not file.filename: raise FilenameNotExistsError + source = request.form.get("source") if source not in ("datasets", None): source = None diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 75bd2f677a..0df0aa59b2 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -148,6 +148,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index f0e9425e3f..f3b9dbf758 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -23,6 +23,7 @@ from core.app.entities.task_entities import ( MessageFileStreamResponse, MessageReplaceStreamResponse, MessageStreamResponse, + StreamEvent, WorkflowTaskState, ) from core.llm_generator.llm_generator import LLMGenerator @@ -180,11 +181,15 @@ class MessageCycleManager: :param message_id: message id :return: """ + message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first() + event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE + return MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, answer=answer, from_variable_selector=from_variable_selector, + event=event_type, ) def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 2a0751a5ee..a5a6e62bd7 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -176,7 +176,7 @@ class ProviderConfig(BasicProviderConfig): scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None required: bool = False - default: Optional[Union[int, str]] = None + default: Optional[Union[int, str, float, bool]] = None options: Optional[list[Option]] = None label: Optional[I18nObject] = None help: Optional[I18nObject] = None diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index f8c050c2ac..770014aa72 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -32,7 +32,7 @@ def get_attr(*, file: File, attr: FileAttribute): case FileAttribute.TRANSFER_METHOD: return file.transfer_method.value case FileAttribute.URL: - return file.remote_url + return _to_url(file) case FileAttribute.EXTENSION: return file.extension case FileAttribute.RELATED_ID: diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 7ce124594a..91f17568b6 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -121,9 +121,8 @@ class TokenBufferMemory: curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) if curr_message_tokens > max_token_limit: - pruned_memory = [] while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: - pruned_memory.append(prompt_messages.pop(0)) + prompt_messages.pop(0) curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) return prompt_messages diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py similarity index 100% rename from api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py rename to api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index a20f2485c8..e7c90c1229 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -4,6 +4,7 @@ import logging import os from datetime import datetime, timedelta from typing import Any, Optional, Union, cast +from urllib.parse import urlparse from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry import trace @@ -40,8 +41,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra try: # Choose the appropriate exporter based on config type exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter] + + # Inspect the provided endpoint to determine its structure + parsed = urlparse(arize_phoenix_config.endpoint) + base_endpoint = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path.rstrip("/") + if isinstance(arize_phoenix_config, ArizeConfig): - arize_endpoint = f"{arize_phoenix_config.endpoint}/v1" + arize_endpoint = f"{base_endpoint}/v1" arize_headers = { "api_key": arize_phoenix_config.api_key or "", "space_id": arize_phoenix_config.space_id or "", @@ -53,7 +60,7 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra timeout=30, ) else: - phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces" + phoenix_endpoint = f"{base_endpoint}{path}/v1/traces" phoenix_headers = { "api_key": arize_phoenix_config.api_key or "", "authorization": f"Bearer {arize_phoenix_config.api_key or ''}", diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 626782cee5..851a77fbc1 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -87,7 +87,7 @@ class PhoenixConfig(BaseTracingConfig): @field_validator("endpoint") @classmethod def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com") + return validate_url_with_path(v, "https://app.phoenix.arize.com") class LangfuseConfig(BaseTracingConfig): diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index a607c76beb..7eb5da7e3a 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -322,7 +322,7 @@ class OpsTraceManager: :return: """ # auth check - if enabled == True: + if enabled: try: provider_config_map[tracing_provider] except KeyError: @@ -407,7 +407,6 @@ class TraceTask: def __init__( self, trace_type: Any, - trace_id: Optional[str] = None, message_id: Optional[str] = None, workflow_execution: Optional[WorkflowExecution] = None, conversation_id: Optional[str] = None, @@ -423,7 +422,7 @@ class TraceTask: self.timer = timer self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.app_id = None - + self.trace_id = None self.kwargs = kwargs external_trace_id = kwargs.get("external_trace_id") if external_trace_id: diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7375726fa9..6f32498b42 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -208,6 +208,7 @@ class BasePluginClient: except Exception: raise PluginDaemonInnerError(code=rep.code, message=rep.message) + logger.error("Error in stream reponse for plugin %s", rep.__dict__) self._handle_plugin_daemon_error(error.error_type, error.message) raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}") if rep.data is None: diff --git a/api/core/plugin/impl/exc.py b/api/core/plugin/impl/exc.py index 8b660c807d..8ecc2e2147 100644 --- a/api/core/plugin/impl/exc.py +++ b/api/core/plugin/impl/exc.py @@ -2,6 +2,8 @@ from collections.abc import Mapping from pydantic import TypeAdapter +from extensions.ext_logging import get_request_id + class PluginDaemonError(Exception): """Base class for all plugin daemon errors.""" @@ -11,7 +13,7 @@ class PluginDaemonError(Exception): def __str__(self) -> str: # returns the class name and description - return f"{self.__class__.__name__}: {self.description}" + return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}" class PluginDaemonInternalError(PluginDaemonError): diff --git a/api/core/rag/datasource/vdb/clickzetta/README.md b/api/core/rag/datasource/vdb/clickzetta/README.md new file mode 100644 index 0000000000..40229f8d44 --- /dev/null +++ b/api/core/rag/datasource/vdb/clickzetta/README.md @@ -0,0 +1,190 @@ +# Clickzetta Vector Database Integration + +This module provides integration with Clickzetta Lakehouse as a vector database for Dify. + +## Features + +- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type +- **Vector Search**: Efficient similarity search using HNSW algorithm +- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities +- **Hybrid Search**: Combine vector similarity and full-text search for better results +- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing +- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments + +## Configuration + +### Required Environment Variables + +All seven configuration parameters are required: + +```bash +# Authentication +CLICKZETTA_USERNAME=your_username +CLICKZETTA_PASSWORD=your_password + +# Instance configuration +CLICKZETTA_INSTANCE=your_instance_id +CLICKZETTA_SERVICE=api.clickzetta.com +CLICKZETTA_WORKSPACE=your_workspace +CLICKZETTA_VCLUSTER=your_vcluster +CLICKZETTA_SCHEMA=your_schema +``` + +### Optional Configuration + +```bash +# Batch processing +CLICKZETTA_BATCH_SIZE=100 + +# Full-text search configuration +CLICKZETTA_ENABLE_INVERTED_INDEX=true +CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode +CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart + +# Vector search configuration +CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance +``` + +## Usage + +### 1. Set Clickzetta as the Vector Store + +In your Dify configuration, set: + +```bash +VECTOR_STORE=clickzetta +``` + +### 2. Table Structure + +Clickzetta will automatically create tables with the following structure: + +```sql +CREATE TABLE ( + id STRING NOT NULL, + content STRING NOT NULL, + metadata JSON, + vector VECTOR(FLOAT, ) NOT NULL, + PRIMARY KEY (id) +); + +-- Vector index for similarity search +CREATE VECTOR INDEX idx__vec +ON TABLE .(vector) +PROPERTIES ( + "distance.function" = "cosine_distance", + "scalar.type" = "f32" +); + +-- Inverted index for full-text search (if enabled) +CREATE INVERTED INDEX idx__text +ON .(content) +PROPERTIES ( + "analyzer" = "chinese", + "mode" = "smart" +); +``` + +## Full-Text Search Capabilities + +Clickzetta supports advanced full-text search with multiple analyzers: + +### Analyzer Types + +1. **keyword**: No tokenization, treats the entire string as a single token + - Best for: Exact matching, IDs, codes + +2. **english**: Designed for English text + - Features: Recognizes ASCII letters and numbers, converts to lowercase + - Best for: English content + +3. **chinese**: Chinese text tokenizer + - Features: Recognizes Chinese and English characters, removes punctuation + - Best for: Chinese or mixed Chinese-English content + +4. **unicode**: Multi-language tokenizer based on Unicode + - Features: Recognizes text boundaries in multiple languages + - Best for: Multi-language content + +### Analyzer Modes + +- **max_word**: Fine-grained tokenization (more tokens) +- **smart**: Intelligent tokenization (balanced) + +### Full-Text Search Functions + +- `MATCH_ALL(column, query)`: All terms must be present +- `MATCH_ANY(column, query)`: At least one term must be present +- `MATCH_PHRASE(column, query)`: Exact phrase matching +- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching +- `MATCH_REGEXP(column, pattern)`: Regular expression matching + +## Performance Optimization + +### Vector Search + +1. **Adjust exploration factor** for accuracy vs speed trade-off: + ```sql + SET cz.vector.index.search.ef=64; + ``` + +2. **Use appropriate distance functions**: + - `cosine_distance`: Best for normalized embeddings (e.g., from language models) + - `l2_distance`: Best for raw feature vectors + +### Full-Text Search + +1. **Choose the right analyzer**: + - Use `keyword` for exact matching + - Use language-specific analyzers for better tokenization + +2. **Combine with vector search**: + - Pre-filter with full-text search for better performance + - Use hybrid search for improved relevance + +## Troubleshooting + +### Connection Issues + +1. Verify all 7 required configuration parameters are set +2. Check network connectivity to Clickzetta service +3. Ensure the user has proper permissions on the schema + +### Search Performance + +1. Verify vector index exists: + ```sql + SHOW INDEX FROM .; + ``` + +2. Check if vector index is being used: + ```sql + EXPLAIN SELECT ... WHERE l2_distance(...) < threshold; + ``` + Look for `vector_index_search_type` in the execution plan. + +### Full-Text Search Not Working + +1. Verify inverted index is created +2. Check analyzer configuration matches your content language +3. Use `TOKENIZE()` function to test tokenization: + ```sql + SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart')); + ``` + +## Limitations + +1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns +2. Full-text search relevance scores are not provided by Clickzetta +3. Inverted index creation may fail for very large existing tables (continue without error) +4. Index naming constraints: + - Index names must be unique within a schema + - Only one vector index can be created per column + - The implementation uses timestamps to ensure unique index names +5. A column can only have one vector index at a time + +## References + +- [Clickzetta Vector Search Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/vector-search.md) +- [Clickzetta Inverted Index Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/inverted-index.md) +- [Clickzetta SQL Functions](../../../../../../../yunqidoc/cn_markdown_20250526/sql_functions/) diff --git a/api/core/rag/datasource/vdb/clickzetta/__init__.py b/api/core/rag/datasource/vdb/clickzetta/__init__.py new file mode 100644 index 0000000000..9d41c5a57d --- /dev/null +++ b/api/core/rag/datasource/vdb/clickzetta/__init__.py @@ -0,0 +1 @@ +# Clickzetta Vector Database Integration for Dify diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py new file mode 100644 index 0000000000..50a395a373 --- /dev/null +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -0,0 +1,843 @@ +import json +import logging +import queue +import threading +import uuid +from typing import TYPE_CHECKING, Any, Optional + +import clickzetta # type: ignore +from pydantic import BaseModel, model_validator + +if TYPE_CHECKING: + from clickzetta import Connection + +from configs import dify_config +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +# ClickZetta Lakehouse Vector Database Configuration + + +class ClickzettaConfig(BaseModel): + """ + Configuration class for Clickzetta connection. + """ + + username: str + password: str + instance: str + service: str = "api.clickzetta.com" + workspace: str = "quick_start" + vcluster: str = "default_ap" + schema_name: str = "dify" # Renamed to avoid shadowing BaseModel.schema + # Advanced settings + batch_size: int = 20 # Reduced batch size to avoid large SQL statements + enable_inverted_index: bool = True # Enable inverted index for full-text search + analyzer_type: str = "chinese" # Analyzer type for full-text search: keyword, english, chinese, unicode + analyzer_mode: str = "smart" # Analyzer mode: max_word, smart + vector_distance_function: str = "cosine_distance" # l2_distance or cosine_distance + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + """ + Validate the configuration values. + """ + if not values.get("username"): + raise ValueError("config CLICKZETTA_USERNAME is required") + if not values.get("password"): + raise ValueError("config CLICKZETTA_PASSWORD is required") + if not values.get("instance"): + raise ValueError("config CLICKZETTA_INSTANCE is required") + if not values.get("service"): + raise ValueError("config CLICKZETTA_SERVICE is required") + if not values.get("workspace"): + raise ValueError("config CLICKZETTA_WORKSPACE is required") + if not values.get("vcluster"): + raise ValueError("config CLICKZETTA_VCLUSTER is required") + if not values.get("schema_name"): + raise ValueError("config CLICKZETTA_SCHEMA is required") + return values + + +class ClickzettaVector(BaseVector): + """ + Clickzetta vector storage implementation. + """ + + # Class-level write queue and lock for serializing writes + _write_queue: Optional[queue.Queue] = None + _write_thread: Optional[threading.Thread] = None + _write_lock = threading.Lock() + _shutdown = False + + def __init__(self, collection_name: str, config: ClickzettaConfig): + super().__init__(collection_name) + self._config = config + self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name + self._connection: Optional[Connection] = None + self._init_connection() + self._init_write_queue() + + def _init_connection(self): + """Initialize Clickzetta connection.""" + self._connection = clickzetta.connect( + username=self._config.username, + password=self._config.password, + instance=self._config.instance, + service=self._config.service, + workspace=self._config.workspace, + vcluster=self._config.vcluster, + schema=self._config.schema_name, + ) + + # Set session parameters for better string handling and performance optimization + if self._connection is not None: + with self._connection.cursor() as cursor: + # Use quote mode for string literal escaping to handle quotes better + cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") + logger.info("Set string literal escape mode to 'quote' for better quote handling") + + # Performance optimization hints for vector operations + self._set_performance_hints(cursor) + + def _set_performance_hints(self, cursor): + """Set ClickZetta performance optimization hints for vector operations.""" + try: + # Performance optimization hints for vector operations and query processing + performance_hints = [ + # Vector index optimization + "SET cz.storage.parquet.vector.index.read.memory.cache = true", + "SET cz.storage.parquet.vector.index.read.local.cache = false", + # Query optimization + "SET cz.sql.table.scan.push.down.filter = true", + "SET cz.sql.table.scan.enable.ensure.filter = true", + "SET cz.storage.always.prefetch.internal = true", + "SET cz.optimizer.generate.columns.always.valid = true", + "SET cz.sql.index.prewhere.enabled = true", + # Storage optimization + "SET cz.storage.parquet.enable.io.prefetch = false", + "SET cz.optimizer.enable.mv.rewrite = false", + "SET cz.sql.dump.as.lz4 = true", + "SET cz.optimizer.limited.optimization.naive.query = true", + "SET cz.sql.table.scan.enable.push.down.log = false", + "SET cz.storage.use.file.format.local.stats = false", + "SET cz.storage.local.file.object.cache.level = all", + # Job execution optimization + "SET cz.sql.job.fast.mode = true", + "SET cz.storage.parquet.non.contiguous.read = true", + "SET cz.sql.compaction.after.commit = true", + ] + + for hint in performance_hints: + cursor.execute(hint) + + logger.info( + "Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints) + ) + + except Exception: + # Catch any errors setting performance hints but continue with defaults + logger.exception("Failed to set some performance hints, continuing with default settings") + + @classmethod + def _init_write_queue(cls): + """Initialize the write queue and worker thread.""" + with cls._write_lock: + if cls._write_queue is None: + cls._write_queue = queue.Queue() + cls._write_thread = threading.Thread(target=cls._write_worker, daemon=True) + cls._write_thread.start() + logger.info("Started Clickzetta write worker thread") + + @classmethod + def _write_worker(cls): + """Worker thread that processes write tasks sequentially.""" + while not cls._shutdown: + try: + # Get task from queue with timeout + if cls._write_queue is not None: + task = cls._write_queue.get(timeout=1) + if task is None: # Shutdown signal + break + + # Execute the write task + func, args, kwargs, result_queue = task + try: + result = func(*args, **kwargs) + result_queue.put((True, result)) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Write task failed") + result_queue.put((False, e)) + finally: + cls._write_queue.task_done() + else: + break + except queue.Empty: + continue + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Write worker error") + + def _execute_write(self, func, *args, **kwargs): + """Execute a write operation through the queue.""" + if ClickzettaVector._write_queue is None: + raise RuntimeError("Write queue not initialized") + + result_queue: queue.Queue[tuple[bool, Any]] = queue.Queue() + ClickzettaVector._write_queue.put((func, args, kwargs, result_queue)) + + # Wait for result + success, result = result_queue.get() + if not success: + raise result + return result + + def get_type(self) -> str: + """Return the vector database type.""" + return "clickzetta" + + def _ensure_connection(self) -> "Connection": + """Ensure connection is available and return it.""" + if self._connection is None: + raise RuntimeError("Database connection not initialized") + return self._connection + + def _table_exists(self) -> bool: + """Check if the table exists.""" + try: + connection = self._ensure_connection() + with connection.cursor() as cursor: + cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}") + return True + except (RuntimeError, ValueError) as e: + if "table or view not found" in str(e).lower(): + return False + else: + # Re-raise if it's a different error + raise + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + """Create the collection and add initial documents.""" + # Execute table creation through write queue to avoid concurrent conflicts + self._execute_write(self._create_table_and_indexes, embeddings) + + # Add initial texts + if texts: + self.add_texts(texts, embeddings, **kwargs) + + def _create_table_and_indexes(self, embeddings: list[list[float]]): + """Create table and indexes (executed in write worker thread).""" + # Check if table already exists to avoid unnecessary index creation + if self._table_exists(): + logger.info("Table %s.%s already exists, skipping creation", self._config.schema_name, self._table_name) + return + + # Create table with vector and metadata columns + dimension = len(embeddings[0]) if embeddings else 768 + + create_table_sql = f""" + CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} ( + id STRING NOT NULL COMMENT 'Unique document identifier', + {Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval', + {Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes', + {Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT + 'High-dimensional embedding vector for semantic similarity search', + PRIMARY KEY (id) + ) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' + """ + + connection = self._ensure_connection() + with connection.cursor() as cursor: + cursor.execute(create_table_sql) + logger.info("Created table %s.%s", self._config.schema_name, self._table_name) + + # Create vector index + self._create_vector_index(cursor) + + # Create inverted index for full-text search if enabled + if self._config.enable_inverted_index: + self._create_inverted_index(cursor) + + def _create_vector_index(self, cursor): + """Create HNSW vector index for similarity search.""" + # Use a fixed index name based on table and column name + index_name = f"idx_{self._table_name}_vector" + + # First check if an index already exists on this column + try: + cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") + existing_indexes = cursor.fetchall() + for idx in existing_indexes: + # Check if vector index already exists on the embedding column + if Field.VECTOR.value in str(idx).lower(): + logger.info("Vector index already exists on column %s", Field.VECTOR.value) + return + except (RuntimeError, ValueError) as e: + logger.warning("Failed to check existing indexes: %s", e) + + index_sql = f""" + CREATE VECTOR INDEX IF NOT EXISTS {index_name} + ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value}) + PROPERTIES ( + "distance.function" = "{self._config.vector_distance_function}", + "scalar.type" = "f32", + "m" = "16", + "ef.construction" = "128" + ) + """ + try: + cursor.execute(index_sql) + logger.info("Created vector index: %s", index_name) + except (RuntimeError, ValueError) as e: + error_msg = str(e).lower() + if "already exists" in error_msg or "already has index" in error_msg or "with the same type" in error_msg: + logger.info("Vector index already exists: %s", e) + else: + logger.exception("Failed to create vector index") + raise + + def _create_inverted_index(self, cursor): + """Create inverted index for full-text search.""" + # Use a fixed index name based on table name to avoid duplicates + index_name = f"idx_{self._table_name}_text" + + # Check if an inverted index already exists on this column + try: + cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") + existing_indexes = cursor.fetchall() + for idx in existing_indexes: + idx_str = str(idx).lower() + # More precise check: look for inverted index specifically on the content column + if ( + "inverted" in idx_str + and Field.CONTENT_KEY.value.lower() in idx_str + and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str) + ): + logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx) + return + except (RuntimeError, ValueError) as e: + logger.warning("Failed to check existing indexes: %s", e) + + index_sql = f""" + CREATE INVERTED INDEX IF NOT EXISTS {index_name} + ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value}) + PROPERTIES ( + "analyzer" = "{self._config.analyzer_type}", + "mode" = "{self._config.analyzer_mode}" + ) + """ + try: + cursor.execute(index_sql) + logger.info("Created inverted index: %s", index_name) + except (RuntimeError, ValueError) as e: + error_msg = str(e).lower() + # Handle ClickZetta specific error messages + if ( + "already exists" in error_msg + or "already has index" in error_msg + or "with the same type" in error_msg + or "cannot create inverted index" in error_msg + ) and "already has index" in error_msg: + logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value) + # Try to get the existing index name for logging + try: + cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") + existing_indexes = cursor.fetchall() + for idx in existing_indexes: + if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower(): + logger.info("Found existing inverted index: %s", idx) + break + except (RuntimeError, ValueError): + pass + else: + logger.warning("Failed to create inverted index: %s", e) + # Continue without inverted index - full-text search will fall back to LIKE + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + """Add documents with embeddings to the collection.""" + if not documents: + return + + batch_size = self._config.batch_size + total_batches = (len(documents) + batch_size - 1) // batch_size + + for i in range(0, len(documents), batch_size): + batch_docs = documents[i : i + batch_size] + batch_embeddings = embeddings[i : i + batch_size] + + # Execute batch insert through write queue + self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) + + def _insert_batch( + self, + batch_docs: list[Document], + batch_embeddings: list[list[float]], + batch_index: int, + batch_size: int, + total_batches: int, + ): + """Insert a batch of documents using parameterized queries (executed in write worker thread).""" + if not batch_docs or not batch_embeddings: + logger.warning("Empty batch provided, skipping insertion") + return + + if len(batch_docs) != len(batch_embeddings): + logger.error("Mismatch between docs (%d) and embeddings (%d)", len(batch_docs), len(batch_embeddings)) + return + + # Prepare data for parameterized insertion + data_rows = [] + vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768 + + for doc, embedding in zip(batch_docs, batch_embeddings): + # Optimized: minimal checks for common case, fallback for edge cases + metadata = doc.metadata if doc.metadata else {} + + if not isinstance(metadata, dict): + metadata = {} + + doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))) + + # Fast path for JSON serialization + try: + metadata_json = json.dumps(metadata, ensure_ascii=True) + except (TypeError, ValueError): + logger.warning("JSON serialization failed, using empty dict") + metadata_json = "{}" + + content = doc.page_content or "" + + # According to ClickZetta docs, vector should be formatted as array string + # for external systems: '[1.0, 2.0, 3.0]' + vector_str = "[" + ",".join(map(str, embedding)) + "]" + data_rows.append([doc_id, content, metadata_json, vector_str]) + + # Check if we have any valid data to insert + if not data_rows: + logger.warning("No valid documents to insert in batch %d/%d", batch_index // batch_size + 1, total_batches) + return + + # Use parameterized INSERT with executemany for better performance and security + # Cast JSON and VECTOR in SQL, pass raw data as parameters + columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}" + insert_sql = ( + f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) " + f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" + ) + + connection = self._ensure_connection() + with connection.cursor() as cursor: + try: + # Set session-level hints for batch insert operations + # Note: executemany doesn't support hints parameter, so we set them as session variables + cursor.execute("SET cz.sql.job.fast.mode = true") + cursor.execute("SET cz.sql.compaction.after.commit = true") + cursor.execute("SET cz.storage.always.prefetch.internal = true") + + cursor.executemany(insert_sql, data_rows) + logger.info( + "Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)", + batch_index // batch_size + 1, + total_batches, + len(data_rows), + vector_dimension, + ) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows)) + logger.exception("SQL template: %s", insert_sql) + logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None") + raise + + def text_exists(self, id: str) -> bool: + """Check if a document exists by ID.""" + safe_id = self._safe_doc_id(id) + connection = self._ensure_connection() + with connection.cursor() as cursor: + cursor.execute( + f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", [safe_id] + ) + result = cursor.fetchone() + return result[0] > 0 if result else False + + def delete_by_ids(self, ids: list[str]) -> None: + """Delete documents by IDs.""" + if not ids: + return + + # Check if table exists before attempting delete + if not self._table_exists(): + logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name) + return + + # Execute delete through write queue + self._execute_write(self._delete_by_ids_impl, ids) + + def _delete_by_ids_impl(self, ids: list[str]) -> None: + """Implementation of delete by IDs (executed in write worker thread).""" + safe_ids = [self._safe_doc_id(id) for id in ids] + # Create properly escaped string literals for SQL + id_list = ",".join(f"'{id}'" for id in safe_ids) + sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})" + + connection = self._ensure_connection() + with connection.cursor() as cursor: + cursor.execute(sql) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + """Delete documents by metadata field.""" + # Check if table exists before attempting delete + if not self._table_exists(): + logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name) + return + + # Execute delete through write queue + self._execute_write(self._delete_by_metadata_field_impl, key, value) + + def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: + """Implementation of delete by metadata field (executed in write worker thread).""" + connection = self._ensure_connection() + with connection.cursor() as cursor: + # Using JSON path to filter with parameterized query + # Note: JSON path requires literal key name, cannot be parameterized + # Use json_extract_string function for ClickZetta compatibility + sql = ( + f"DELETE FROM {self._config.schema_name}.{self._table_name} " + f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?" + ) + cursor.execute(sql, [value]) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """Search for documents by vector similarity.""" + top_k = kwargs.get("top_k", 10) + score_threshold = kwargs.get("score_threshold", 0.0) + document_ids_filter = kwargs.get("document_ids_filter") + + # Handle filter parameter from canvas (workflow) + filter_param = kwargs.get("filter", {}) + + # Build filter clause + filter_clauses = [] + if document_ids_filter: + safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) + # Use json_extract_string function for ClickZetta compatibility + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + + # No need for dataset_id filter since each dataset has its own table + + # Add distance threshold based on distance function + vector_dimension = len(query_vector) + if self._config.vector_distance_function == "cosine_distance": + # For cosine distance, smaller is better (0 = identical, 2 = opposite) + distance_func = "COSINE_DISTANCE" + if score_threshold > 0: + query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" + filter_clauses.append( + f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}" + ) + else: + # For L2 distance, smaller is better + distance_func = "L2_DISTANCE" + if score_threshold > 0: + query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" + filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}") + + where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1" + + # Execute vector search query + query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" + search_sql = f""" + SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, + {distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance + FROM {self._config.schema_name}.{self._table_name} + WHERE {where_clause} + ORDER BY distance + LIMIT {top_k} + """ + + documents = [] + connection = self._ensure_connection() + with connection.cursor() as cursor: + # Use hints parameter for vector search optimization + search_hints = { + "hints": { + "sdk.job.timeout": 60, # Increase timeout for vector search + "cz.sql.job.fast.mode": True, + "cz.storage.parquet.vector.index.read.memory.cache": True, + } + } + cursor.execute(search_sql, parameters=search_hints) + results = cursor.fetchall() + + for row in results: + # Parse metadata from JSON string (may be double-encoded) + try: + if row[2]: + metadata = json.loads(row[2]) + + # If result is a string, it's double-encoded JSON - parse again + if isinstance(metadata, str): + metadata = json.loads(metadata) + + if not isinstance(metadata, dict): + metadata = {} + else: + metadata = {} + except (json.JSONDecodeError, TypeError) as e: + logger.exception("JSON parsing failed") + # Fallback: extract document_id with regex + import re + + doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) + metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} + + # Ensure required fields are set + metadata["doc_id"] = row[0] # segment id + + # Ensure document_id exists (critical for Dify's format_retrieval_documents) + if "document_id" not in metadata: + metadata["document_id"] = row[0] # fallback to segment id + + # Add score based on distance + if self._config.vector_distance_function == "cosine_distance": + metadata["score"] = 1 - (row[3] / 2) + else: + metadata["score"] = 1 / (1 + row[3]) + + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) + + return documents + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Search for documents using full-text search with inverted index.""" + if not self._config.enable_inverted_index: + logger.warning("Full-text search is not enabled. Enable inverted index in config.") + return [] + + top_k = kwargs.get("top_k", 10) + document_ids_filter = kwargs.get("document_ids_filter") + + # Handle filter parameter from canvas (workflow) + filter_param = kwargs.get("filter", {}) + + # Build filter clause + filter_clauses = [] + if document_ids_filter: + safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) + # Use json_extract_string function for ClickZetta compatibility + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + + # No need for dataset_id filter since each dataset has its own table + + # Use match_all function for full-text search + # match_all requires all terms to be present + # Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause + escaped_query = query.replace("'", "''") + filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')") + + where_clause = " AND ".join(filter_clauses) + + # Execute full-text search query + search_sql = f""" + SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value} + FROM {self._config.schema_name}.{self._table_name} + WHERE {where_clause} + LIMIT {top_k} + """ + + documents = [] + connection = self._ensure_connection() + with connection.cursor() as cursor: + try: + # Use hints parameter for full-text search optimization + fulltext_hints = { + "hints": { + "sdk.job.timeout": 30, # Timeout for full-text search + "cz.sql.job.fast.mode": True, + "cz.sql.index.prewhere.enabled": True, + } + } + cursor.execute(search_sql, parameters=fulltext_hints) + results = cursor.fetchall() + + for row in results: + # Parse metadata from JSON string (may be double-encoded) + try: + if row[2]: + metadata = json.loads(row[2]) + + # If result is a string, it's double-encoded JSON - parse again + if isinstance(metadata, str): + metadata = json.loads(metadata) + + if not isinstance(metadata, dict): + metadata = {} + else: + metadata = {} + except (json.JSONDecodeError, TypeError) as e: + logger.exception("JSON parsing failed") + # Fallback: extract document_id with regex + import re + + doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) + metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} + + # Ensure required fields are set + metadata["doc_id"] = row[0] # segment id + + # Ensure document_id exists (critical for Dify's format_retrieval_documents) + if "document_id" not in metadata: + metadata["document_id"] = row[0] # fallback to segment id + + # Add a relevance score for full-text search + metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Full-text search failed") + # Fallback to LIKE search if full-text search fails + return self._search_by_like(query, **kwargs) + + return documents + + def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]: + """Fallback search using LIKE operator.""" + top_k = kwargs.get("top_k", 10) + document_ids_filter = kwargs.get("document_ids_filter") + + # Handle filter parameter from canvas (workflow) + filter_param = kwargs.get("filter", {}) + + # Build filter clause + filter_clauses = [] + if document_ids_filter: + safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) + # Use json_extract_string function for ClickZetta compatibility + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + + # No need for dataset_id filter since each dataset has its own table + + # Use simple quote escaping for LIKE clause + escaped_query = query.replace("'", "''") + filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'") + where_clause = " AND ".join(filter_clauses) + + search_sql = f""" + SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value} + FROM {self._config.schema_name}.{self._table_name} + WHERE {where_clause} + LIMIT {top_k} + """ + + documents = [] + connection = self._ensure_connection() + with connection.cursor() as cursor: + # Use hints parameter for LIKE search optimization + like_hints = { + "hints": { + "sdk.job.timeout": 20, # Timeout for LIKE search + "cz.sql.job.fast.mode": True, + } + } + cursor.execute(search_sql, parameters=like_hints) + results = cursor.fetchall() + + for row in results: + # Parse metadata from JSON string (may be double-encoded) + try: + if row[2]: + metadata = json.loads(row[2]) + + # If result is a string, it's double-encoded JSON - parse again + if isinstance(metadata, str): + metadata = json.loads(metadata) + + if not isinstance(metadata, dict): + metadata = {} + else: + metadata = {} + except (json.JSONDecodeError, TypeError) as e: + logger.exception("JSON parsing failed") + # Fallback: extract document_id with regex + import re + + doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) + metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} + + # Ensure required fields are set + metadata["doc_id"] = row[0] # segment id + + # Ensure document_id exists (critical for Dify's format_retrieval_documents) + if "document_id" not in metadata: + metadata["document_id"] = row[0] # fallback to segment id + + metadata["score"] = 0.5 # Lower score for LIKE search + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) + + return documents + + def delete(self) -> None: + """Delete the entire collection.""" + connection = self._ensure_connection() + with connection.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") + + def _format_vector_simple(self, vector: list[float]) -> str: + """Simple vector formatting for SQL queries.""" + return ",".join(map(str, vector)) + + def _safe_doc_id(self, doc_id: str) -> str: + """Ensure doc_id is safe for SQL and doesn't contain special characters.""" + if not doc_id: + return str(uuid.uuid4()) + # Remove or replace potentially problematic characters + safe_id = str(doc_id) + # Only allow alphanumeric, hyphens, underscores + safe_id = "".join(c for c in safe_id if c.isalnum() or c in "-_") + if not safe_id: # If all characters were removed + return str(uuid.uuid4()) + return safe_id[:255] # Limit length + + +class ClickzettaVectorFactory(AbstractVectorFactory): + """Factory for creating Clickzetta vector instances.""" + + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: + """Initialize a Clickzetta vector instance.""" + # Get configuration from environment variables or dataset config + config = ClickzettaConfig( + username=dify_config.CLICKZETTA_USERNAME or "", + password=dify_config.CLICKZETTA_PASSWORD or "", + instance=dify_config.CLICKZETTA_INSTANCE or "", + service=dify_config.CLICKZETTA_SERVICE or "api.clickzetta.com", + workspace=dify_config.CLICKZETTA_WORKSPACE or "quick_start", + vcluster=dify_config.CLICKZETTA_VCLUSTER or "default_ap", + schema_name=dify_config.CLICKZETTA_SCHEMA or "dify", + batch_size=dify_config.CLICKZETTA_BATCH_SIZE or 100, + enable_inverted_index=dify_config.CLICKZETTA_ENABLE_INVERTED_INDEX or True, + analyzer_type=dify_config.CLICKZETTA_ANALYZER_TYPE or "chinese", + analyzer_mode=dify_config.CLICKZETTA_ANALYZER_MODE or "smart", + vector_distance_function=dify_config.CLICKZETTA_VECTOR_DISTANCE_FUNCTION or "cosine_distance", + ) + + # Use dataset collection name as table name + collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower() + + return ClickzettaVector(collection_name=collection_name, config=config) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 9dea050dc3..49c4b392fe 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -7,6 +7,7 @@ from urllib.parse import urlparse import requests from elasticsearch import Elasticsearch from flask import current_app +from packaging.version import parse as parse_version from pydantic import BaseModel, model_validator from core.rag.datasource.vdb.field import Field @@ -149,7 +150,7 @@ class ElasticSearchVector(BaseVector): return cast(str, info["version"]["number"]) def _check_version(self): - if self._version < "8.0.0": + if parse_version(self._version) < parse_version("8.0.0"): raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") def get_type(self) -> str: diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index 784e27fc7f..91d667ff2c 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -1,5 +1,6 @@ import json import logging +import math from typing import Any, Optional import tablestore # type: ignore @@ -22,6 +23,7 @@ class TableStoreConfig(BaseModel): access_key_secret: Optional[str] = None instance_name: Optional[str] = None endpoint: Optional[str] = None + normalize_full_text_bm25_score: Optional[bool] = False @model_validator(mode="before") @classmethod @@ -47,6 +49,7 @@ class TableStoreVector(BaseVector): config.access_key_secret, config.instance_name, ) + self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score self._table_name = f"{collection_name}" self._index_name = f"{collection_name}_idx" self._tags_field = f"{Field.METADATA_KEY.value}_tags" @@ -131,8 +134,8 @@ class TableStoreVector(BaseVector): filtered_list = None if document_ids_filter: filtered_list = ["document_id=" + item for item in document_ids_filter] - - return self._search_by_full_text(query, filtered_list, top_k) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._search_by_full_text(query, filtered_list, top_k, score_threshold) def delete(self) -> None: self._delete_table_if_exist() @@ -318,7 +321,19 @@ class TableStoreVector(BaseVector): documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents - def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]: + @staticmethod + def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float: + """ + Args: + score: BM25 search score. + k: decay factor, the larger the k, the steeper the low score end + """ + normalized_score = 1 - math.exp(-k * score) + return max(0.0, min(1.0, normalized_score)) + + def _search_by_full_text( + self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float + ) -> list[Document]: bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[]) bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) @@ -339,15 +354,27 @@ class TableStoreVector(BaseVector): documents = [] for search_hit in search_response.search_hits: + score = None + if self._normalize_full_text_bm25_score: + score = self._normalize_score_exp_decay(search_hit.score) + + # skip when score is below threshold and use normalize score + if score and score <= score_threshold: + continue + ots_column_map = {} for col in search_hit.row[1]: ots_column_map[col[0]] = col[1] - vector_str = ots_column_map.get(Field.VECTOR.value) metadata_str = ots_column_map.get(Field.METADATA_KEY.value) - vector = json.loads(vector_str) if vector_str else None metadata = json.loads(metadata_str) if metadata_str else {} + vector_str = ots_column_map.get(Field.VECTOR.value) + vector = json.loads(vector_str) if vector_str else None + + if score: + metadata["score"] = score + documents.append( Document( page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", @@ -355,6 +382,8 @@ class TableStoreVector(BaseVector): metadata=metadata, ) ) + if self._normalize_full_text_bm25_score: + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents @@ -375,5 +404,6 @@ class TableStoreVectorFactory(AbstractVectorFactory): instance_name=dify_config.TABLESTORE_INSTANCE_NAME, access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID, access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET, + normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE, ), ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 3aa4b67a78..0517d5a6d1 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -246,6 +246,10 @@ class TencentVector(BaseVector): return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + document_ids_filter = kwargs.get("document_ids_filter") + filter = None + if document_ids_filter: + filter = Filter(Filter.In("metadata.document_id", document_ids_filter)) if not self._enable_hybrid_search: return [] res = self._client.hybrid_search( @@ -269,6 +273,7 @@ class TencentVector(BaseVector): ), retrieve_vector=False, limit=kwargs.get("top_k", 4), + filter=filter, ) score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(res, score_threshold) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 43c49ed4b3..eef03ce412 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -172,6 +172,10 @@ class Vector: from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory return MatrixoneVectorFactory + case VectorType.CLICKZETTA: + from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory + + return ClickzettaVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 0d70947b72..a415142196 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -30,3 +30,4 @@ class VectorType(StrEnum): TABLESTORE = "tablestore" HUAWEI_CLOUD = "huawei_cloud" MATRIXONE = "matrixone" + CLICKZETTA = "clickzetta" diff --git a/api/core/rag/entities/metadata_entities.py b/api/core/rag/entities/metadata_entities.py index 6ef932ad22..1f054bccdb 100644 --- a/api/core/rag/entities/metadata_entities.py +++ b/api/core/rag/entities/metadata_entities.py @@ -13,6 +13,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 875626eb34..17f4d1af2d 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,5 +1,6 @@ import json import logging +import operator from typing import Any, Optional, cast import requests @@ -130,13 +131,15 @@ class NotionExtractor(BaseExtractor): data[property_name] = value row_dict = {k: v for k, v in data.items() if v} row_content = "" - for key, value in row_dict.items(): + for key, value in sorted(row_dict.items(), key=operator.itemgetter(0)): if isinstance(value, dict): value_dict = {k: v for k, v in value.items() if v} value_content = "".join(f"{k}:{v} " for k, v in value_dict.items()) row_content = row_content + f"{key}:{value_content}\n" else: row_content = row_content + f"{key}:{value}\n" + if "url" in result: + row_content = row_content + f"Row Page URL:{result.get('url', '')}\n" database_content.append(row_content) has_more = response_data.get("has_more", False) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 14363de7d4..0eff7c186a 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -62,7 +62,7 @@ class WordExtractor(BaseExtractor): def extract(self) -> list[Document]: """Load given path as single page.""" - content = self.parse_docx(self.file_path, "storage") + content = self.parse_docx(self.file_path) return [ Document( page_content=content, @@ -189,23 +189,8 @@ class WordExtractor(BaseExtractor): paragraph_content.append(run.text) return "".join(paragraph_content).strip() - def _parse_paragraph(self, paragraph, image_map): - paragraph_content = [] - for run in paragraph.runs: - if run.element.xpath(".//a:blip"): - for blip in run.element.xpath(".//a:blip"): - embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") - if embed_id: - rel_target = run.part.rels[embed_id].target_ref - if rel_target in image_map: - paragraph_content.append(image_map[rel_target]) - if run.text.strip(): - paragraph_content.append(run.text.strip()) - return " ".join(paragraph_content) if paragraph_content else "" - - def parse_docx(self, docx_path, image_folder): + def parse_docx(self, docx_path): doc = DocxDocument(docx_path) - os.makedirs(image_folder, exist_ok=True) content = [] diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index bcaf299892..d654463be9 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -5,14 +5,13 @@ from __future__ import annotations from typing import Any, Optional from core.model_manager import ModelInstance -from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer from core.rag.splitter.text_splitter import ( TS, Collection, Literal, RecursiveCharacterTextSplitter, Set, - TokenTextSplitter, Union, ) @@ -45,14 +44,6 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return [len(text) for text in texts] - if issubclass(cls, TokenTextSplitter): - extra_kwargs = { - "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", - "allowed_special": allowed_special, - "disallowed_special": disallowed_special, - } - kwargs = {**kwargs, **extra_kwargs} - return cls(length_function=_character_encoder, **kwargs) diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 35e16b5c8f..d6961cdaa4 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -20,9 +20,6 @@ class Tool(ABC): The base class of a tool """ - entity: ToolEntity - runtime: ToolRuntime - def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: self.entity = entity self.runtime = runtime diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index 1639dd687f..a8fd6ec2cd 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -37,12 +37,12 @@ class LocaltimeToTimestampTool(BuiltinTool): @staticmethod def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None: try: - if local_tz is None: - local_tz = datetime.now().astimezone().tzinfo - if isinstance(local_tz, str): - local_tz = pytz.timezone(local_tz) local_time = datetime.strptime(localtime, time_format) - localtime = local_tz.localize(local_time) # type: ignore + if local_tz is None: + localtime = local_time.astimezone() # type: ignore + elif isinstance(local_tz, str): + local_tz = pytz.timezone(local_tz) + localtime = local_tz.localize(local_time) # type: ignore timestamp = int(localtime.timestamp()) # type: ignore return timestamp except Exception as e: diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index f9b776b3b9..91316b859a 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -27,7 +27,7 @@ class TimezoneConversionTool(BuiltinTool): target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore if not target_time: yield self.create_text_message( - f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}" + f"Invalid datetime and timezone: {current_time},{current_timezone},{target_timezone}" ) return diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 724a2291c6..84efefba07 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -20,8 +20,6 @@ class BuiltinTool(Tool): :param meta: the meta data of a tool call processing """ - provider: str - def __init__(self, provider: str, **kwargs): super().__init__(**kwargs) self.provider = provider diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 10653b9948..e112de9578 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -1,7 +1,8 @@ import json from collections.abc import Generator +from dataclasses import dataclass from os import getenv -from typing import Any, Optional +from typing import Any, Optional, Union from urllib.parse import urlencode import httpx @@ -20,10 +21,21 @@ API_TOOL_DEFAULT_TIMEOUT = ( ) -class ApiTool(Tool): - api_bundle: ApiToolBundle - provider_id: str +@dataclass +class ParsedResponse: + """Represents a parsed HTTP response with type information""" + content: Union[str, dict] + is_json: bool + + def to_string(self) -> str: + """Convert response to string format for credential validation""" + if isinstance(self.content, dict): + return json.dumps(self.content, ensure_ascii=False) + return str(self.content) + + +class ApiTool(Tool): """ Api tool """ @@ -61,7 +73,9 @@ class ApiTool(Tool): response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) # validate response - return self.validate_and_parse_response(response) + parsed_response = self.validate_and_parse_response(response) + # For credential validation, always return as string + return parsed_response.to_string() def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.API @@ -115,23 +129,36 @@ class ApiTool(Tool): return headers - def validate_and_parse_response(self, response: httpx.Response) -> str: + def validate_and_parse_response(self, response: httpx.Response) -> ParsedResponse: """ - validate the response + validate the response and return parsed content with type information + + :return: ParsedResponse with content and is_json flag """ if isinstance(response, httpx.Response): if response.status_code >= 400: raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") if not response.content: - return "Empty response from the tool, please check your parameters and try again." + return ParsedResponse( + "Empty response from the tool, please check your parameters and try again.", False + ) + + # Check content type + content_type = response.headers.get("content-type", "").lower() + is_json_content_type = "application/json" in content_type + + # Try to parse as JSON try: - response = response.json() - try: - return json.dumps(response, ensure_ascii=False) - except Exception: - return json.dumps(response) + json_response = response.json() + # If content-type indicates JSON, return as JSON object + if is_json_content_type: + return ParsedResponse(json_response, True) + else: + # If content-type doesn't indicate JSON, treat as text regardless of content + return ParsedResponse(response.text, False) except Exception: - return response.text + # Not valid JSON, return as text + return ParsedResponse(response.text, False) else: raise ValueError(f"Invalid response type {type(response)}") @@ -372,7 +399,14 @@ class ApiTool(Tool): response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters) # validate response - response = self.validate_and_parse_response(response) + parsed_response = self.validate_and_parse_response(response) - # assemble invoke message - yield self.create_text_message(response) + # assemble invoke message based on response type + if parsed_response.is_json and isinstance(parsed_response.content, dict): + yield self.create_json_message(parsed_response.content) + else: + # Convert to string if needed and create text message + text_response = ( + parsed_response.content if isinstance(parsed_response.content, str) else str(parsed_response.content) + ) + yield self.create_text_message(text_response) diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index d1bacbc735..8ebbb6b0fe 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -8,23 +8,16 @@ from core.mcp.mcp_client import MCPClient from core.mcp.types import ImageContent, TextContent from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime -from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType class MCPTool(Tool): - tenant_id: str - icon: str - runtime_parameters: Optional[list[ToolParameter]] - server_url: str - provider_id: str - def __init__( self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str ) -> None: super().__init__(entity, runtime) self.tenant_id = tenant_id self.icon = icon - self.runtime_parameters = None self.server_url = server_url self.provider_id = provider_id diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index aef2677c36..db38c10e81 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -9,11 +9,6 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too class PluginTool(Tool): - tenant_id: str - icon: str - plugin_unique_identifier: str - runtime_parameters: Optional[list[ToolParameter]] - def __init__( self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str ) -> None: @@ -21,7 +16,7 @@ class PluginTool(Tool): self.tenant_id = tenant_id self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - self.runtime_parameters = None + self.runtime_parameters: Optional[list[ToolParameter]] = None def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.PLUGIN diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 178f2b9689..83444c02d8 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -29,7 +29,7 @@ from core.tools.errors import ( ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) -from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.enums import CreatorUserRole @@ -247,7 +247,8 @@ class ToolEngine: ) elif response.type == ToolInvokeMessage.MessageType.JSON: result += json.dumps( - cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False + safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object), + ensure_ascii=False, ) else: result += str(response.message) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 1bb4cfa4cd..2737bcfb16 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -7,6 +7,7 @@ from os import listdir, path from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +import sqlalchemy as sa from pydantic import TypeAdapter from yarl import URL @@ -616,7 +617,7 @@ class ToolManager: WHERE tenant_id = :tenant_id ORDER BY tenant_id, provider, is_default DESC, created_at DESC """ - ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] + ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() @classmethod diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index ec0575f6c3..d58807e29f 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -20,8 +20,6 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas class DatasetRetrieverTool(Tool): - retrieval_tool: DatasetRetrieverBaseTool - def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None: super().__init__(entity, runtime) self.retrieval_tool = retrieval_tool diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 9998de0465..ac12d83ef2 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,7 +1,14 @@ import logging from collections.abc import Generator +from datetime import date, datetime +from decimal import Decimal from mimetypes import guess_extension -from typing import Optional +from typing import Optional, cast +from uuid import UUID + +import numpy as np +import pytz +from flask_login import current_user from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage @@ -10,6 +17,41 @@ from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) +def safe_json_value(v): + if isinstance(v, datetime): + tz_name = getattr(current_user, "timezone", None) if current_user is not None else None + if not tz_name: + tz_name = "UTC" + return v.astimezone(pytz.timezone(tz_name)).isoformat() + elif isinstance(v, date): + return v.isoformat() + elif isinstance(v, UUID): + return str(v) + elif isinstance(v, Decimal): + return float(v) + elif isinstance(v, bytes): + try: + return v.decode("utf-8") + except UnicodeDecodeError: + return v.hex() + elif isinstance(v, memoryview): + return v.tobytes().hex() + elif isinstance(v, np.ndarray): + return v.tolist() + elif isinstance(v, dict): + return safe_json_dict(v) + elif isinstance(v, list | tuple | set): + return [safe_json_value(i) for i in v] + else: + return v + + +def safe_json_dict(d): + if not isinstance(d, dict): + raise TypeError("safe_json_dict() expects a dictionary (dict) as input") + return {k: safe_json_value(v) for k, v in d.items()} + + class ToolFileMessageTransformer: @classmethod def transform_tool_invoke_messages( @@ -113,6 +155,12 @@ class ToolFileMessageTransformer: ) else: yield message + + elif message.type == ToolInvokeMessage.MessageType.JSON: + if isinstance(message.message, ToolInvokeMessage.JsonMessage): + json_msg = cast(ToolInvokeMessage.JsonMessage, message.message) + json_msg.json_object = safe_json_value(json_msg.json_object) + yield message else: yield message diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index db6b84082f..6824e5e0e8 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -25,15 +25,6 @@ logger = logging.getLogger(__name__) class WorkflowTool(Tool): - workflow_app_id: str - version: str - workflow_entities: dict[str, Any] - workflow_call_depth: int - thread_pool_id: Optional[str] = None - workflow_as_tool_id: str - - label: str - """ Workflow tool. """ diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 13274f4e0e..a99f5eece3 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -119,6 +119,13 @@ class ObjectSegment(Segment): class ArraySegment(Segment): + @property + def text(self) -> str: + # Return empty string for empty arrays instead of "[]" + if not self.value: + return "" + return super().text + @property def markdown(self) -> str: items = [] @@ -155,6 +162,9 @@ class ArrayStringSegment(ArraySegment): @property def text(self) -> str: + # Return empty string for empty arrays instead of "[]" + if not self.value: + return "" return json.dumps(self.value, ensure_ascii=False) diff --git a/api/core/variables/types.py b/api/core/variables/types.py index e79b2410bf..d28fb11401 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -109,7 +109,7 @@ class SegmentType(StrEnum): elif array_validation == ArrayValidation.FIRST: return element_type.is_valid(value[0]) else: - return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value) + return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value) def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool: """ @@ -152,7 +152,7 @@ class SegmentType(StrEnum): _ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { - # ARRAY_ANY does not have correpond element type. + # ARRAY_ANY does not have corresponding element type. SegmentType.ARRAY_STRING: SegmentType.STRING, SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index f3061f7d96..23512c8ce4 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -597,7 +597,7 @@ def _extract_text_from_vtt(vtt_bytes: bytes) -> str: for i in range(1, len(raw_results)): spk, txt = raw_results[i] - if spk == None: + if spk is None: merged_results.append((None, current_text)) continue diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index fe103c7117..e45f63bbec 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -91,7 +91,7 @@ class Executor: self.auth = node_data.authorization self.timeout = timeout self.ssl_verify = node_data.ssl_verify - self.params = [] + self.params = None self.headers = {} self.content = None self.files = None @@ -139,7 +139,8 @@ class Executor: (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) ) - self.params = result + if result: + self.params = result def _init_headers(self): """ @@ -277,6 +278,22 @@ class Executor: elif self.auth.config.type == "custom": headers[authorization.config.header] = authorization.config.api_key or "" + # Handle Content-Type for multipart/form-data requests + # Fix for issue #22880: Missing boundary when using multipart/form-data + body = self.node_data.body + if body and body.type == "form-data": + # For multipart/form-data with files, let httpx handle the boundary automatically + # by not setting Content-Type header when files are present + if not self.files or all(f[0] == "__multipart_placeholder__" for f in self.files): + # Only set Content-Type when there are no actual files + # This ensures httpx generates the correct boundary + if "content-type" not in (k.lower() for k in headers): + headers["Content-Type"] = "multipart/form-data" + elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE: + # Set Content-Type for other body types + if "content-type" not in (k.lower() for k in headers): + headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] + return headers def _validate_and_parse_response(self, response: httpx.Response) -> Response: @@ -384,15 +401,24 @@ class Executor: # '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file. # This prevents logging meaningless placeholder entries. if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files): - for key, (filename, content, mime_type) in self.files: + for file_entry in self.files: + # file_entry should be (key, (filename, content, mime_type)), but handle edge cases + if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2: + continue # skip malformed entries + key = file_entry[0] + content = file_entry[1][1] body_string += f"--{boundary}\r\n" body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - # decode content - try: - body_string += content.decode("utf-8") - except UnicodeDecodeError: - # fix: decode binary content - pass + # decode content safely + if isinstance(content, bytes): + try: + body_string += content.decode("utf-8") + except UnicodeDecodeError: + body_string += content.decode("utf-8", errors="replace") + elif isinstance(content, str): + body_string += content + else: + body_string += f"[Unsupported content type: {type(content).__name__}]" body_string += "\r\n" body_string += f"--{boundary}--\r\n" elif self.node_data.body: diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index f1767bdf9e..b71271abeb 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -74,6 +74,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index e041e217ca..7303b68501 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -602,6 +602,28 @@ class KnowledgeRetrievalNode(BaseNode): **{key: metadata_name, key_value: f"%{value}"} ) ) + case "in": + if isinstance(value, str): + escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")] + escaped_value_str = ",".join(escaped_values) + else: + escaped_value_str = str(value) + filters.append( + (text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params( + **{key: metadata_name, key_value: escaped_value_str} + ) + ) + case "not in": + if isinstance(value, str): + escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")] + escaped_value_str = ",".join(escaped_values) + else: + escaped_value_str = str(value) + filters.append( + (text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params( + **{key: metadata_name, key_value: escaped_value_str} + ) + ) case "=" | "is": if isinstance(value, str): filters.append(Document.doc_metadata[metadata_name] == f'"{value}"') diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 90a0397b67..dfc2a0000b 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -3,7 +3,7 @@ import io import json import logging from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager @@ -33,12 +33,10 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.entities.model_entities import ( - AIModelEntity, ModelFeature, ModelPropertyKey, ModelType, ) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -1006,21 +1004,6 @@ class LLMNode(BaseNode): ) return saved_file - def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: - """ - Fetch model schema - """ - model_name = self._node_data.model.name - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name - ) - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - model_credentials = model_instance.credentials - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) - return model_schema - @staticmethod def fetch_structured_output_schema( *, diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 4c8e13de70..df89b2476d 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -318,6 +318,33 @@ class ToolNode(BaseNode): json.append(message.message.json_object) elif message.type == ToolInvokeMessage.MessageType.LINK: assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + else: + transfer_method = FileTransferMethod.TOOL_FILE + + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileError(f"Tool file {tool_file_id} does not exist") + + mapping = { + "tool_file_id": tool_file_id, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "transfer_method": transfer_method, + "url": message.message.text, + } + + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + files.append(file) + stream_text = f"Link: {message.message.text}\n" text += stream_text yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"]) diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index b027a165f9..a8f025a750 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -136,6 +136,8 @@ def init_app(app: DifyApp): from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter from opentelemetry.instrumentation.celery import CeleryInstrumentor from opentelemetry.instrumentation.flask import FlaskInstrumentor + from opentelemetry.instrumentation.redis import RedisInstrumentor + from opentelemetry.instrumentation.requests import RequestsInstrumentor from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider from opentelemetry.propagate import set_global_textmap @@ -234,6 +236,8 @@ def init_app(app: DifyApp): CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument() instrument_exception_logging() init_sqlalchemy_instrumentor(app) + RedisInstrumentor().instrument() + RequestsInstrumentor().instrument() atexit.register(shutdown_tracer) diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index bd35278544..d13393dd14 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -69,6 +69,19 @@ class Storage: from extensions.storage.supabase_storage import SupabaseStorage return SupabaseStorage + case StorageType.CLICKZETTA_VOLUME: + from extensions.storage.clickzetta_volume.clickzetta_volume_storage import ( + ClickZettaVolumeConfig, + ClickZettaVolumeStorage, + ) + + def create_clickzetta_volume_storage(): + # ClickZettaVolumeConfig will automatically read from environment variables + # and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set + volume_config = ClickZettaVolumeConfig() + return ClickZettaVolumeStorage(volume_config) + + return create_clickzetta_volume_storage case _: raise ValueError(f"unsupported storage type {storage_type}") diff --git a/api/extensions/storage/clickzetta_volume/__init__.py b/api/extensions/storage/clickzetta_volume/__init__.py new file mode 100644 index 0000000000..8a1588034b --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/__init__.py @@ -0,0 +1,5 @@ +"""ClickZetta Volume storage implementation.""" + +from .clickzetta_volume_storage import ClickZettaVolumeStorage + +__all__ = ["ClickZettaVolumeStorage"] diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py new file mode 100644 index 0000000000..09ab37f42e --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -0,0 +1,530 @@ +"""ClickZetta Volume Storage Implementation + +This module provides storage backend using ClickZetta Volume functionality. +Supports Table Volume, User Volume, and External Volume types. +""" + +import logging +import os +import tempfile +from collections.abc import Generator +from io import BytesIO +from pathlib import Path +from typing import Optional + +import clickzetta # type: ignore[import] +from pydantic import BaseModel, model_validator + +from extensions.storage.base_storage import BaseStorage + +from .volume_permissions import VolumePermissionManager, check_volume_permission + +logger = logging.getLogger(__name__) + + +class ClickZettaVolumeConfig(BaseModel): + """Configuration for ClickZetta Volume storage.""" + + username: str = "" + password: str = "" + instance: str = "" + service: str = "api.clickzetta.com" + workspace: str = "quick_start" + vcluster: str = "default_ap" + schema_name: str = "dify" + volume_type: str = "table" # table|user|external + volume_name: Optional[str] = None # For external volumes + table_prefix: str = "dataset_" # Prefix for table volume names + dify_prefix: str = "dify_km" # Directory prefix for User Volume + permission_check: bool = True # Enable/disable permission checking + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + """Validate the configuration values. + + This method will first try to use CLICKZETTA_VOLUME_* environment variables, + then fall back to CLICKZETTA_* environment variables (for vector DB config). + """ + import os + + # Helper function to get environment variable with fallback + def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str: + # First try CLICKZETTA_VOLUME_* specific config + volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", "")) + if volume_value: + return str(volume_value) + + # Then try environment variables + volume_env = os.getenv(volume_key) + if volume_env: + return volume_env + + # Fall back to existing CLICKZETTA_* config + fallback_env = os.getenv(fallback_key) + if fallback_env: + return fallback_env + + return default or "" + + # Apply environment variables with fallback to existing CLICKZETTA_* config + values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME")) + values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD")) + values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE")) + values.setdefault( + "service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com") + ) + values.setdefault( + "workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start") + ) + values.setdefault( + "vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap") + ) + values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify")) + + # Volume-specific configurations (no fallback to vector DB config) + values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table")) + values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME")) + values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_")) + values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km")) + # 暂时禁用权限检查功能,直接设置为false + values.setdefault("permission_check", False) + + # Validate required fields + if not values.get("username"): + raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required") + if not values.get("password"): + raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required") + if not values.get("instance"): + raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required") + + # Validate volume type + volume_type = values["volume_type"] + if volume_type not in ["table", "user", "external"]: + raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external") + + if volume_type == "external" and not values.get("volume_name"): + raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type") + + return values + + +class ClickZettaVolumeStorage(BaseStorage): + """ClickZetta Volume storage implementation.""" + + def __init__(self, config: ClickZettaVolumeConfig): + """Initialize ClickZetta Volume storage. + + Args: + config: ClickZetta Volume configuration + """ + self._config = config + self._connection = None + self._permission_manager: VolumePermissionManager | None = None + self._init_connection() + self._init_permission_manager() + + logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type) + + def _init_connection(self): + """Initialize ClickZetta connection.""" + try: + self._connection = clickzetta.connect( + username=self._config.username, + password=self._config.password, + instance=self._config.instance, + service=self._config.service, + workspace=self._config.workspace, + vcluster=self._config.vcluster, + schema=self._config.schema_name, + ) + logger.debug("ClickZetta connection established") + except Exception as e: + logger.exception("Failed to connect to ClickZetta") + raise + + def _init_permission_manager(self): + """Initialize permission manager.""" + try: + self._permission_manager = VolumePermissionManager( + self._connection, self._config.volume_type, self._config.volume_name + ) + logger.debug("Permission manager initialized") + except Exception as e: + logger.exception("Failed to initialize permission manager") + raise + + def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str: + """Get the appropriate volume path based on volume type.""" + if self._config.volume_type == "user": + # Add dify prefix for User Volume to organize files + return f"{self._config.dify_prefix}/{filename}" + elif self._config.volume_type == "table": + # Check if this should use User Volume (special directories) + if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + # Use User Volume with dify prefix for special directories + return f"{self._config.dify_prefix}/{filename}" + + if dataset_id: + return f"{self._config.table_prefix}{dataset_id}/{filename}" + else: + # Extract dataset_id from filename if not provided + # Format: dataset_id/filename + if "/" in filename: + return filename + else: + raise ValueError("dataset_id is required for table volume or filename must include dataset_id/") + elif self._config.volume_type == "external": + return filename + else: + raise ValueError(f"Unsupported volume type: {self._config.volume_type}") + + def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str: + """Get SQL prefix for volume operations.""" + if self._config.volume_type == "user": + return "USER VOLUME" + elif self._config.volume_type == "table": + # For Dify's current file storage pattern, most files are stored in + # paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext" + # These should use USER VOLUME for better compatibility + if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + return "USER VOLUME" + + # Only use TABLE VOLUME for actual dataset-specific paths + # like "dataset_12345/file.pdf" or paths with dataset_ prefix + if dataset_id: + table_name = f"{self._config.table_prefix}{dataset_id}" + else: + # Default table name for generic operations + table_name = "default_dataset" + return f"TABLE VOLUME {table_name}" + elif self._config.volume_type == "external": + return f"VOLUME {self._config.volume_name}" + else: + raise ValueError(f"Unsupported volume type: {self._config.volume_type}") + + def _execute_sql(self, sql: str, fetch: bool = False): + """Execute SQL command.""" + try: + if self._connection is None: + raise RuntimeError("Connection not initialized") + with self._connection.cursor() as cursor: + cursor.execute(sql) + if fetch: + return cursor.fetchall() + return None + except Exception as e: + logger.exception("SQL execution failed: %s", sql) + raise + + def _ensure_table_volume_exists(self, dataset_id: str) -> None: + """Ensure table volume exists for the given dataset_id.""" + if self._config.volume_type != "table" or not dataset_id: + return + + # Skip for upload_files and other special directories that use USER VOLUME + if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + return + + table_name = f"{self._config.table_prefix}{dataset_id}" + + try: + # Check if table exists + check_sql = f"SHOW TABLES LIKE '{table_name}'" + result = self._execute_sql(check_sql, fetch=True) + + if not result: + # Create table with volume + create_sql = f""" + CREATE TABLE {table_name} ( + id INT PRIMARY KEY AUTO_INCREMENT, + filename VARCHAR(255) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + INDEX idx_filename (filename) + ) WITH VOLUME + """ + self._execute_sql(create_sql) + logger.info("Created table volume: %s", table_name) + + except Exception as e: + logger.warning("Failed to create table volume %s: %s", table_name, e) + # Don't raise exception, let the operation continue + # The table might exist but not be visible due to permissions + + def save(self, filename: str, data: bytes) -> None: + """Save data to ClickZetta Volume. + + Args: + filename: File path in volume + data: File content as bytes + """ + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + # Ensure table volume exists (for table volumes) + if dataset_id: + self._ensure_table_volume_exists(dataset_id) + + # Check permissions (if enabled) + if self._config.permission_check: + # Skip permission check for special directories that use USER VOLUME + if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + if self._permission_manager is not None: + check_volume_permission(self._permission_manager, "save", dataset_id) + + # Write data to temporary file + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file.write(data) + temp_file_path = temp_file.name + + try: + # Upload to volume + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'" + else: + sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'" + + self._execute_sql(sql) + logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path) + finally: + # Clean up temporary file + Path(temp_file_path).unlink(missing_ok=True) + + def load_once(self, filename: str) -> bytes: + """Load file content from ClickZetta Volume. + + Args: + filename: File path in volume + + Returns: + File content as bytes + """ + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + # Check permissions (if enabled) + if self._config.permission_check: + # Skip permission check for special directories that use USER VOLUME + if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + if self._permission_manager is not None: + check_volume_permission(self._permission_manager, "load_once", dataset_id) + + # Download to temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'" + else: + sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'" + + self._execute_sql(sql) + + # Find the downloaded file (may be in subdirectories) + downloaded_file = None + for root, dirs, files in os.walk(temp_dir): + for file in files: + if file == filename or file == os.path.basename(filename): + downloaded_file = Path(root) / file + break + if downloaded_file: + break + + if not downloaded_file or not downloaded_file.exists(): + raise FileNotFoundError(f"Downloaded file not found: {filename}") + + content = downloaded_file.read_bytes() + + logger.debug("File %s loaded from ClickZetta Volume", filename) + return content + + def load_stream(self, filename: str) -> Generator: + """Load file as stream from ClickZetta Volume. + + Args: + filename: File path in volume + + Yields: + File content chunks + """ + content = self.load_once(filename) + batch_size = 4096 + stream = BytesIO(content) + + while chunk := stream.read(batch_size): + yield chunk + + logger.debug("File %s loaded as stream from ClickZetta Volume", filename) + + def download(self, filename: str, target_filepath: str): + """Download file from ClickZetta Volume to local path. + + Args: + filename: File path in volume + target_filepath: Local target file path + """ + content = self.load_once(filename) + + with Path(target_filepath).open("wb") as f: + f.write(content) + + logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath) + + def exists(self, filename: str) -> bool: + """Check if file exists in ClickZetta Volume. + + Args: + filename: File path in volume + + Returns: + True if file exists, False otherwise + """ + try: + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'" + else: + sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'" + + rows = self._execute_sql(sql, fetch=True) + + exists = len(rows) > 0 + logger.debug("File %s exists check: %s", filename, exists) + return exists + except Exception as e: + logger.warning("Error checking file existence for %s: %s", filename, e) + return False + + def delete(self, filename: str): + """Delete file from ClickZetta Volume. + + Args: + filename: File path in volume + """ + if not self.exists(filename): + logger.debug("File %s not found, skip delete", filename) + return + + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"REMOVE {volume_prefix} FILE '{volume_path}'" + else: + sql = f"REMOVE {volume_prefix} FILE '{filename}'" + + self._execute_sql(sql) + + logger.debug("File %s deleted from ClickZetta Volume", filename) + + def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: + """Scan files and directories in ClickZetta Volume. + + Args: + path: Path to scan (dataset_id for table volumes) + files: Include files in results + directories: Include directories in results + + Returns: + List of file/directory paths + """ + try: + # For table volumes, path is treated as dataset_id + dataset_id = None + if self._config.volume_type == "table": + dataset_id = path + path = "" # Root of the table volume + + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # For User Volume, add dify prefix to path + if volume_prefix == "USER VOLUME": + if path: + scan_path = f"{self._config.dify_prefix}/{path}" + sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'" + else: + sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'" + else: + if path: + sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'" + else: + sql = f"LIST {volume_prefix}" + + rows = self._execute_sql(sql, fetch=True) + + result = [] + for row in rows: + file_path = row[0] # relative_path column + + # For User Volume, remove dify prefix from results + dify_prefix_with_slash = f"{self._config.dify_prefix}/" + if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash): + file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix + + if files and not file_path.endswith("/") or directories and file_path.endswith("/"): + result.append(file_path) + + logger.debug("Scanned %d items in path %s", len(result), path) + return result + + except Exception as e: + logger.exception("Error scanning path %s", path) + return [] diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py new file mode 100644 index 0000000000..d5d04f121b --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -0,0 +1,516 @@ +"""ClickZetta Volume文件生命周期管理 + +该模块提供文件版本控制、自动清理、备份和恢复等生命周期管理功能。 +支持知识库文件的完整生命周期管理。 +""" + +import json +import logging +from dataclasses import asdict, dataclass +from datetime import datetime, timedelta +from enum import Enum +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +class FileStatus(Enum): + """文件状态枚举""" + + ACTIVE = "active" # 活跃状态 + ARCHIVED = "archived" # 已归档 + DELETED = "deleted" # 已删除(软删除) + BACKUP = "backup" # 备份文件 + + +@dataclass +class FileMetadata: + """文件元数据""" + + filename: str + size: int | None + created_at: datetime + modified_at: datetime + version: int | None + status: FileStatus + checksum: Optional[str] = None + tags: Optional[dict[str, str]] = None + parent_version: Optional[int] = None + + def to_dict(self) -> dict: + """转换为字典格式""" + data = asdict(self) + data["created_at"] = self.created_at.isoformat() + data["modified_at"] = self.modified_at.isoformat() + data["status"] = self.status.value + return data + + @classmethod + def from_dict(cls, data: dict) -> "FileMetadata": + """从字典创建实例""" + data = data.copy() + data["created_at"] = datetime.fromisoformat(data["created_at"]) + data["modified_at"] = datetime.fromisoformat(data["modified_at"]) + data["status"] = FileStatus(data["status"]) + return cls(**data) + + +class FileLifecycleManager: + """文件生命周期管理器""" + + def __init__(self, storage, dataset_id: Optional[str] = None): + """初始化生命周期管理器 + + Args: + storage: ClickZetta Volume存储实例 + dataset_id: 数据集ID(用于Table Volume) + """ + self._storage = storage + self._dataset_id = dataset_id + self._metadata_file = ".dify_file_metadata.json" + self._version_prefix = ".versions/" + self._backup_prefix = ".backups/" + self._deleted_prefix = ".deleted/" + + # 获取权限管理器(如果存在) + self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None) + + def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata: + """保存文件并管理生命周期 + + Args: + filename: 文件名 + data: 文件内容 + tags: 文件标签 + + Returns: + 文件元数据 + """ + # 权限检查 + if not self._check_permission(filename, "save"): + from .volume_permissions import VolumePermissionError + + raise VolumePermissionError( + f"Permission denied for lifecycle save operation on file: {filename}", + operation="save", + volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"), + dataset_id=self._dataset_id, + ) + + try: + # 1. 检查是否存在旧版本 + metadata_dict = self._load_metadata() + current_metadata = metadata_dict.get(filename) + + # 2. 如果存在旧版本,创建版本备份 + if current_metadata: + self._create_version_backup(filename, current_metadata) + + # 3. 计算文件信息 + now = datetime.now() + checksum = self._calculate_checksum(data) + new_version = (current_metadata["version"] + 1) if current_metadata else 1 + + # 4. 保存新文件 + self._storage.save(filename, data) + + # 5. 创建元数据 + created_at = now + parent_version = None + + if current_metadata: + # 如果created_at是字符串,转换为datetime + if isinstance(current_metadata["created_at"], str): + created_at = datetime.fromisoformat(current_metadata["created_at"]) + else: + created_at = current_metadata["created_at"] + parent_version = current_metadata["version"] + + file_metadata = FileMetadata( + filename=filename, + size=len(data), + created_at=created_at, + modified_at=now, + version=new_version, + status=FileStatus.ACTIVE, + checksum=checksum, + tags=tags or {}, + parent_version=parent_version, + ) + + # 6. 更新元数据 + metadata_dict[filename] = file_metadata.to_dict() + self._save_metadata(metadata_dict) + + logger.info("File %s saved with lifecycle management, version %s", filename, new_version) + return file_metadata + + except Exception as e: + logger.exception("Failed to save file with lifecycle") + raise + + def get_file_metadata(self, filename: str) -> Optional[FileMetadata]: + """获取文件元数据 + + Args: + filename: 文件名 + + Returns: + 文件元数据,如果不存在返回None + """ + try: + metadata_dict = self._load_metadata() + if filename in metadata_dict: + return FileMetadata.from_dict(metadata_dict[filename]) + return None + except Exception as e: + logger.exception("Failed to get file metadata for %s", filename) + return None + + def list_file_versions(self, filename: str) -> list[FileMetadata]: + """列出文件的所有版本 + + Args: + filename: 文件名 + + Returns: + 文件版本列表,按版本号排序 + """ + try: + versions = [] + + # 获取当前版本 + current_metadata = self.get_file_metadata(filename) + if current_metadata: + versions.append(current_metadata) + + # 获取历史版本 + version_pattern = f"{self._version_prefix}{filename}.v*" + try: + version_files = self._storage.scan(self._dataset_id or "", files=True) + for file_path in version_files: + if file_path.startswith(f"{self._version_prefix}{filename}.v"): + # 解析版本号 + version_str = file_path.split(".v")[-1].split(".")[0] + try: + version_num = int(version_str) + # 这里简化处理,实际应该从版本文件中读取元数据 + # 暂时创建基本的元数据信息 + except ValueError: + continue + except: + # 如果无法扫描版本文件,只返回当前版本 + pass + + return sorted(versions, key=lambda x: x.version or 0, reverse=True) + + except Exception as e: + logger.exception("Failed to list file versions for %s", filename) + return [] + + def restore_version(self, filename: str, version: int) -> bool: + """恢复文件到指定版本 + + Args: + filename: 文件名 + version: 要恢复的版本号 + + Returns: + 恢复是否成功 + """ + try: + version_filename = f"{self._version_prefix}{filename}.v{version}" + + # 检查版本文件是否存在 + if not self._storage.exists(version_filename): + logger.warning("Version %s of %s not found", version, filename) + return False + + # 读取版本文件内容 + version_data = self._storage.load_once(version_filename) + + # 保存当前版本为备份 + current_metadata = self.get_file_metadata(filename) + if current_metadata: + self._create_version_backup(filename, current_metadata.to_dict()) + + # 恢复文件 + self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)}) + return True + + except Exception as e: + logger.exception("Failed to restore %s to version %s", filename, version) + return False + + def archive_file(self, filename: str) -> bool: + """归档文件 + + Args: + filename: 文件名 + + Returns: + 归档是否成功 + """ + # 权限检查 + if not self._check_permission(filename, "archive"): + logger.warning("Permission denied for archive operation on file: %s", filename) + return False + + try: + # 更新文件状态为归档 + metadata_dict = self._load_metadata() + if filename not in metadata_dict: + logger.warning("File %s not found in metadata", filename) + return False + + metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value + metadata_dict[filename]["modified_at"] = datetime.now().isoformat() + + self._save_metadata(metadata_dict) + + logger.info("File %s archived successfully", filename) + return True + + except Exception as e: + logger.exception("Failed to archive file %s", filename) + return False + + def soft_delete_file(self, filename: str) -> bool: + """软删除文件(移动到删除目录) + + Args: + filename: 文件名 + + Returns: + 删除是否成功 + """ + # 权限检查 + if not self._check_permission(filename, "delete"): + logger.warning("Permission denied for soft delete operation on file: %s", filename) + return False + + try: + # 检查文件是否存在 + if not self._storage.exists(filename): + logger.warning("File %s not found", filename) + return False + + # 读取文件内容 + file_data = self._storage.load_once(filename) + + # 移动到删除目录 + deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}" + self._storage.save(deleted_filename, file_data) + + # 删除原文件 + self._storage.delete(filename) + + # 更新元数据 + metadata_dict = self._load_metadata() + if filename in metadata_dict: + metadata_dict[filename]["status"] = FileStatus.DELETED.value + metadata_dict[filename]["modified_at"] = datetime.now().isoformat() + self._save_metadata(metadata_dict) + + logger.info("File %s soft deleted successfully", filename) + return True + + except Exception as e: + logger.exception("Failed to soft delete file %s", filename) + return False + + def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int: + """清理旧版本文件 + + Args: + max_versions: 保留的最大版本数 + max_age_days: 版本文件的最大保留天数 + + Returns: + 清理的文件数量 + """ + try: + cleaned_count = 0 + cutoff_date = datetime.now() - timedelta(days=max_age_days) + + # 获取所有版本文件 + try: + all_files = self._storage.scan(self._dataset_id or "", files=True) + version_files = [f for f in all_files if f.startswith(self._version_prefix)] + + # 按文件分组 + file_versions: dict[str, list[tuple[int, str]]] = {} + for version_file in version_files: + # 解析文件名和版本 + parts = version_file[len(self._version_prefix) :].split(".v") + if len(parts) >= 2: + base_filename = parts[0] + version_part = parts[1].split(".")[0] + try: + version_num = int(version_part) + if base_filename not in file_versions: + file_versions[base_filename] = [] + file_versions[base_filename].append((version_num, version_file)) + except ValueError: + continue + + # 清理每个文件的旧版本 + for base_filename, versions in file_versions.items(): + # 按版本号排序 + versions.sort(key=lambda x: x[0], reverse=True) + + # 保留最新的max_versions个版本,删除其余的 + if len(versions) > max_versions: + to_delete = versions[max_versions:] + for version_num, version_file in to_delete: + self._storage.delete(version_file) + cleaned_count += 1 + logger.debug("Cleaned old version: %s", version_file) + + logger.info("Cleaned %d old version files", cleaned_count) + + except Exception as e: + logger.warning("Could not scan for version files: %s", e) + + return cleaned_count + + except Exception as e: + logger.exception("Failed to cleanup old versions") + return 0 + + def get_storage_statistics(self) -> dict[str, Any]: + """获取存储统计信息 + + Returns: + 存储统计字典 + """ + try: + metadata_dict = self._load_metadata() + + stats: dict[str, Any] = { + "total_files": len(metadata_dict), + "active_files": 0, + "archived_files": 0, + "deleted_files": 0, + "total_size": 0, + "versions_count": 0, + "oldest_file": None, + "newest_file": None, + } + + oldest_date = None + newest_date = None + + for filename, metadata in metadata_dict.items(): + file_meta = FileMetadata.from_dict(metadata) + + # 统计文件状态 + if file_meta.status == FileStatus.ACTIVE: + stats["active_files"] = (stats["active_files"] or 0) + 1 + elif file_meta.status == FileStatus.ARCHIVED: + stats["archived_files"] = (stats["archived_files"] or 0) + 1 + elif file_meta.status == FileStatus.DELETED: + stats["deleted_files"] = (stats["deleted_files"] or 0) + 1 + + # 统计大小 + stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0) + + # 统计版本 + stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0) + + # 找出最新和最旧的文件 + if oldest_date is None or file_meta.created_at < oldest_date: + oldest_date = file_meta.created_at + stats["oldest_file"] = filename + + if newest_date is None or file_meta.modified_at > newest_date: + newest_date = file_meta.modified_at + stats["newest_file"] = filename + + return stats + + except Exception as e: + logger.exception("Failed to get storage statistics") + return {} + + def _create_version_backup(self, filename: str, metadata: dict): + """创建版本备份""" + try: + # 读取当前文件内容 + current_data = self._storage.load_once(filename) + + # 保存为版本文件 + version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}" + self._storage.save(version_filename, current_data) + + logger.debug("Created version backup: %s", version_filename) + + except Exception as e: + logger.warning("Failed to create version backup for %s: %s", filename, e) + + def _load_metadata(self) -> dict[str, Any]: + """加载元数据文件""" + try: + if self._storage.exists(self._metadata_file): + metadata_content = self._storage.load_once(self._metadata_file) + result = json.loads(metadata_content.decode("utf-8")) + return dict(result) if result else {} + else: + return {} + except Exception as e: + logger.warning("Failed to load metadata: %s", e) + return {} + + def _save_metadata(self, metadata_dict: dict): + """保存元数据文件""" + try: + metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False) + self._storage.save(self._metadata_file, metadata_content.encode("utf-8")) + logger.debug("Metadata saved successfully") + except Exception as e: + logger.exception("Failed to save metadata") + raise + + def _calculate_checksum(self, data: bytes) -> str: + """计算文件校验和""" + import hashlib + + return hashlib.md5(data).hexdigest() + + def _check_permission(self, filename: str, operation: str) -> bool: + """检查文件操作权限 + + Args: + filename: 文件名 + operation: 操作类型 + + Returns: + True if permission granted, False otherwise + """ + # 如果没有权限管理器,默认允许 + if not self._permission_manager: + return True + + try: + # 根据操作类型映射到权限 + operation_mapping = { + "save": "save", + "load": "load_once", + "delete": "delete", + "archive": "delete", # 归档需要删除权限 + "restore": "save", # 恢复需要写权限 + "cleanup": "delete", # 清理需要删除权限 + "read": "load_once", + "write": "save", + } + + mapped_operation = operation_mapping.get(operation, operation) + + # 检查权限 + result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id) + return bool(result) + + except Exception as e: + logger.exception("Permission check failed for %s operation %s", filename, operation) + # 安全默认:权限检查失败时拒绝访问 + return False diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py new file mode 100644 index 0000000000..4801df5102 --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -0,0 +1,646 @@ +"""ClickZetta Volume权限管理机制 + +该模块提供Volume权限检查、验证和管理功能。 +根据ClickZetta的权限模型,不同Volume类型有不同的权限要求。 +""" + +import logging +from enum import Enum +from typing import Optional + +logger = logging.getLogger(__name__) + + +class VolumePermission(Enum): + """Volume权限类型枚举""" + + READ = "SELECT" # 对应ClickZetta的SELECT权限 + WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限 + LIST = "SELECT" # 列出文件需要SELECT权限 + DELETE = "INSERT,UPDATE,DELETE" # 删除文件需要写权限 + USAGE = "USAGE" # External Volume需要的基本权限 + + +class VolumePermissionManager: + """Volume权限管理器""" + + def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None): + """初始化权限管理器 + + Args: + connection_or_config: ClickZetta连接对象或配置字典 + volume_type: Volume类型 (user|table|external) + volume_name: Volume名称 (用于external volume) + """ + # 支持两种初始化方式:连接对象或配置字典 + if isinstance(connection_or_config, dict): + # 从配置字典创建连接 + import clickzetta # type: ignore[import-untyped] + + config = connection_or_config + self._connection = clickzetta.connect( + username=config.get("username"), + password=config.get("password"), + instance=config.get("instance"), + service=config.get("service"), + workspace=config.get("workspace"), + vcluster=config.get("vcluster"), + schema=config.get("schema") or config.get("database"), + ) + self._volume_type = config.get("volume_type", volume_type) + self._volume_name = config.get("volume_name", volume_name) + else: + # 直接使用连接对象 + self._connection = connection_or_config + self._volume_type = volume_type + self._volume_name = volume_name + + if not self._connection: + raise ValueError("Valid connection or config is required") + if not self._volume_type: + raise ValueError("volume_type is required") + + self._permission_cache: dict[str, set[str]] = {} + self._current_username = None # 将从连接中获取当前用户名 + + def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool: + """检查用户是否有执行特定操作的权限 + + Args: + operation: 要执行的操作类型 + dataset_id: 数据集ID (用于table volume) + + Returns: + True if user has permission, False otherwise + """ + try: + if self._volume_type == "user": + return self._check_user_volume_permission(operation) + elif self._volume_type == "table": + return self._check_table_volume_permission(operation, dataset_id) + elif self._volume_type == "external": + return self._check_external_volume_permission(operation) + else: + logger.warning("Unknown volume type: %s", self._volume_type) + return False + + except Exception as e: + logger.exception("Permission check failed") + return False + + def _check_user_volume_permission(self, operation: VolumePermission) -> bool: + """检查User Volume权限 + + User Volume权限规则: + - 用户对自己的User Volume有全部权限 + - 只要用户能够连接到ClickZetta,就默认具有User Volume的基本权限 + - 更注重连接身份验证,而不是复杂的权限检查 + """ + try: + # 获取当前用户名 + current_user = self._get_current_username() + + # 检查基本连接状态 + with self._connection.cursor() as cursor: + # 简单的连接测试,如果能执行查询说明用户有基本权限 + cursor.execute("SELECT 1") + result = cursor.fetchone() + + if result: + logger.debug( + "User Volume permission check for %s, operation %s: granted (basic connection verified)", + current_user, + operation.name, + ) + return True + else: + logger.warning( + "User Volume permission check failed: cannot verify basic connection for %s", current_user + ) + return False + + except Exception as e: + logger.exception("User Volume permission check failed") + # 对于User Volume,如果权限检查失败,可能是配置问题,给出更友好的错误提示 + logger.info("User Volume permission check failed, but permission checking is disabled in this version") + return False + + def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool: + """检查Table Volume权限 + + Table Volume权限规则: + - Table Volume权限继承对应表的权限 + - SELECT权限 -> 可以READ/LIST文件 + - INSERT,UPDATE,DELETE权限 -> 可以WRITE/DELETE文件 + """ + if not dataset_id: + logger.warning("dataset_id is required for table volume permission check") + return False + + table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id + + try: + # 检查表权限 + permissions = self._get_table_permissions(table_name) + required_permissions = set(operation.value.split(",")) + + # 检查是否有所需的所有权限 + has_permission = required_permissions.issubset(permissions) + + logger.debug( + "Table Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s", + table_name, + operation.name, + required_permissions, + permissions, + has_permission, + ) + + return has_permission + + except Exception as e: + logger.exception("Table volume permission check failed for %s", table_name) + return False + + def _check_external_volume_permission(self, operation: VolumePermission) -> bool: + """检查External Volume权限 + + External Volume权限规则: + - 尝试获取对External Volume的权限 + - 如果权限检查失败,进行备选验证 + - 对于开发环境,提供更宽松的权限检查 + """ + if not self._volume_name: + logger.warning("volume_name is required for external volume permission check") + return False + + try: + # 检查External Volume权限 + permissions = self._get_external_volume_permissions(self._volume_name) + + # External Volume权限映射:根据操作类型确定所需权限 + required_permissions = set() + + if operation in [VolumePermission.READ, VolumePermission.LIST]: + required_permissions.add("read") + elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]: + required_permissions.add("write") + + # 检查是否有所需的所有权限 + has_permission = required_permissions.issubset(permissions) + + logger.debug( + "External Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s", + self._volume_name, + operation.name, + required_permissions, + permissions, + has_permission, + ) + + # 如果权限检查失败,尝试备选验证 + if not has_permission: + logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name) + + # 备选验证:尝试列出Volume来验证基本访问权限 + try: + with self._connection.cursor() as cursor: + cursor.execute("SHOW VOLUMES") + volumes = cursor.fetchall() + for volume in volumes: + if len(volume) > 0 and volume[0] == self._volume_name: + logger.info("Fallback verification successful for %s", self._volume_name) + return True + except Exception as fallback_e: + logger.warning("Fallback verification failed for %s: %s", self._volume_name, fallback_e) + + return has_permission + + except Exception as e: + logger.exception("External volume permission check failed for %s", self._volume_name) + logger.info("External Volume permission check failed, but permission checking is disabled in this version") + return False + + def _get_table_permissions(self, table_name: str) -> set[str]: + """获取用户对指定表的权限 + + Args: + table_name: 表名 + + Returns: + 用户对该表的权限集合 + """ + cache_key = f"table:{table_name}" + + if cache_key in self._permission_cache: + return self._permission_cache[cache_key] + + permissions = set() + + try: + with self._connection.cursor() as cursor: + # 使用正确的ClickZetta语法检查当前用户权限 + cursor.execute("SHOW GRANTS") + grants = cursor.fetchall() + + # 解析权限结果,查找对该表的权限 + for grant in grants: + if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) + privilege = grant[0].upper() + object_type = grant[1].upper() if len(grant) > 1 else "" + object_name = grant[2] if len(grant) > 2 else "" + + # 检查是否是对该表的权限 + if ( + object_type == "TABLE" + and object_name == table_name + or object_type == "SCHEMA" + and object_name in table_name + ): + if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: + if privilege == "ALL": + permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) + else: + permissions.add(privilege) + + # 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限 + if not permissions: + try: + cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1") + permissions.add("SELECT") + except Exception: + logger.debug("Cannot query table %s, no SELECT permission", table_name) + + except Exception as e: + logger.warning("Could not check table permissions for %s: %s", table_name, e) + # 安全默认:权限检查失败时拒绝访问 + pass + + # 缓存权限信息 + self._permission_cache[cache_key] = permissions + return permissions + + def _get_current_username(self) -> str: + """获取当前用户名""" + if self._current_username: + return self._current_username + + try: + with self._connection.cursor() as cursor: + cursor.execute("SELECT CURRENT_USER()") + result = cursor.fetchone() + if result: + self._current_username = result[0] + return str(self._current_username) + except Exception as e: + logger.exception("Failed to get current username") + + return "unknown" + + def _get_user_permissions(self, username: str) -> set[str]: + """获取用户的基本权限集合""" + cache_key = f"user_permissions:{username}" + + if cache_key in self._permission_cache: + return self._permission_cache[cache_key] + + permissions = set() + + try: + with self._connection.cursor() as cursor: + # 使用正确的ClickZetta语法检查当前用户权限 + cursor.execute("SHOW GRANTS") + grants = cursor.fetchall() + + # 解析权限结果,查找用户的基本权限 + for grant in grants: + if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) + privilege = grant[0].upper() + object_type = grant[1].upper() if len(grant) > 1 else "" + + # 收集所有相关权限 + if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: + if privilege == "ALL": + permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) + else: + permissions.add(privilege) + + except Exception as e: + logger.warning("Could not check user permissions for %s: %s", username, e) + # 安全默认:权限检查失败时拒绝访问 + pass + + # 缓存权限信息 + self._permission_cache[cache_key] = permissions + return permissions + + def _get_external_volume_permissions(self, volume_name: str) -> set[str]: + """获取用户对指定External Volume的权限 + + Args: + volume_name: External Volume名称 + + Returns: + 用户对该Volume的权限集合 + """ + cache_key = f"external_volume:{volume_name}" + + if cache_key in self._permission_cache: + return self._permission_cache[cache_key] + + permissions = set() + + try: + with self._connection.cursor() as cursor: + # 使用正确的ClickZetta语法检查Volume权限 + logger.info("Checking permissions for volume: %s", volume_name) + cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}") + grants = cursor.fetchall() + + logger.info("Raw grants result for %s: %s", volume_name, grants) + + # 解析权限结果 + # 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to, + # grantee_name, grantor_name, grant_option, granted_time) + for grant in grants: + logger.info("Processing grant: %s", grant) + if len(grant) >= 5: + granted_type = grant[0] + privilege = grant[1].upper() + granted_on = grant[3] + object_name = grant[4] + + logger.info( + "Grant details - type: %s, privilege: %s, granted_on: %s, object_name: %s", + granted_type, + privilege, + granted_on, + object_name, + ) + + # 检查是否是对该Volume的权限或者是层级权限 + if ( + granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name) + ) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"): + logger.info("Matching grant found for %s", volume_name) + + if "READ" in privilege: + permissions.add("read") + logger.info("Added READ permission for %s", volume_name) + if "WRITE" in privilege: + permissions.add("write") + logger.info("Added WRITE permission for %s", volume_name) + if "ALTER" in privilege: + permissions.add("alter") + logger.info("Added ALTER permission for %s", volume_name) + if privilege == "ALL": + permissions.update(["read", "write", "alter"]) + logger.info("Added ALL permissions for %s", volume_name) + + logger.info("Final permissions for %s: %s", volume_name, permissions) + + # 如果没有找到明确的权限,尝试查看Volume列表来验证基本权限 + if not permissions: + try: + cursor.execute("SHOW VOLUMES") + volumes = cursor.fetchall() + for volume in volumes: + if len(volume) > 0 and volume[0] == volume_name: + permissions.add("read") # 至少有读权限 + logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name) + break + except Exception: + logger.debug("Cannot access volume %s, no basic permission", volume_name) + + except Exception as e: + logger.warning("Could not check external volume permissions for %s: %s", volume_name, e) + # 在权限检查失败时,尝试基本的Volume访问验证 + try: + with self._connection.cursor() as cursor: + cursor.execute("SHOW VOLUMES") + volumes = cursor.fetchall() + for volume in volumes: + if len(volume) > 0 and volume[0] == volume_name: + logger.info("Basic volume access verified for %s", volume_name) + permissions.add("read") + permissions.add("write") # 假设有写权限 + break + except Exception as basic_e: + logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e) + # 最后的备选方案:假设有基本权限 + permissions.add("read") + + # 缓存权限信息 + self._permission_cache[cache_key] = permissions + return permissions + + def clear_permission_cache(self): + """清空权限缓存""" + self._permission_cache.clear() + logger.debug("Permission cache cleared") + + def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]: + """获取权限摘要 + + Args: + dataset_id: 数据集ID (用于table volume) + + Returns: + 权限摘要字典 + """ + summary = {} + + for operation in VolumePermission: + summary[operation.name.lower()] = self.check_permission(operation, dataset_id) + + return summary + + def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool: + """检查文件路径的权限继承 + + Args: + file_path: 文件路径 + operation: 要执行的操作 + + Returns: + True if user has permission, False otherwise + """ + try: + # 解析文件路径 + path_parts = file_path.strip("/").split("/") + + if not path_parts: + logger.warning("Invalid file path for permission inheritance check") + return False + + # 对于Table Volume,第一层是dataset_id + if self._volume_type == "table": + if len(path_parts) < 1: + return False + + dataset_id = path_parts[0] + + # 检查对dataset的权限 + has_dataset_permission = self.check_permission(operation, dataset_id) + + if not has_dataset_permission: + logger.debug("Permission denied for dataset %s", dataset_id) + return False + + # 检查路径遍历攻击 + if self._contains_path_traversal(file_path): + logger.warning("Path traversal attack detected: %s", file_path) + return False + + # 检查是否访问敏感目录 + if self._is_sensitive_path(file_path): + logger.warning("Access to sensitive path denied: %s", file_path) + return False + + logger.debug("Permission inherited for path %s", file_path) + return True + + elif self._volume_type == "user": + # User Volume的权限继承 + current_user = self._get_current_username() + + # 检查是否试图访问其他用户的目录 + if len(path_parts) > 1 and path_parts[0] != current_user: + logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0]) + return False + + # 检查基本权限 + return self.check_permission(operation) + + elif self._volume_type == "external": + # External Volume的权限继承 + # 检查对External Volume的权限 + return self.check_permission(operation) + + else: + logger.warning("Unknown volume type for permission inheritance: %s", self._volume_type) + return False + + except Exception as e: + logger.exception("Permission inheritance check failed") + return False + + def _contains_path_traversal(self, file_path: str) -> bool: + """检查路径是否包含路径遍历攻击""" + # 检查常见的路径遍历模式 + traversal_patterns = [ + "../", + "..\\", + "..%2f", + "..%2F", + "..%5c", + "..%5C", + "%2e%2e%2f", + "%2e%2e%5c", + "....//", + "....\\\\", + ] + + file_path_lower = file_path.lower() + + for pattern in traversal_patterns: + if pattern in file_path_lower: + return True + + # 检查绝对路径 + if file_path.startswith("/") or file_path.startswith("\\"): + return True + + # 检查Windows驱动器路径 + if len(file_path) >= 2 and file_path[1] == ":": + return True + + return False + + def _is_sensitive_path(self, file_path: str) -> bool: + """检查路径是否为敏感路径""" + sensitive_patterns = [ + "passwd", + "shadow", + "hosts", + "config", + "secrets", + "private", + "key", + "certificate", + "cert", + "ssl", + "database", + "backup", + "dump", + "log", + "tmp", + ] + + file_path_lower = file_path.lower() + + return any(pattern in file_path_lower for pattern in sensitive_patterns) + + def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool: + """验证操作权限 + + Args: + operation: 操作名称 (save|load|exists|delete|scan) + dataset_id: 数据集ID + + Returns: + True if operation is allowed, False otherwise + """ + operation_mapping = { + "save": VolumePermission.WRITE, + "load": VolumePermission.READ, + "load_once": VolumePermission.READ, + "load_stream": VolumePermission.READ, + "download": VolumePermission.READ, + "exists": VolumePermission.READ, + "delete": VolumePermission.DELETE, + "scan": VolumePermission.LIST, + } + + if operation not in operation_mapping: + logger.warning("Unknown operation: %s", operation) + return False + + volume_permission = operation_mapping[operation] + return self.check_permission(volume_permission, dataset_id) + + +class VolumePermissionError(Exception): + """Volume权限错误异常""" + + def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None): + self.operation = operation + self.volume_type = volume_type + self.dataset_id = dataset_id + super().__init__(message) + + +def check_volume_permission( + permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None +) -> None: + """权限检查装饰器函数 + + Args: + permission_manager: 权限管理器 + operation: 操作名称 + dataset_id: 数据集ID + + Raises: + VolumePermissionError: 如果没有权限 + """ + if not permission_manager.validate_operation(operation, dataset_id): + error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume" + if dataset_id: + error_message += f" (dataset: {dataset_id})" + + raise VolumePermissionError( + error_message, + operation=operation, + volume_type=permission_manager._volume_type or "unknown", + dataset_id=dataset_id, + ) diff --git a/api/extensions/storage/storage_type.py b/api/extensions/storage/storage_type.py index 0a891e36cf..bc2d632159 100644 --- a/api/extensions/storage/storage_type.py +++ b/api/extensions/storage/storage_type.py @@ -5,6 +5,7 @@ class StorageType(StrEnum): ALIYUN_OSS = "aliyun-oss" AZURE_BLOB = "azure-blob" BAIDU_OBS = "baidu-obs" + CLICKZETTA_VOLUME = "clickzetta-volume" GOOGLE_STORAGE = "google-storage" HUAWEI_OBS = "huawei-obs" LOCAL = "local" diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 512a9cb608..b2bcee5dcd 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,4 +1,6 @@ import mimetypes +import os +import urllib.parse import uuid from collections.abc import Callable, Mapping, Sequence from typing import Any, cast @@ -240,16 +242,21 @@ def _build_from_remote_url( def _get_remote_file_info(url: str): file_size = -1 - filename = url.split("/")[-1].split("?")[0] or "unknown_file" - mime_type = mimetypes.guess_type(filename)[0] or "" + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + + # Initialize mime_type from filename as fallback + mime_type, _ = mimetypes.guess_type(filename) resp = ssrf_proxy.head(url, follow_redirects=True) resp = cast(httpx.Response, resp) if resp.status_code == httpx.codes.OK: if content_disposition := resp.headers.get("Content-Disposition"): filename = str(content_disposition.split("filename=")[-1].strip('"')) + # Re-guess mime_type from updated filename + mime_type, _ = mimetypes.guess_type(filename) file_size = int(resp.headers.get("Content-Length", file_size)) - mime_type = mime_type or str(resp.headers.get("Content-Type", "")) return mime_type, filename, file_size diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index b6d85e0e24..1a5fcabf97 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -59,6 +59,8 @@ model_config_fields = { "updated_at": TimestampField, } +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} + app_detail_fields = { "id": fields.String, "name": fields.String, @@ -77,6 +79,7 @@ app_detail_fields = { "updated_by": fields.String, "updated_at": TimestampField, "access_mode": fields.String, + "tags": fields.List(fields.Nested(tag_fields)), } prompt_config_fields = { @@ -92,8 +95,6 @@ model_config_partial_fields = { "updated_at": TimestampField, } -tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} - app_partial_fields = { "id": fields.String, "name": fields.String, @@ -185,7 +186,6 @@ app_detail_fields_with_site = { "enable_api": fields.Boolean, "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), "workflow": fields.Nested(workflow_partial_fields, allow_null=True), - "site": fields.Nested(site_fields), "api_base_url": fields.String, "use_icon_as_answer_icon": fields.Boolean, "max_active_requests": fields.Integer, @@ -195,6 +195,8 @@ app_detail_fields_with_site = { "updated_at": TimestampField, "deleted_tools": fields.List(fields.Nested(deleted_tool_fields)), "access_mode": fields.String, + "tags": fields.List(fields.Nested(tag_fields)), + "site": fields.Nested(site_fields), } diff --git a/api/libs/rsa.py b/api/libs/rsa.py index 598e5bc9e3..c72032701f 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -1,5 +1,4 @@ import hashlib -import os from typing import Union from Crypto.Cipher import AES @@ -18,7 +17,7 @@ def generate_key_pair(tenant_id: str) -> str: pem_private = private_key.export_key() pem_public = public_key.export_key() - filepath = os.path.join("privkeys", tenant_id, "private.pem") + filepath = f"privkeys/{tenant_id}/private.pem" storage.save(filepath, pem_private) @@ -48,7 +47,7 @@ def encrypt(text: str, public_key: Union[str, bytes]) -> bytes: def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]: - filepath = os.path.join("privkeys", tenant_id, "private.pem") + filepath = f"privkeys/{tenant_id}/private.pem" cache_key = f"tenant_privkey:{hashlib.sha3_256(filepath.encode()).hexdigest()}" private_key = redis_client.get(cache_key) diff --git a/api/migrations/versions/2025_07_24_1450-532b3f888abf_manual_dataset_field_update.py b/api/migrations/versions/2025_07_24_1450-532b3f888abf_manual_dataset_field_update.py new file mode 100644 index 0000000000..1664fb99c4 --- /dev/null +++ b/api/migrations/versions/2025_07_24_1450-532b3f888abf_manual_dataset_field_update.py @@ -0,0 +1,25 @@ +"""manual dataset field update + +Revision ID: 532b3f888abf +Revises: 8bcc02c9bd07 +Create Date: 2025-07-24 14:50:48.779833 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '532b3f888abf' +down_revision = '8bcc02c9bd07' +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'::character varying") + + +def downgrade(): + op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'") diff --git a/api/models/account.py b/api/models/account.py index d63c5d7fb5..1a0752440d 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -3,8 +3,9 @@ import json from datetime import datetime from typing import Optional, cast +import sqlalchemy as sa from flask_login import UserMixin # type: ignore -from sqlalchemy import func, select +from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, mapped_column, reconstructor from models.base import Base @@ -83,26 +84,24 @@ class AccountStatus(enum.StrEnum): class Account(UserMixin, Base): __tablename__ = "accounts" - __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email")) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(255)) - email: Mapped[str] = mapped_column(db.String(255)) - password: Mapped[Optional[str]] = mapped_column(db.String(255)) - password_salt: Mapped[Optional[str]] = mapped_column(db.String(255)) - avatar: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - interface_language: Mapped[Optional[str]] = mapped_column(db.String(255)) - interface_theme: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - timezone: Mapped[Optional[str]] = mapped_column(db.String(255)) - last_login_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - last_login_ip: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - last_active_at: Mapped[datetime] = mapped_column( - db.DateTime, server_default=func.current_timestamp(), nullable=False - ) - status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'active'::character varying")) - initialized_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column(String(255)) + email: Mapped[str] = mapped_column(String(255)) + password: Mapped[Optional[str]] = mapped_column(String(255)) + password_salt: Mapped[Optional[str]] = mapped_column(String(255)) + avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + interface_language: Mapped[Optional[str]] = mapped_column(String(255)) + interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + timezone: Mapped[Optional[str]] = mapped_column(String(255)) + last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying")) + initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) @reconstructor def init_on_load(self): @@ -197,16 +196,16 @@ class TenantStatus(enum.StrEnum): class Tenant(Base): __tablename__ = "tenants" - __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(255)) - encrypt_public_key = db.Column(db.Text) - plan: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'basic'::character varying")) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) - custom_config: Mapped[Optional[str]] = mapped_column(db.Text) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column(String(255)) + encrypt_public_key = db.Column(sa.Text) + plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) + custom_config: Mapped[Optional[str]] = mapped_column(sa.Text) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( @@ -227,56 +226,56 @@ class Tenant(Base): class TenantAccountJoin(Base): __tablename__ = "tenant_account_joins" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), - db.Index("tenant_account_join_account_id_idx", "account_id"), - db.Index("tenant_account_join_tenant_id_idx", "tenant_id"), - db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), + sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), + sa.Index("tenant_account_join_account_id_idx", "account_id"), + sa.Index("tenant_account_join_tenant_id_idx", "tenant_id"), + sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) - current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - role: Mapped[str] = mapped_column(db.String(16), server_default="normal") + current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) + role: Mapped[str] = mapped_column(String(16), server_default="normal") invited_by: Mapped[Optional[str]] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) class AccountIntegrate(Base): __tablename__ = "account_integrates" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), - db.UniqueConstraint("account_id", "provider", name="unique_account_provider"), - db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), + sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"), + sa.UniqueConstraint("account_id", "provider", name="unique_account_provider"), + sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) account_id: Mapped[str] = mapped_column(StringUUID) - provider: Mapped[str] = mapped_column(db.String(16)) - open_id: Mapped[str] = mapped_column(db.String(255)) - encrypted_token: Mapped[str] = mapped_column(db.String(255)) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + provider: Mapped[str] = mapped_column(String(16)) + open_id: Mapped[str] = mapped_column(String(255)) + encrypted_token: Mapped[str] = mapped_column(String(255)) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) class InvitationCode(Base): __tablename__ = "invitation_codes" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), - db.Index("invitation_codes_batch_idx", "batch"), - db.Index("invitation_codes_code_idx", "code", "status"), + sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"), + sa.Index("invitation_codes_batch_idx", "batch"), + sa.Index("invitation_codes_code_idx", "code", "status"), ) - id: Mapped[int] = mapped_column(db.Integer) - batch: Mapped[str] = mapped_column(db.String(255)) - code: Mapped[str] = mapped_column(db.String(32)) - status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'unused'::character varying")) - used_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + id: Mapped[int] = mapped_column(sa.Integer) + batch: Mapped[str] = mapped_column(String(255)) + code: Mapped[str] = mapped_column(String(32)) + status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying")) + used_at: Mapped[Optional[datetime]] = mapped_column(DateTime) used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID) used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) - deprecated_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)")) + deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)")) class TenantPluginPermission(Base): @@ -292,16 +291,14 @@ class TenantPluginPermission(Base): __tablename__ = "account_plugin_permissions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"), - db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), + sa.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"), + sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - install_permission: Mapped[InstallPermission] = mapped_column( - db.String(16), nullable=False, server_default="everyone" - ) - debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone") + install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone") + debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone") class TenantPluginAutoUpgradeStrategy(Base): @@ -317,20 +314,16 @@ class TenantPluginAutoUpgradeStrategy(Base): __tablename__ = "tenant_plugin_auto_upgrade_strategies" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"), - db.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"), + sa.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"), + sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - strategy_setting: Mapped[StrategySetting] = mapped_column(db.String(16), nullable=False, server_default="fix_only") - upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day - upgrade_mode: Mapped[UpgradeMode] = mapped_column(db.String(16), nullable=False, server_default="exclude") - exclude_plugins: Mapped[list[str]] = mapped_column( - db.ARRAY(db.String(255)), nullable=False - ) # plugin_id (author/name) - include_plugins: Mapped[list[str]] = mapped_column( - db.ARRAY(db.String(255)), nullable=False - ) # plugin_id (author/name) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only") + upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day + upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude") + exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + created_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 3cef5a0fb2..60167d9069 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,10 +1,11 @@ import enum +from datetime import datetime -from sqlalchemy import func -from sqlalchemy.orm import mapped_column +import sqlalchemy as sa +from sqlalchemy import DateTime, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column from .base import Base -from .engine import db from .types import StringUUID @@ -18,13 +19,13 @@ class APIBasedExtensionPoint(enum.Enum): class APIBasedExtension(Base): __tablename__ = "api_based_extensions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), - db.Index("api_based_extension_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), + sa.Index("api_based_extension_tenant_idx", "tenant_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - name = mapped_column(db.String(255), nullable=False) - api_endpoint = mapped_column(db.String(255), nullable=False) - api_key = mapped_column(db.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + name: Mapped[str] = mapped_column(String(255), nullable=False) + api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False) + api_key = mapped_column(Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index 01372f8bf6..3b1d289bc4 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -12,7 +12,8 @@ from datetime import datetime from json import JSONDecodeError from typing import Any, Optional, cast -from sqlalchemy import func, select +import sqlalchemy as sa +from sqlalchemy import DateTime, String, func, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -38,32 +39,32 @@ class DatasetPermissionEnum(enum.StrEnum): class Dataset(Base): __tablename__ = "datasets" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_pkey"), - db.Index("dataset_tenant_idx", "tenant_id"), - db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), + sa.PrimaryKeyConstraint("id", name="dataset_pkey"), + sa.Index("dataset_tenant_idx", "tenant_id"), + sa.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), ) INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] PROVIDER_LIST = ["vendor", "external", None] - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) - name: Mapped[str] = mapped_column(db.String(255)) - description = mapped_column(db.Text, nullable=True) - provider: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'vendor'::character varying")) - permission: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'only_me'::character varying")) - data_source_type = mapped_column(db.String(255)) - indexing_technique: Mapped[Optional[str]] = mapped_column(db.String(255)) - index_struct = mapped_column(db.Text, nullable=True) + name: Mapped[str] = mapped_column(String(255)) + description = mapped_column(sa.Text, nullable=True) + provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'::character varying")) + permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'::character varying")) + data_source_type = mapped_column(String(255)) + indexing_technique: Mapped[Optional[str]] = mapped_column(String(255)) + index_struct = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - embedding_model = db.Column(db.String(255), nullable=True) # TODO: mapped_column - embedding_model_provider = db.Column(db.String(255), nullable=True) # TODO: mapped_column + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + embedding_model = db.Column(String(255), nullable=True) # TODO: mapped_column + embedding_model_provider = db.Column(String(255), nullable=True) # TODO: mapped_column collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(JSONB, nullable=True) - built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + built_in_field_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @property def dataset_keyword_table(self): @@ -262,16 +263,16 @@ class Dataset(Base): class DatasetProcessRule(Base): __tablename__ = "dataset_process_rules" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), - db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), + sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) - mode = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) - rules = mapped_column(db.Text, nullable=True) + mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying")) + rules = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] @@ -302,72 +303,70 @@ class DatasetProcessRule(Base): class Document(Base): __tablename__ = "documents" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="document_pkey"), - db.Index("document_dataset_id_idx", "dataset_id"), - db.Index("document_is_paused_idx", "is_paused"), - db.Index("document_tenant_idx", "tenant_id"), - db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"), + sa.PrimaryKeyConstraint("id", name="document_pkey"), + sa.Index("document_dataset_id_idx", "dataset_id"), + sa.Index("document_is_paused_idx", "is_paused"), + sa.Index("document_tenant_idx", "tenant_id"), + sa.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"), ) # initial fields - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) - data_source_type = mapped_column(db.String(255), nullable=False) - data_source_info = mapped_column(db.Text, nullable=True) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) + data_source_type: Mapped[str] = mapped_column(String(255), nullable=False) + data_source_info = mapped_column(sa.Text, nullable=True) dataset_process_rule_id = mapped_column(StringUUID, nullable=True) - batch = mapped_column(db.String(255), nullable=False) - name = mapped_column(db.String(255), nullable=False) - created_from = mapped_column(db.String(255), nullable=False) + batch: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[str] = mapped_column(String(255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_api_request_id = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) # start processing - processing_started_at = mapped_column(db.DateTime, nullable=True) + processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # parsing - file_id = mapped_column(db.Text, nullable=True) - word_count = mapped_column(db.Integer, nullable=True) - parsing_completed_at = mapped_column(db.DateTime, nullable=True) + file_id = mapped_column(sa.Text, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable + parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # cleaning - cleaning_completed_at = mapped_column(db.DateTime, nullable=True) + cleaning_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # split - splitting_completed_at = mapped_column(db.DateTime, nullable=True) + splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # indexing - tokens = mapped_column(db.Integer, nullable=True) - indexing_latency = mapped_column(db.Float, nullable=True) - completed_at = mapped_column(db.DateTime, nullable=True) + tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + indexing_latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # pause - is_paused = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + is_paused: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) paused_by = mapped_column(StringUUID, nullable=True) - paused_at = mapped_column(db.DateTime, nullable=True) + paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # error - error = mapped_column(db.Text, nullable=True) - stopped_at = mapped_column(db.DateTime, nullable=True) + error = mapped_column(sa.Text, nullable=True) + stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # basic fields - indexing_status = mapped_column( - db.String(255), nullable=False, server_default=db.text("'waiting'::character varying") - ) - enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = mapped_column(db.DateTime, nullable=True) + indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'::character varying")) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - archived = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - archived_reason = mapped_column(db.String(255), nullable=True) + archived: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + archived_reason = mapped_column(String(255), nullable=True) archived_by = mapped_column(StringUUID, nullable=True) - archived_at = mapped_column(db.DateTime, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - doc_type = mapped_column(db.String(40), nullable=True) + archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + doc_type = mapped_column(String(40), nullable=True) doc_metadata = mapped_column(JSONB, nullable=True) - doc_form = mapped_column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) - doc_language = mapped_column(db.String(255), nullable=True) + doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'::character varying")) + doc_language = mapped_column(String(255), nullable=True) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -524,7 +523,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.upload_date, "type": "time", - "value": self.created_at.timestamp(), + "value": str(self.created_at.timestamp()), } ) built_in_fields.append( @@ -532,7 +531,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.last_update_date, "type": "time", - "value": self.updated_at.timestamp(), + "value": str(self.updated_at.timestamp()), } ) built_in_fields.append( @@ -645,45 +644,45 @@ class Document(Base): class DocumentSegment(Base): __tablename__ = "document_segments" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="document_segment_pkey"), - db.Index("document_segment_dataset_id_idx", "dataset_id"), - db.Index("document_segment_document_id_idx", "document_id"), - db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), - db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), - db.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"), - db.Index("document_segment_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="document_segment_pkey"), + sa.Index("document_segment_dataset_id_idx", "dataset_id"), + sa.Index("document_segment_document_id_idx", "document_id"), + sa.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), + sa.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), + sa.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"), + sa.Index("document_segment_tenant_idx", "tenant_id"), ) # initial fields - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) position: Mapped[int] - content = mapped_column(db.Text, nullable=False) - answer = mapped_column(db.Text, nullable=True) + content = mapped_column(sa.Text, nullable=False) + answer = mapped_column(sa.Text, nullable=True) word_count: Mapped[int] tokens: Mapped[int] # indexing fields - keywords = mapped_column(db.JSON, nullable=True) - index_node_id = mapped_column(db.String(255), nullable=True) - index_node_hash = mapped_column(db.String(255), nullable=True) + keywords = mapped_column(sa.JSON, nullable=True) + index_node_id = mapped_column(String(255), nullable=True) + index_node_hash = mapped_column(String(255), nullable=True) # basic fields - hit_count = mapped_column(db.Integer, nullable=False, default=0) - enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = mapped_column(db.DateTime, nullable=True) + hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'waiting'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'::character varying")) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - indexing_at = mapped_column(db.DateTime, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - error = mapped_column(db.Text, nullable=True) - stopped_at = mapped_column(db.DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + error = mapped_column(sa.Text, nullable=True) + stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) @property def dataset(self): @@ -796,32 +795,36 @@ class DocumentSegment(Base): class ChildChunk(Base): __tablename__ = "child_chunks" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), - db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), - db.Index("child_chunks_node_idx", "index_node_id", "dataset_id"), - db.Index("child_chunks_segment_idx", "segment_id"), + sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"), + sa.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), + sa.Index("child_chunks_node_idx", "index_node_id", "dataset_id"), + sa.Index("child_chunks_segment_idx", "segment_id"), ) # initial fields - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) segment_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) - content = mapped_column(db.Text, nullable=False) - word_count = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) + content = mapped_column(sa.Text, nullable=False) + word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) # indexing fields - index_node_id = mapped_column(db.String(255), nullable=True) - index_node_hash = mapped_column(db.String(255), nullable=True) - type = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + index_node_id = mapped_column(String(255), nullable=True) + index_node_hash = mapped_column(String(255), nullable=True) + type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying")) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - indexing_at = mapped_column(db.DateTime, nullable=True) - completed_at = mapped_column(db.DateTime, nullable=True) - error = mapped_column(db.Text, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) + indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + error = mapped_column(sa.Text, nullable=True) @property def dataset(self): @@ -839,14 +842,14 @@ class ChildChunk(Base): class AppDatasetJoin(Base): __tablename__ = "app_dataset_joins" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), - db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), + sa.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), + sa.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), ) - id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def app(self): @@ -856,32 +859,32 @@ class AppDatasetJoin(Base): class DatasetQuery(Base): __tablename__ = "dataset_queries" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), - db.Index("dataset_query_dataset_id_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_query_pkey"), + sa.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) - content = mapped_column(db.Text, nullable=False) - source = mapped_column(db.String(255), nullable=False) + content = mapped_column(sa.Text, nullable=False) + source: Mapped[str] = mapped_column(String(255), nullable=False) source_app_id = mapped_column(StringUUID, nullable=True) - created_by_role = mapped_column(db.String, nullable=False) + created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) class DatasetKeywordTable(Base): __tablename__ = "dataset_keyword_tables" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), - db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), + sa.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False, unique=True) - keyword_table = mapped_column(db.Text, nullable=False) + keyword_table = mapped_column(sa.Text, nullable=False) data_source_type = mapped_column( - db.String(255), nullable=False, server_default=db.text("'database'::character varying") + String(255), nullable=False, server_default=sa.text("'database'::character varying") ) @property @@ -918,19 +921,19 @@ class DatasetKeywordTable(Base): class Embedding(Base): __tablename__ = "embeddings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="embedding_pkey"), - db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), - db.Index("created_at_idx", "created_at"), + sa.PrimaryKeyConstraint("id", name="embedding_pkey"), + sa.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), + sa.Index("created_at_idx", "created_at"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) model_name = mapped_column( - db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") + String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'::character varying") ) - hash = mapped_column(db.String(64), nullable=False) - embedding = mapped_column(db.LargeBinary, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - provider_name = mapped_column(db.String(255), nullable=False, server_default=db.text("''::character varying")) + hash = mapped_column(String(64), nullable=False) + embedding = mapped_column(sa.LargeBinary, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name = mapped_column(String(255), nullable=False, server_default=sa.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -942,84 +945,84 @@ class Embedding(Base): class DatasetCollectionBinding(Base): __tablename__ = "dataset_collection_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), - db.Index("provider_model_name_idx", "provider_name", "model_name"), + sa.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), + sa.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - provider_name = mapped_column(db.String(255), nullable=False) - model_name = mapped_column(db.String(255), nullable=False) - type = mapped_column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) - collection_name = mapped_column(db.String(64), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + type = mapped_column(String(40), server_default=sa.text("'dataset'::character varying"), nullable=False) + collection_name = mapped_column(String(64), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TidbAuthBinding(Base): __tablename__ = "tidb_auth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), - db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), - db.Index("tidb_auth_bindings_active_idx", "active"), - db.Index("tidb_auth_bindings_created_at_idx", "created_at"), - db.Index("tidb_auth_bindings_status_idx", "status"), + sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), + sa.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), + sa.Index("tidb_auth_bindings_active_idx", "active"), + sa.Index("tidb_auth_bindings_created_at_idx", "created_at"), + sa.Index("tidb_auth_bindings_status_idx", "status"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - cluster_id = mapped_column(db.String(255), nullable=False) - cluster_name = mapped_column(db.String(255), nullable=False) - active = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("CREATING")) - account = mapped_column(db.String(255), nullable=False) - password = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) + cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) + active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying")) + account: Mapped[str] = mapped_column(String(255), nullable=False) + password: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class Whitelist(Base): __tablename__ = "whitelists" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="whitelists_pkey"), - db.Index("whitelists_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="whitelists_pkey"), + sa.Index("whitelists_tenant_idx", "tenant_id"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - category = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + category: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetPermission(Base): __tablename__ = "dataset_permissions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), - db.Index("idx_dataset_permissions_dataset_id", "dataset_id"), - db.Index("idx_dataset_permissions_account_id", "account_id"), - db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), + sa.Index("idx_dataset_permissions_dataset_id", "dataset_id"), + sa.Index("idx_dataset_permissions_account_id", "account_id"), + sa.Index("idx_dataset_permissions_tenant_id", "tenant_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), primary_key=True) dataset_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) - has_permission = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + has_permission: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ExternalKnowledgeApis(Base): __tablename__ = "external_knowledge_apis" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), - db.Index("external_knowledge_apis_tenant_idx", "tenant_id"), - db.Index("external_knowledge_apis_name_idx", "name"), + sa.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), + sa.Index("external_knowledge_apis_tenant_idx", "tenant_id"), + sa.Index("external_knowledge_apis_name_idx", "name"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - name = mapped_column(db.String(255), nullable=False) - description = mapped_column(db.String(255), nullable=False) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(String(255), nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) - settings = mapped_column(db.Text, nullable=True) + settings = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -1059,71 +1062,79 @@ class ExternalKnowledgeApis(Base): class ExternalKnowledgeBindings(Base): __tablename__ = "external_knowledge_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), - db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"), - db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"), - db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), - db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), + sa.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), + sa.Index("external_knowledge_bindings_tenant_idx", "tenant_id"), + sa.Index("external_knowledge_bindings_dataset_idx", "dataset_id"), + sa.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), + sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) external_knowledge_api_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - external_knowledge_id = mapped_column(db.Text, nullable=False) + external_knowledge_id = mapped_column(sa.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetAutoDisableLog(Base): __tablename__ = "dataset_auto_disable_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), - db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), - db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), - db.Index("dataset_auto_disable_log_created_atx", "created_at"), + sa.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), + sa.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), + sa.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), + sa.Index("dataset_auto_disable_log_created_atx", "created_at"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) - notified = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) class RateLimitLog(Base): __tablename__ = "rate_limit_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"), - db.Index("rate_limit_log_tenant_idx", "tenant_id"), - db.Index("rate_limit_log_operation_idx", "operation"), + sa.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"), + sa.Index("rate_limit_log_tenant_idx", "tenant_id"), + sa.Index("rate_limit_log_operation_idx", "operation"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - subscription_plan = mapped_column(db.String(255), nullable=False) - operation = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False) + operation: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) class DatasetMetadata(Base): __tablename__ = "dataset_metadatas" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), - db.Index("dataset_metadata_tenant_idx", "tenant_id"), - db.Index("dataset_metadata_dataset_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), + sa.Index("dataset_metadata_tenant_idx", "tenant_id"), + sa.Index("dataset_metadata_dataset_idx", "dataset_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - type = mapped_column(db.String(255), nullable=False) - name = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + type: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) created_by = mapped_column(StringUUID, nullable=False) updated_by = mapped_column(StringUUID, nullable=True) @@ -1131,17 +1142,17 @@ class DatasetMetadata(Base): class DatasetMetadataBinding(Base): __tablename__ = "dataset_metadata_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), - db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"), - db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"), - db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"), - db.Index("dataset_metadata_binding_document_idx", "document_id"), + sa.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), + sa.Index("dataset_metadata_binding_tenant_idx", "tenant_id"), + sa.Index("dataset_metadata_binding_dataset_idx", "dataset_id"), + sa.Index("dataset_metadata_binding_metadata_idx", "metadata_id"), + sa.Index("dataset_metadata_binding_document_idx", "document_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) metadata_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) created_by = mapped_column(StringUUID, nullable=False) diff --git a/api/models/model.py b/api/models/model.py index 9f6d51b315..c4303f3cc5 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: import sqlalchemy as sa from flask import request from flask_login import UserMixin -from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text +from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config @@ -35,10 +35,10 @@ from .types import StringUUID class DifySetup(Base): __tablename__ = "dify_setups" - __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) + __table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) - version = mapped_column(db.String(255), nullable=False) - setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + version: Mapped[str] = mapped_column(String(255), nullable=False) + setup_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMode(StrEnum): @@ -69,33 +69,33 @@ class IconType(Enum): class App(Base): __tablename__ = "apps" - __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_pkey"), sa.Index("app_tenant_id_idx", "tenant_id")) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) - name: Mapped[str] = mapped_column(db.String(255)) - description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying")) - mode: Mapped[str] = mapped_column(db.String(255)) - icon_type: Mapped[Optional[str]] = mapped_column(db.String(255)) # image, emoji - icon = db.Column(db.String(255)) - icon_background: Mapped[Optional[str]] = mapped_column(db.String(255)) + name: Mapped[str] = mapped_column(String(255)) + description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying")) + mode: Mapped[str] = mapped_column(String(255)) + icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji + icon = db.Column(String(255)) + icon_background: Mapped[Optional[str]] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) - enable_site: Mapped[bool] = mapped_column(db.Boolean) - enable_api: Mapped[bool] = mapped_column(db.Boolean) - api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) - api_rph: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) - is_demo: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - is_public: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - is_universal: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - tracing = mapped_column(db.Text, nullable=True) + status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) + enable_site: Mapped[bool] = mapped_column(sa.Boolean) + enable_api: Mapped[bool] = mapped_column(sa.Boolean) + api_rpm: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0")) + api_rph: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0")) + is_demo: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) + is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) + is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) + tracing = mapped_column(sa.Text, nullable=True) max_active_requests: Mapped[Optional[int]] created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @property def desc_or_prompt(self): @@ -302,36 +302,36 @@ class App(Base): class AppModelConfig(Base): __tablename__ = "app_model_configs" - __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id")) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - provider = mapped_column(db.String(255), nullable=True) - model_id = mapped_column(db.String(255), nullable=True) - configs = mapped_column(db.JSON, nullable=True) + provider = mapped_column(String(255), nullable=True) + model_id = mapped_column(String(255), nullable=True) + configs = mapped_column(sa.JSON, nullable=True) created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - opening_statement = mapped_column(db.Text) - suggested_questions = mapped_column(db.Text) - suggested_questions_after_answer = mapped_column(db.Text) - speech_to_text = mapped_column(db.Text) - text_to_speech = mapped_column(db.Text) - more_like_this = mapped_column(db.Text) - model = mapped_column(db.Text) - user_input_form = mapped_column(db.Text) - dataset_query_variable = mapped_column(db.String(255)) - pre_prompt = mapped_column(db.Text) - agent_mode = mapped_column(db.Text) - sensitive_word_avoidance = mapped_column(db.Text) - retriever_resource = mapped_column(db.Text) - prompt_type = mapped_column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) - chat_prompt_config = mapped_column(db.Text) - completion_prompt_config = mapped_column(db.Text) - dataset_configs = mapped_column(db.Text) - external_data_tools = mapped_column(db.Text) - file_upload = mapped_column(db.Text) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + opening_statement = mapped_column(sa.Text) + suggested_questions = mapped_column(sa.Text) + suggested_questions_after_answer = mapped_column(sa.Text) + speech_to_text = mapped_column(sa.Text) + text_to_speech = mapped_column(sa.Text) + more_like_this = mapped_column(sa.Text) + model = mapped_column(sa.Text) + user_input_form = mapped_column(sa.Text) + dataset_query_variable = mapped_column(String(255)) + pre_prompt = mapped_column(sa.Text) + agent_mode = mapped_column(sa.Text) + sensitive_word_avoidance = mapped_column(sa.Text) + retriever_resource = mapped_column(sa.Text) + prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'::character varying")) + chat_prompt_config = mapped_column(sa.Text) + completion_prompt_config = mapped_column(sa.Text) + dataset_configs = mapped_column(sa.Text) + external_data_tools = mapped_column(sa.Text) + file_upload = mapped_column(sa.Text) @property def app(self): @@ -553,24 +553,24 @@ class AppModelConfig(Base): class RecommendedApp(Base): __tablename__ = "recommended_apps" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), - db.Index("recommended_app_app_id_idx", "app_id"), - db.Index("recommended_app_is_listed_idx", "is_listed", "language"), + sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"), + sa.Index("recommended_app_app_id_idx", "app_id"), + sa.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - description = mapped_column(db.JSON, nullable=False) - copyright = mapped_column(db.String(255), nullable=False) - privacy_policy = mapped_column(db.String(255), nullable=False) + description = mapped_column(sa.JSON, nullable=False) + copyright: Mapped[str] = mapped_column(String(255), nullable=False) + privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False) custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - category = mapped_column(db.String(255), nullable=False) - position = mapped_column(db.Integer, nullable=False, default=0) - is_listed = mapped_column(db.Boolean, nullable=False, default=True) - install_count = mapped_column(db.Integer, nullable=False, default=0) - language = mapped_column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + category: Mapped[str] = mapped_column(String(255), nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) + install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying")) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -581,20 +581,20 @@ class RecommendedApp(Base): class InstalledApp(Base): __tablename__ = "installed_apps" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="installed_app_pkey"), - db.Index("installed_app_tenant_id_idx", "tenant_id"), - db.Index("installed_app_app_id_idx", "app_id"), - db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), + sa.PrimaryKeyConstraint("id", name="installed_app_pkey"), + sa.Index("installed_app_tenant_id_idx", "tenant_id"), + sa.Index("installed_app_app_id_idx", "app_id"), + sa.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) app_owner_tenant_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False, default=0) - is_pinned = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - last_used_at = mapped_column(db.DateTime, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + is_pinned: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + last_used_at = mapped_column(sa.DateTime, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -610,47 +610,47 @@ class InstalledApp(Base): class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="conversation_pkey"), - db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), + sa.PrimaryKeyConstraint("id", name="conversation_pkey"), + sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) app_model_config_id = mapped_column(StringUUID, nullable=True) - model_provider = mapped_column(db.String(255), nullable=True) - override_model_configs = mapped_column(db.Text) - model_id = mapped_column(db.String(255), nullable=True) - mode: Mapped[str] = mapped_column(db.String(255)) - name = mapped_column(db.String(255), nullable=False) - summary = mapped_column(db.Text) - _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) - introduction = mapped_column(db.Text) - system_instruction = mapped_column(db.Text) - system_instruction_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - status = mapped_column(db.String(255), nullable=False) + model_provider = mapped_column(String(255), nullable=True) + override_model_configs = mapped_column(sa.Text) + model_id = mapped_column(String(255), nullable=True) + mode: Mapped[str] = mapped_column(String(255)) + name: Mapped[str] = mapped_column(String(255), nullable=False) + summary = mapped_column(sa.Text) + _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) + introduction = mapped_column(sa.Text) + system_instruction = mapped_column(sa.Text) + system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + status: Mapped[str] = mapped_column(String(255), nullable=False) # The `invoke_from` records how the conversation is created. # # Its value corresponds to the members of `InvokeFrom`. # (api/core/app/entities/app_invoke_entities.py) - invoke_from = mapped_column(db.String(255), nullable=True) + invoke_from = mapped_column(String(255), nullable=True) # ref: ConversationSource. - from_source = mapped_column(db.String(255), nullable=False) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) - read_at = mapped_column(db.DateTime) + read_at = mapped_column(sa.DateTime) read_account_id = mapped_column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" ) - is_deleted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @property def inputs(self): @@ -892,36 +892,36 @@ class Message(Base): Index("message_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - model_provider = mapped_column(db.String(255), nullable=True) - model_id = mapped_column(db.String(255), nullable=True) - override_model_configs = mapped_column(db.Text) - conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) - _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) - query: Mapped[str] = mapped_column(db.Text, nullable=False) - message = mapped_column(db.JSON, nullable=False) - message_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - message_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) - message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - answer: Mapped[str] = db.Column(db.Text, nullable=False) # TODO make it mapped_column - answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) - answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + model_provider = mapped_column(String(255), nullable=True) + model_id = mapped_column(String(255), nullable=True) + override_model_configs = mapped_column(sa.Text) + conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) + _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) + query: Mapped[str] = mapped_column(sa.Text, nullable=False) + message = mapped_column(sa.JSON, nullable=False) + message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + message_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False) + message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + answer: Mapped[str] = db.Column(sa.Text, nullable=False) # TODO make it mapped_column + answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False) + answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) parent_message_id = mapped_column(StringUUID, nullable=True) - provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) - total_price = mapped_column(db.Numeric(10, 7)) - currency = mapped_column(db.String(255), nullable=False) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - error = mapped_column(db.Text) - message_metadata = mapped_column(db.Text) - invoke_from: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - from_source = mapped_column(db.String(255), nullable=False) + provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) + total_price = mapped_column(sa.Numeric(10, 7)) + currency: Mapped[str] = mapped_column(String(255), nullable=False) + status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) + error = mapped_column(sa.Text) + message_metadata = mapped_column(sa.Text) + invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - agent_based = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) @property @@ -1228,23 +1228,23 @@ class Message(Base): class MessageFeedback(Base): __tablename__ = "message_feedbacks" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), - db.Index("message_feedback_app_idx", "app_id"), - db.Index("message_feedback_message_idx", "message_id", "from_source"), - db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), + sa.PrimaryKeyConstraint("id", name="message_feedback_pkey"), + sa.Index("message_feedback_app_idx", "app_id"), + sa.Index("message_feedback_message_idx", "message_id", "from_source"), + sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) conversation_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) - rating = mapped_column(db.String(255), nullable=False) - content = mapped_column(db.Text) - from_source = mapped_column(db.String(255), nullable=False) + rating: Mapped[str] = mapped_column(String(255), nullable=False) + content = mapped_column(sa.Text) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def from_account(self): @@ -1270,9 +1270,9 @@ class MessageFeedback(Base): class MessageFile(Base): __tablename__ = "message_files" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_file_pkey"), - db.Index("message_file_message_idx", "message_id"), - db.Index("message_file_created_by_idx", "created_by"), + sa.PrimaryKeyConstraint("id", name="message_file_pkey"), + sa.Index("message_file_message_idx", "message_id"), + sa.Index("message_file_created_by_idx", "created_by"), ) def __init__( @@ -1296,37 +1296,37 @@ class MessageFile(Base): self.created_by_role = created_by_role.value self.created_by = created_by - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(db.String(255), nullable=False) - transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) - url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - belongs_to: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + type: Mapped[str] = mapped_column(String(255), nullable=False) + transfer_method: Mapped[str] = mapped_column(String(255), nullable=False) + url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) - created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageAnnotation(Base): __tablename__ = "message_annotations" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), - db.Index("message_annotation_app_idx", "app_id"), - db.Index("message_annotation_conversation_idx", "conversation_id"), - db.Index("message_annotation_message_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="message_annotation_pkey"), + sa.Index("message_annotation_app_idx", "app_id"), + sa.Index("message_annotation_conversation_idx", "conversation_id"), + sa.Index("message_annotation_message_idx", "message_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id: Mapped[str] = mapped_column(StringUUID) - conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, db.ForeignKey("conversations.id")) + conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) message_id: Mapped[Optional[str]] = mapped_column(StringUUID) - question = db.Column(db.Text, nullable=True) - content = mapped_column(db.Text, nullable=False) - hit_count = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + question = db.Column(sa.Text, nullable=True) + content = mapped_column(sa.Text, nullable=False) + hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) account_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def account(self): @@ -1342,24 +1342,24 @@ class MessageAnnotation(Base): class AppAnnotationHitHistory(Base): __tablename__ = "app_annotation_hit_histories" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), - db.Index("app_annotation_hit_histories_app_idx", "app_id"), - db.Index("app_annotation_hit_histories_account_idx", "account_id"), - db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), - db.Index("app_annotation_hit_histories_message_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), + sa.Index("app_annotation_hit_histories_app_idx", "app_id"), + sa.Index("app_annotation_hit_histories_account_idx", "account_id"), + sa.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), + sa.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - source = mapped_column(db.Text, nullable=False) - question = mapped_column(db.Text, nullable=False) + source = mapped_column(sa.Text, nullable=False) + question = mapped_column(sa.Text, nullable=False) account_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - score = mapped_column(Float, nullable=False, server_default=db.text("0")) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + score = mapped_column(Float, nullable=False, server_default=sa.text("0")) message_id = mapped_column(StringUUID, nullable=False) - annotation_question = mapped_column(db.Text, nullable=False) - annotation_content = mapped_column(db.Text, nullable=False) + annotation_question = mapped_column(sa.Text, nullable=False) + annotation_content = mapped_column(sa.Text, nullable=False) @property def account(self): @@ -1380,18 +1380,18 @@ class AppAnnotationHitHistory(Base): class AppAnnotationSetting(Base): __tablename__ = "app_annotation_settings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), - db.Index("app_annotation_settings_app_idx", "app_id"), + sa.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), + sa.Index("app_annotation_settings_app_idx", "app_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - score_threshold = mapped_column(Float, nullable=False, server_default=db.text("0")) + score_threshold = mapped_column(Float, nullable=False, server_default=sa.text("0")) collection_binding_id = mapped_column(StringUUID, nullable=False) created_user_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_user_id = mapped_column(StringUUID, nullable=False) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def collection_binding_detail(self): @@ -1408,58 +1408,58 @@ class AppAnnotationSetting(Base): class OperationLog(Base): __tablename__ = "operation_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="operation_log_pkey"), - db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), + sa.PrimaryKeyConstraint("id", name="operation_log_pkey"), + sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) - action = mapped_column(db.String(255), nullable=False) - content = mapped_column(db.JSON) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_ip = mapped_column(db.String(255), nullable=False) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + action: Mapped[str] = mapped_column(String(255), nullable=False) + content = mapped_column(sa.JSON) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + created_ip: Mapped[str] = mapped_column(String(255), nullable=False) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class EndUser(Base, UserMixin): __tablename__ = "end_users" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="end_user_pkey"), - db.Index("end_user_session_id_idx", "session_id", "type"), - db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), + sa.PrimaryKeyConstraint("id", name="end_user_pkey"), + sa.Index("end_user_session_id_idx", "session_id", "type"), + sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(255), nullable=False) - external_user_id = mapped_column(db.String(255), nullable=True) - name = mapped_column(db.String(255)) - is_anonymous = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + type: Mapped[str] = mapped_column(String(255), nullable=False) + external_user_id = mapped_column(String(255), nullable=True) + name = mapped_column(String(255)) + is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) session_id: Mapped[str] = mapped_column() - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMCPServer(Base): __tablename__ = "app_mcp_servers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"), - db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"), - db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"), + sa.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"), + sa.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"), + sa.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) - name = mapped_column(db.String(255), nullable=False) - description = mapped_column(db.String(255), nullable=False) - server_code = mapped_column(db.String(255), nullable=False) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - parameters = mapped_column(db.Text, nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(String(255), nullable=False) + server_code: Mapped[str] = mapped_column(String(255), nullable=False) + status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) + parameters = mapped_column(sa.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_server_code(n): @@ -1478,35 +1478,35 @@ class AppMCPServer(Base): class Site(Base): __tablename__ = "sites" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="site_pkey"), - db.Index("site_app_id_idx", "app_id"), - db.Index("site_code_idx", "code", "status"), + sa.PrimaryKeyConstraint("id", name="site_pkey"), + sa.Index("site_app_id_idx", "app_id"), + sa.Index("site_code_idx", "code", "status"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - title = mapped_column(db.String(255), nullable=False) - icon_type = mapped_column(db.String(255), nullable=True) - icon = mapped_column(db.String(255)) - icon_background = mapped_column(db.String(255)) - description = mapped_column(db.Text) - default_language = mapped_column(db.String(255), nullable=False) - chat_color_theme = mapped_column(db.String(255)) - chat_color_theme_inverted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - copyright = mapped_column(db.String(255)) - privacy_policy = mapped_column(db.String(255)) - show_workflow_steps = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - use_icon_as_answer_icon = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + title: Mapped[str] = mapped_column(String(255), nullable=False) + icon_type = mapped_column(String(255), nullable=True) + icon = mapped_column(String(255)) + icon_background = mapped_column(String(255)) + description = mapped_column(sa.Text) + default_language: Mapped[str] = mapped_column(String(255), nullable=False) + chat_color_theme = mapped_column(String(255)) + chat_color_theme_inverted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + copyright = mapped_column(String(255)) + privacy_policy = mapped_column(String(255)) + show_workflow_steps: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") - customize_domain = mapped_column(db.String(255)) - customize_token_strategy = mapped_column(db.String(255), nullable=False) - prompt_public = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + customize_domain = mapped_column(String(255)) + customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) + prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - code = mapped_column(db.String(255)) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + code = mapped_column(String(255)) @property def custom_disclaimer(self): @@ -1535,19 +1535,19 @@ class Site(Base): class ApiToken(Base): __tablename__ = "api_tokens" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="api_token_pkey"), - db.Index("api_token_app_id_type_idx", "app_id", "type"), - db.Index("api_token_token_idx", "token", "type"), - db.Index("api_token_tenant_idx", "tenant_id", "type"), + sa.PrimaryKeyConstraint("id", name="api_token_pkey"), + sa.Index("api_token_app_id_type_idx", "app_id", "type"), + sa.Index("api_token_token_idx", "token", "type"), + sa.Index("api_token_tenant_idx", "tenant_id", "type"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(16), nullable=False) - token = mapped_column(db.String(255), nullable=False) - last_used_at = mapped_column(db.DateTime, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + type = mapped_column(String(16), nullable=False) + token: Mapped[str] = mapped_column(String(255), nullable=False) + last_used_at = mapped_column(sa.DateTime, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_api_key(prefix, n): @@ -1561,27 +1561,27 @@ class ApiToken(Base): class UploadFile(Base): __tablename__ = "upload_files" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="upload_file_pkey"), - db.Index("upload_file_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="upload_file_pkey"), + sa.Index("upload_file_tenant_idx", "tenant_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - storage_type: Mapped[str] = mapped_column(db.String(255), nullable=False) - key: Mapped[str] = mapped_column(db.String(255), nullable=False) - name: Mapped[str] = mapped_column(db.String(255), nullable=False) - size: Mapped[int] = mapped_column(db.Integer, nullable=False) - extension: Mapped[str] = mapped_column(db.String(255), nullable=False) - mime_type: Mapped[str] = mapped_column(db.String(255), nullable=True) + storage_type: Mapped[str] = mapped_column(String(255), nullable=False) + key: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + size: Mapped[int] = mapped_column(sa.Integer, nullable=False) + extension: Mapped[str] = mapped_column(String(255), nullable=False) + mime_type: Mapped[str] = mapped_column(String(255), nullable=True) created_by_role: Mapped[str] = mapped_column( - db.String(255), nullable=False, server_default=db.text("'account'::character varying") + String(255), nullable=False, server_default=sa.text("'account'::character varying") ) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True) - hash: Mapped[str | None] = mapped_column(db.String(255), nullable=True) + used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True) + hash: Mapped[str | None] = mapped_column(String(255), nullable=True) source_url: Mapped[str] = mapped_column(sa.TEXT, default="") def __init__( @@ -1623,71 +1623,71 @@ class UploadFile(Base): class ApiRequest(Base): __tablename__ = "api_requests" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="api_request_pkey"), - db.Index("api_request_token_idx", "tenant_id", "api_token_id"), + sa.PrimaryKeyConstraint("id", name="api_request_pkey"), + sa.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) api_token_id = mapped_column(StringUUID, nullable=False) - path = mapped_column(db.String(255), nullable=False) - request = mapped_column(db.Text, nullable=True) - response = mapped_column(db.Text, nullable=True) - ip = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + path: Mapped[str] = mapped_column(String(255), nullable=False) + request = mapped_column(sa.Text, nullable=True) + response = mapped_column(sa.Text, nullable=True) + ip: Mapped[str] = mapped_column(String(255), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageChain(Base): __tablename__ = "message_chains" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_chain_pkey"), - db.Index("message_chain_message_id_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="message_chain_pkey"), + sa.Index("message_chain_message_id_idx", "message_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) - type = mapped_column(db.String(255), nullable=False) - input = mapped_column(db.Text, nullable=True) - output = mapped_column(db.Text, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + type: Mapped[str] = mapped_column(String(255), nullable=False) + input = mapped_column(sa.Text, nullable=True) + output = mapped_column(sa.Text, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) class MessageAgentThought(Base): __tablename__ = "message_agent_thoughts" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), - db.Index("message_agent_thought_message_id_idx", "message_id"), - db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), + sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), + sa.Index("message_agent_thought_message_id_idx", "message_id"), + sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) message_chain_id = mapped_column(StringUUID, nullable=True) - position = mapped_column(db.Integer, nullable=False) - thought = mapped_column(db.Text, nullable=True) - tool = mapped_column(db.Text, nullable=True) - tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) - tool_meta_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) - tool_input = mapped_column(db.Text, nullable=True) - observation = mapped_column(db.Text, nullable=True) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) + thought = mapped_column(sa.Text, nullable=True) + tool = mapped_column(sa.Text, nullable=True) + tool_labels_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text")) + tool_meta_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text")) + tool_input = mapped_column(sa.Text, nullable=True) + observation = mapped_column(sa.Text, nullable=True) # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design - tool_process_data = mapped_column(db.Text, nullable=True) - message = mapped_column(db.Text, nullable=True) - message_token = mapped_column(db.Integer, nullable=True) - message_unit_price = mapped_column(db.Numeric, nullable=True) - message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - message_files = mapped_column(db.Text, nullable=True) - answer = db.Column(db.Text, nullable=True) - answer_token = mapped_column(db.Integer, nullable=True) - answer_unit_price = mapped_column(db.Numeric, nullable=True) - answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - tokens = mapped_column(db.Integer, nullable=True) - total_price = mapped_column(db.Numeric, nullable=True) - currency = mapped_column(db.String, nullable=True) - latency = mapped_column(db.Float, nullable=True) - created_by_role = mapped_column(db.String, nullable=False) + tool_process_data = mapped_column(sa.Text, nullable=True) + message = mapped_column(sa.Text, nullable=True) + message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + message_unit_price = mapped_column(sa.Numeric, nullable=True) + message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + message_files = mapped_column(sa.Text, nullable=True) + answer = db.Column(sa.Text, nullable=True) + answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + answer_unit_price = mapped_column(sa.Numeric, nullable=True) + answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + total_price = mapped_column(sa.Numeric, nullable=True) + currency = mapped_column(String, nullable=True) + latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def files(self) -> list: @@ -1769,80 +1769,80 @@ class MessageAgentThought(Base): class DatasetRetrieverResource(Base): __tablename__ = "dataset_retriever_resources" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), - db.Index("dataset_retriever_resource_message_id_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), + sa.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - dataset_name = mapped_column(db.Text, nullable=False) + dataset_name = mapped_column(sa.Text, nullable=False) document_id = mapped_column(StringUUID, nullable=True) - document_name = mapped_column(db.Text, nullable=False) - data_source_type = mapped_column(db.Text, nullable=True) + document_name = mapped_column(sa.Text, nullable=False) + data_source_type = mapped_column(sa.Text, nullable=True) segment_id = mapped_column(StringUUID, nullable=True) - score = mapped_column(db.Float, nullable=True) - content = mapped_column(db.Text, nullable=False) - hit_count = mapped_column(db.Integer, nullable=True) - word_count = mapped_column(db.Integer, nullable=True) - segment_position = mapped_column(db.Integer, nullable=True) - index_node_hash = mapped_column(db.Text, nullable=True) - retriever_from = mapped_column(db.Text, nullable=False) + score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + content = mapped_column(sa.Text, nullable=False) + hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + index_node_hash = mapped_column(sa.Text, nullable=True) + retriever_from = mapped_column(sa.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) class Tag(Base): __tablename__ = "tags" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tag_pkey"), - db.Index("tag_type_idx", "type"), - db.Index("tag_name_idx", "name"), + sa.PrimaryKeyConstraint("id", name="tag_pkey"), + sa.Index("tag_type_idx", "type"), + sa.Index("tag_name_idx", "name"), ) TAG_TYPE_LIST = ["knowledge", "app"] - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(16), nullable=False) - name = mapped_column(db.String(255), nullable=False) + type = mapped_column(String(16), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class TagBinding(Base): __tablename__ = "tag_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), - db.Index("tag_bind_target_id_idx", "target_id"), - db.Index("tag_bind_tag_id_idx", "tag_id"), + sa.PrimaryKeyConstraint("id", name="tag_binding_pkey"), + sa.Index("tag_bind_target_id_idx", "target_id"), + sa.Index("tag_bind_tag_id_idx", "tag_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) tag_id = mapped_column(StringUUID, nullable=True) target_id = mapped_column(StringUUID, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class TraceAppConfig(Base): __tablename__ = "trace_app_config" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), - db.Index("trace_app_config_app_id_idx", "app_id"), + sa.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), + sa.Index("trace_app_config_app_id_idx", "app_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - tracing_provider = mapped_column(db.String(255), nullable=True) - tracing_config = mapped_column(db.JSON, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + tracing_provider = mapped_column(String(255), nullable=True) + tracing_config = mapped_column(sa.JSON, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - is_active = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) @property def tracing_config_dict(self): diff --git a/api/models/provider.py b/api/models/provider.py index 1e25f0c90f..4ea2c59fdb 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -2,11 +2,11 @@ from datetime import datetime from enum import Enum from typing import Optional -from sqlalchemy import func, text +import sqlalchemy as sa +from sqlalchemy import DateTime, String, func, text from sqlalchemy.orm import Mapped, mapped_column from .base import Base -from .engine import db from .types import StringUUID @@ -47,31 +47,31 @@ class Provider(Base): __tablename__ = "providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_pkey"), - db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), - db.UniqueConstraint( + sa.PrimaryKeyConstraint("id", name="provider_pkey"), + sa.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), + sa.UniqueConstraint( "tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota" ), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) provider_type: Mapped[str] = mapped_column( - db.String(40), nullable=False, server_default=text("'custom'::character varying") + String(40), nullable=False, server_default=text("'custom'::character varying") ) - encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) + last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) quota_type: Mapped[Optional[str]] = mapped_column( - db.String(40), nullable=True, server_default=text("''::character varying") + String(40), nullable=True, server_default=text("''::character varying") ) - quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) - quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) + quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True) + quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -104,80 +104,80 @@ class ProviderModel(Base): __tablename__ = "provider_models" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_model_pkey"), - db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), - db.UniqueConstraint( + sa.PrimaryKeyConstraint("id", name="provider_model_pkey"), + sa.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), + sa.UniqueConstraint( "tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name" ), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(Base): __tablename__ = "tenant_default_models" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), - db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), + sa.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), + sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(Base): __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), - db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), + sa.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), + sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(Base): __tablename__ = "provider_orders" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_order_pkey"), - db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), + sa.PrimaryKeyConstraint("id", name="provider_order_pkey"), + sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) - payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) - transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) - quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) - currency: Mapped[Optional[str]] = mapped_column(db.String(40)) - total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) + payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False) + payment_id: Mapped[Optional[str]] = mapped_column(String(191)) + transaction_id: Mapped[Optional[str]] = mapped_column(String(191)) + quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) + currency: Mapped[Optional[str]] = mapped_column(String(40)) + total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer) payment_status: Mapped[str] = mapped_column( - db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") + String(40), nullable=False, server_default=text("'wait_pay'::character varying") ) - paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(Base): @@ -187,19 +187,19 @@ class ProviderModelSetting(Base): __tablename__ = "provider_model_settings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), - db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), + sa.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), + sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) - load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) + load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(Base): @@ -209,17 +209,17 @@ class LoadBalancingModelConfig(Base): __tablename__ = "load_balancing_model_configs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), - db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), + sa.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), + sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - name: Mapped[str] = mapped_column(db.String(255), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 100e0d96ef..8456d65a87 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,49 +1,51 @@ import json +from datetime import datetime +from typing import Optional -from sqlalchemy import func +import sqlalchemy as sa +from sqlalchemy import DateTime, String, func from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import Mapped, mapped_column from models.base import Base -from .engine import db from .types import StringUUID class DataSourceOauthBinding(Base): __tablename__ = "data_source_oauth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="source_binding_pkey"), - db.Index("source_binding_tenant_id_idx", "tenant_id"), - db.Index("source_info_idx", "source_info", postgresql_using="gin"), + sa.PrimaryKeyConstraint("id", name="source_binding_pkey"), + sa.Index("source_binding_tenant_id_idx", "tenant_id"), + sa.Index("source_info_idx", "source_info", postgresql_using="gin"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - access_token = mapped_column(db.String(255), nullable=False) - provider = mapped_column(db.String(255), nullable=False) + access_token: Mapped[str] = mapped_column(String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) source_info = mapped_column(JSONB, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) class DataSourceApiKeyAuthBinding(Base): __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), - db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), - db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), + sa.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), + sa.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), + sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - category = mapped_column(db.String(255), nullable=False) - provider = mapped_column(db.String(255), nullable=False) - credentials = mapped_column(db.Text, nullable=True) # JSON - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + category: Mapped[str] = mapped_column(String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) + credentials = mapped_column(sa.Text, nullable=True) # JSON + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) def to_dict(self): return { diff --git a/api/models/task.py b/api/models/task.py index 3e5ebd2099..ab700c553c 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,7 +1,9 @@ from datetime import datetime from typing import Optional +import sqlalchemy as sa from celery import states # type: ignore +from sqlalchemy import DateTime, String from sqlalchemy.orm import Mapped, mapped_column from libs.datetime_utils import naive_utc_now @@ -15,23 +17,23 @@ class CeleryTask(Base): __tablename__ = "celery_taskmeta" - id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) - task_id = mapped_column(db.String(155), unique=True) - status = mapped_column(db.String(50), default=states.PENDING) + id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) + task_id = mapped_column(String(155), unique=True) + status = mapped_column(String(50), default=states.PENDING) result = mapped_column(db.PickleType, nullable=True) date_done = mapped_column( - db.DateTime, + DateTime, default=lambda: naive_utc_now(), onupdate=lambda: naive_utc_now(), nullable=True, ) - traceback = mapped_column(db.Text, nullable=True) - name = mapped_column(db.String(155), nullable=True) - args = mapped_column(db.LargeBinary, nullable=True) - kwargs = mapped_column(db.LargeBinary, nullable=True) - worker = mapped_column(db.String(155), nullable=True) - retries = mapped_column(db.Integer, nullable=True) - queue = mapped_column(db.String(155), nullable=True) + traceback = mapped_column(sa.Text, nullable=True) + name = mapped_column(String(155), nullable=True) + args = mapped_column(sa.LargeBinary, nullable=True) + kwargs = mapped_column(sa.LargeBinary, nullable=True) + worker = mapped_column(String(155), nullable=True) + retries: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + queue = mapped_column(String(155), nullable=True) class CeleryTaskSet(Base): @@ -40,8 +42,8 @@ class CeleryTaskSet(Base): __tablename__ = "celery_tasksetmeta" id: Mapped[int] = mapped_column( - db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True + sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True ) - taskset_id = mapped_column(db.String(155), unique=True) + taskset_id = mapped_column(String(155), unique=True) result = mapped_column(db.PickleType, nullable=True) - date_done: Mapped[Optional[datetime]] = mapped_column(db.DateTime, default=lambda: naive_utc_now(), nullable=True) + date_done: Mapped[Optional[datetime]] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 68f4211e59..408c1371c2 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse import sqlalchemy as sa from deprecated import deprecated -from sqlalchemy import ForeignKey, func +from sqlalchemy import ForeignKey, String, func from sqlalchemy.orm import Mapped, mapped_column from core.file import helpers as file_helpers @@ -25,33 +25,33 @@ from .types import StringUUID class ToolOAuthSystemClient(Base): __tablename__ = "tool_oauth_system_clients" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), - db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), + sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), + sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) - provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + plugin_id = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) # oauth params of the tool provider - encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) # tenant level tool oauth client params (client_id, client_secret, etc.) class ToolOAuthTenantClient(Base): __tablename__ = "tool_oauth_tenant_clients" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"), - db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"), + sa.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"), + sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) - provider: Mapped[str] = mapped_column(db.String(255), nullable=False) - enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + plugin_id: Mapped[str] = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) # oauth params of the tool provider - encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) @property def oauth_params(self) -> dict: @@ -65,35 +65,35 @@ class BuiltinToolProvider(Base): __tablename__ = "tool_builtin_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), - db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"), + sa.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), + sa.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"), ) # id of the tool provider - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column( - db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") + String(256), nullable=False, server_default=sa.text("'API KEY 1'::character varying") ) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # name of the tool provider - provider: Mapped[str] = mapped_column(db.String(256), nullable=False) + provider: Mapped[str] = mapped_column(String(256), nullable=False) # credential of the tool provider - encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) + encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True) created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) - is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) # credential type, e.g., "api-key", "oauth2" credential_type: Mapped[str] = mapped_column( - db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") + String(32), nullable=False, server_default=sa.text("'api-key'::character varying") ) - expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1")) + expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) @property def credentials(self) -> dict: @@ -107,35 +107,35 @@ class ApiToolProvider(Base): __tablename__ = "tool_api_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), - db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), + sa.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), + sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # name of the api provider - name = mapped_column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) + name = mapped_column(String(255), nullable=False, server_default=sa.text("'API KEY 1'::character varying")) # icon - icon = mapped_column(db.String(255), nullable=False) + icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema - schema = mapped_column(db.Text, nullable=False) - schema_type_str: Mapped[str] = mapped_column(db.String(40), nullable=False) + schema = mapped_column(sa.Text, nullable=False) + schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) # who created this tool user_id = mapped_column(StringUUID, nullable=False) # tenant id tenant_id = mapped_column(StringUUID, nullable=False) # description of the provider - description = mapped_column(db.Text, nullable=False) + description = mapped_column(sa.Text, nullable=False) # json format tools - tools_str = mapped_column(db.Text, nullable=False) + tools_str = mapped_column(sa.Text, nullable=False) # json format credentials - credentials_str = mapped_column(db.Text, nullable=False) + credentials_str = mapped_column(sa.Text, nullable=False) # privacy policy - privacy_policy = mapped_column(db.String(255), nullable=True) + privacy_policy = mapped_column(String(255), nullable=True) # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def schema_type(self) -> ApiProviderSchemaType: @@ -167,17 +167,17 @@ class ToolLabelBinding(Base): __tablename__ = "tool_label_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), - db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), + sa.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), + sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # tool id - tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False) + tool_id: Mapped[str] = mapped_column(String(64), nullable=False) # tool type - tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + tool_type: Mapped[str] = mapped_column(String(40), nullable=False) # label name - label_name: Mapped[str] = mapped_column(db.String(40), nullable=False) + label_name: Mapped[str] = mapped_column(String(40), nullable=False) class WorkflowToolProvider(Base): @@ -187,38 +187,38 @@ class WorkflowToolProvider(Base): __tablename__ = "tool_workflow_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), - db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), - db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), + sa.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), + sa.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), + sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # name of the workflow provider - name: Mapped[str] = mapped_column(db.String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) # label of the workflow provider - label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") + label: Mapped[str] = mapped_column(String(255), nullable=False, server_default="") # icon - icon: Mapped[str] = mapped_column(db.String(255), nullable=False) + icon: Mapped[str] = mapped_column(String(255), nullable=False) # app id of the workflow provider app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # version of the workflow provider - version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") + version: Mapped[str] = mapped_column(String(255), nullable=False, server_default="") # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # description of the provider - description: Mapped[str] = mapped_column(db.Text, nullable=False) + description: Mapped[str] = mapped_column(sa.Text, nullable=False) # parameter configuration - parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]") + parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]") # privacy policy - privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="") + privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="") created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) @property @@ -245,38 +245,38 @@ class MCPToolProvider(Base): __tablename__ = "tool_mcp_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"), - db.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"), - db.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"), - db.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"), + sa.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"), + sa.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"), + sa.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"), + sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # name of the mcp provider - name: Mapped[str] = mapped_column(db.String(40), nullable=False) + name: Mapped[str] = mapped_column(String(40), nullable=False) # server identifier of the mcp provider - server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False) + server_identifier: Mapped[str] = mapped_column(String(64), nullable=False) # encrypted url of the mcp provider - server_url: Mapped[str] = mapped_column(db.Text, nullable=False) + server_url: Mapped[str] = mapped_column(sa.Text, nullable=False) # hash of server_url for uniqueness check - server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False) + server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False) # icon of the mcp provider - icon: Mapped[str] = mapped_column(db.String(255), nullable=True) + icon: Mapped[str] = mapped_column(String(255), nullable=True) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # encrypted credentials - encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) + encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True) # authed - authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False) + authed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) # tools - tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]") + tools: Mapped[str] = mapped_column(sa.Text, nullable=False, default="[]") created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) def load_user(self) -> Account | None: @@ -347,35 +347,35 @@ class ToolModelInvoke(Base): """ __tablename__ = "tool_model_invokes" - __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # who invoke this tool user_id = mapped_column(StringUUID, nullable=False) # tenant id tenant_id = mapped_column(StringUUID, nullable=False) # provider - provider = mapped_column(db.String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) # type - tool_type = mapped_column(db.String(40), nullable=False) + tool_type = mapped_column(String(40), nullable=False) # tool name - tool_name = mapped_column(db.String(128), nullable=False) + tool_name = mapped_column(String(128), nullable=False) # invoke parameters - model_parameters = mapped_column(db.Text, nullable=False) + model_parameters = mapped_column(sa.Text, nullable=False) # prompt messages - prompt_messages = mapped_column(db.Text, nullable=False) + prompt_messages = mapped_column(sa.Text, nullable=False) # invoke response - model_response = mapped_column(db.Text, nullable=False) + model_response = mapped_column(sa.Text, nullable=False) - prompt_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) - answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) - total_price = mapped_column(db.Numeric(10, 7)) - currency = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + prompt_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False) + answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) + total_price = mapped_column(sa.Numeric(10, 7)) + currency: Mapped[str] = mapped_column(String(255), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @deprecated @@ -386,13 +386,13 @@ class ToolConversationVariables(Base): __tablename__ = "tool_conversation_variables" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), + sa.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), # add index for user_id and conversation_id - db.Index("user_id_idx", "user_id"), - db.Index("conversation_id_idx", "conversation_id"), + sa.Index("user_id_idx", "user_id"), + sa.Index("conversation_id_idx", "conversation_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # conversation user id user_id = mapped_column(StringUUID, nullable=False) # tenant id @@ -400,10 +400,10 @@ class ToolConversationVariables(Base): # conversation id conversation_id = mapped_column(StringUUID, nullable=False) # variables pool - variables_str = mapped_column(db.Text, nullable=False) + variables_str = mapped_column(sa.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def variables(self) -> Any: @@ -417,11 +417,11 @@ class ToolFile(Base): __tablename__ = "tool_files" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_file_pkey"), - db.Index("tool_file_conversation_id_idx", "conversation_id"), + sa.PrimaryKeyConstraint("id", name="tool_file_pkey"), + sa.Index("tool_file_conversation_id_idx", "conversation_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # conversation user id user_id: Mapped[str] = mapped_column(StringUUID) # tenant id @@ -429,11 +429,11 @@ class ToolFile(Base): # conversation id conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # file key - file_key: Mapped[str] = mapped_column(db.String(255), nullable=False) + file_key: Mapped[str] = mapped_column(String(255), nullable=False) # mime type - mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False) + mimetype: Mapped[str] = mapped_column(String(255), nullable=False) # original url - original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True) + original_url: Mapped[str] = mapped_column(String(2048), nullable=True) # name name: Mapped[str] = mapped_column(default="") # size @@ -448,30 +448,30 @@ class DeprecatedPublishedAppTool(Base): __tablename__ = "tool_published_apps" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), - db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), + sa.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), + sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # id of the app app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False) user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # who published this tool - description = mapped_column(db.Text, nullable=False) + description = mapped_column(sa.Text, nullable=False) # llm_description of the tool, for LLM - llm_description = mapped_column(db.Text, nullable=False) + llm_description = mapped_column(sa.Text, nullable=False) # query description, query will be seem as a parameter of the tool, # to describe this parameter to llm, we need this field - query_description = mapped_column(db.Text, nullable=False) + query_description = mapped_column(sa.Text, nullable=False) # query name, the name of the query parameter - query_name = mapped_column(db.String(40), nullable=False) + query_name = mapped_column(String(40), nullable=False) # name of the tool provider - tool_name = mapped_column(db.String(40), nullable=False) + tool_name = mapped_column(String(40), nullable=False) # author - author = mapped_column(db.String(40), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + author = mapped_column(String(40), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")) @property def description_i18n(self) -> I18nObject: diff --git a/api/models/web.py b/api/models/web.py index ce00f4010f..74f99e187b 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,4 +1,7 @@ -from sqlalchemy import func +from datetime import datetime + +import sqlalchemy as sa +from sqlalchemy import DateTime, String, func from sqlalchemy.orm import Mapped, mapped_column from models.base import Base @@ -11,18 +14,18 @@ from .types import StringUUID class SavedMessage(Base): __tablename__ = "saved_messages" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="saved_message_pkey"), - db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), + sa.PrimaryKeyConstraint("id", name="saved_message_pkey"), + sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) created_by_role = mapped_column( - db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=sa.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): @@ -32,15 +35,15 @@ class SavedMessage(Base): class PinnedConversation(Base): __tablename__ = "pinned_conversations" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), - db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), + sa.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), + sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) created_by_role = mapped_column( - db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=sa.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index d89db6c7da..453a650f84 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -6,8 +6,9 @@ from enum import Enum, StrEnum from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 +import sqlalchemy as sa from flask_login import current_user -from sqlalchemy import orm +from sqlalchemy import DateTime, orm from core.file.constants import maybe_file_object from core.file.models import File @@ -24,8 +25,7 @@ from ._workflow_exc import NodeNotFoundError, WorkflowDataError if TYPE_CHECKING: from models.model import AppMode -import sqlalchemy as sa -from sqlalchemy import Index, PrimaryKeyConstraint, UniqueConstraint, func +from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func from sqlalchemy.orm import Mapped, declared_attr, mapped_column from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE @@ -117,33 +117,33 @@ class Workflow(Base): __tablename__ = "workflows" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_pkey"), - db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), + sa.PrimaryKeyConstraint("id", name="workflow_pkey"), + sa.Index("workflow_version_idx", "tenant_id", "app_id", "version"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(db.String(255), nullable=False) - version: Mapped[str] = mapped_column(db.String(255), nullable=False) + type: Mapped[str] = mapped_column(String(255), nullable=False) + version: Mapped[str] = mapped_column(String(255), nullable=False) marked_name: Mapped[str] = mapped_column(default="", server_default="") marked_comment: Mapped[str] = mapped_column(default="", server_default="") graph: Mapped[str] = mapped_column(sa.Text) _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=naive_utc_now(), server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( - "environment_variables", db.Text, nullable=False, server_default="{}" + "environment_variables", sa.Text, nullable=False, server_default="{}" ) _conversation_variables: Mapped[str] = mapped_column( - "conversation_variables", db.Text, nullable=False, server_default="{}" + "conversation_variables", sa.Text, nullable=False, server_default="{}" ) VERSION_DRAFT = "draft" @@ -491,31 +491,31 @@ class WorkflowRun(Base): __tablename__ = "workflow_runs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), - db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), + sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"), + sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - type: Mapped[str] = mapped_column(db.String(255)) - triggered_from: Mapped[str] = mapped_column(db.String(255)) - version: Mapped[str] = mapped_column(db.String(255)) - graph: Mapped[Optional[str]] = mapped_column(db.Text) - inputs: Mapped[Optional[str]] = mapped_column(db.Text) - status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded + type: Mapped[str] = mapped_column(String(255)) + triggered_from: Mapped[str] = mapped_column(String(255)) + version: Mapped[str] = mapped_column(String(255)) + graph: Mapped[Optional[str]] = mapped_column(sa.Text) + inputs: Mapped[Optional[str]] = mapped_column(sa.Text) + status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error: Mapped[Optional[str]] = mapped_column(db.Text) - elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) + error: Mapped[Optional[str]] = mapped_column(sa.Text) + elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) - total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) - created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user + total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) + created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) @property def created_by_account(self): @@ -704,29 +704,29 @@ class WorkflowNodeExecutionModel(Base): ), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - triggered_from: Mapped[str] = mapped_column(db.String(255)) + triggered_from: Mapped[str] = mapped_column(String(255)) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) - index: Mapped[int] = mapped_column(db.Integer) - predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255)) - node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255)) - node_id: Mapped[str] = mapped_column(db.String(255)) - node_type: Mapped[str] = mapped_column(db.String(255)) - title: Mapped[str] = mapped_column(db.String(255)) - inputs: Mapped[Optional[str]] = mapped_column(db.Text) - process_data: Mapped[Optional[str]] = mapped_column(db.Text) - outputs: Mapped[Optional[str]] = mapped_column(db.Text) - status: Mapped[str] = mapped_column(db.String(255)) - error: Mapped[Optional[str]] = mapped_column(db.Text) - elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0")) - execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - created_by_role: Mapped[str] = mapped_column(db.String(255)) + index: Mapped[int] = mapped_column(sa.Integer) + predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255)) + node_execution_id: Mapped[Optional[str]] = mapped_column(String(255)) + node_id: Mapped[str] = mapped_column(String(255)) + node_type: Mapped[str] = mapped_column(String(255)) + title: Mapped[str] = mapped_column(String(255)) + inputs: Mapped[Optional[str]] = mapped_column(sa.Text) + process_data: Mapped[Optional[str]] = mapped_column(sa.Text) + outputs: Mapped[Optional[str]] = mapped_column(sa.Text) + status: Mapped[str] = mapped_column(String(255)) + error: Mapped[Optional[str]] = mapped_column(sa.Text) + elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) + execution_metadata: Mapped[Optional[str]] = mapped_column(sa.Text) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + created_by_role: Mapped[str] = mapped_column(String(255)) created_by: Mapped[str] = mapped_column(StringUUID) - finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) @property def created_by_account(self): @@ -834,19 +834,19 @@ class WorkflowAppLog(Base): __tablename__ = "workflow_app_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), - db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), + sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), + sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from: Mapped[str] = mapped_column(db.String(255), nullable=False) - created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -864,6 +864,19 @@ class WorkflowAppLog(Base): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None + def to_dict(self): + return { + "id": self.id, + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "workflow_id": self.workflow_id, + "workflow_run_id": self.workflow_run_id, + "created_from": self.created_from, + "created_by_role": self.created_by_role, + "created_by": self.created_by, + "created_at": self.created_at, + } + class ConversationVariable(Base): __tablename__ = "workflow_conversation_variables" @@ -871,12 +884,12 @@ class ConversationVariable(Base): id: Mapped[str] = mapped_column(StringUUID, primary_key=True) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) - data: Mapped[str] = mapped_column(db.Text, nullable=False) + data: Mapped[str] = mapped_column(sa.Text, nullable=False) created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True + DateTime, nullable=False, server_default=func.current_timestamp(), index=True ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: @@ -933,17 +946,17 @@ class WorkflowDraftVariable(Base): __allow_unmapped__ = True # id is the unique identifier of a draft variable. - id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) created_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), @@ -958,7 +971,7 @@ class WorkflowDraftVariable(Base): # # If it's not edited after creation, its value is `None`. last_edited_at: Mapped[datetime | None] = mapped_column( - db.DateTime, + DateTime, nullable=True, default=None, ) diff --git a/api/pyproject.toml b/api/pyproject.toml index be42b509ed..a86ec7ee6b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -49,6 +49,8 @@ dependencies = [ "opentelemetry-instrumentation==0.48b0", "opentelemetry-instrumentation-celery==0.48b0", "opentelemetry-instrumentation-flask==0.48b0", + "opentelemetry-instrumentation-redis==0.48b0", + "opentelemetry-instrumentation-requests==0.48b0", "opentelemetry-instrumentation-sqlalchemy==0.48b0", "opentelemetry-propagator-b3==1.27.0", # opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0), @@ -114,6 +116,7 @@ dev = [ "pytest-cov~=4.1.0", "pytest-env~=1.1.3", "pytest-mock~=3.14.0", + "testcontainers~=4.10.0", "types-aiofiles~=24.1.0", "types-beautifulsoup4~=4.12.0", "types-cachetools~=5.5.0", @@ -191,6 +194,7 @@ vdb = [ "alibabacloud_tea_openapi~=0.3.9", "chromadb==0.5.20", "clickhouse-connect~=0.7.16", + "clickzetta-connector-python>=0.8.102", "couchbase~=4.3.0", "elasticsearch==8.14.0", "opensearch-py==2.4.0", @@ -210,3 +214,4 @@ vdb = [ "xinference-client~=1.2.2", "mo-vector~=0.1.13", ] + diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 2298acf6eb..2b74fb2dd0 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -3,7 +3,7 @@ import time import click from sqlalchemy import text -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -27,8 +27,8 @@ def clean_embedding_cache_task(): .all() ) embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] - except NotFound: - break + except SQLAlchemyError: + raise if embedding_ids: for embedding_id in embedding_ids: db.session.execute( diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 4c35745959..a896c818a5 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -3,7 +3,7 @@ import logging import time import click -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -42,8 +42,8 @@ def clean_messages(): .all() ) - except NotFound: - break + except SQLAlchemyError: + raise if not messages: break for message in messages: diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 7887835bc5..940da5309e 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -3,7 +3,7 @@ import time import click from sqlalchemy import func, select -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -65,8 +65,8 @@ def clean_unused_datasets_task(): datasets = db.paginate(stmt, page=1, per_page=50) - except NotFound: - break + except SQLAlchemyError: + raise if datasets.items is None or len(datasets.items) == 0: break for dataset in datasets: @@ -146,8 +146,8 @@ def clean_unused_datasets_task(): ) datasets = db.paginate(stmt, page=1, per_page=50) - except NotFound: - break + except SQLAlchemyError: + raise if datasets.items is None or len(datasets.items) == 0: break for dataset in datasets: diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index a05e1358ed..f0d3bed057 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -1,8 +1,8 @@ import logging from datetime import datetime -from urllib.parse import urlparse import click +from kombu.utils.url import parse_url # type: ignore from redis import Redis import app @@ -10,16 +10,13 @@ from configs import dify_config from extensions.ext_database import db from libs.email_i18n import EmailType, get_email_i18n_service -# Create a dedicated Redis connection (using the same configuration as Celery) -celery_broker_url = dify_config.CELERY_BROKER_URL - -parsed = urlparse(celery_broker_url) -host = parsed.hostname or "localhost" -port = parsed.port or 6379 -password = parsed.password or None -redis_db = parsed.path.strip("/") or "1" # type: ignore - -celery_redis = Redis(host=host, port=port, password=password, db=redis_db) +redis_config = parse_url(dify_config.CELERY_BROKER_URL) +celery_redis = Redis( + host=redis_config.get("hostname") or "localhost", + port=redis_config.get("port") or 6379, + password=redis_config.get("password") or None, + db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1, +) @app.celery.task(queue="monitor") diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index cfa917daf6..b7a047914e 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -266,6 +266,54 @@ class AppAnnotationService: annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id ) + @classmethod + def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): + # get app info + app = ( + db.session.query(App) + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + + # Fetch annotations and their settings in a single query + annotations_to_delete = ( + db.session.query(MessageAnnotation, AppAnnotationSetting) + .outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id) + .filter(MessageAnnotation.id.in_(annotation_ids)) + .all() + ) + + if not annotations_to_delete: + return {"deleted_count": 0} + + # Step 1: Extract IDs for bulk operations + annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete] + + # Step 2: Bulk delete hit histories in a single query + db.session.query(AppAnnotationHitHistory).filter( + AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete) + ).delete(synchronize_session=False) + + # Step 3: Trigger async tasks for search index deletion + for annotation, annotation_setting in annotations_to_delete: + if annotation_setting: + delete_annotation_index_task.delay( + annotation.id, app_id, current_user.current_tenant_id, annotation_setting.collection_binding_id + ) + + # Step 4: Bulk delete annotations in a single query + deleted_count = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.id.in_(annotation_ids_to_delete)) + .delete(synchronize_session=False) + ) + + db.session.commit() + return {"deleted_count": deleted_count} + @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: # get app info diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index fe0efd061d..2aa9f6cabd 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -12,6 +12,7 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from packaging import version +from packaging.version import parse as parse_version from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session @@ -269,7 +270,7 @@ class AppDslService: check_dependencies_pending_data = None if dependencies: check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies] - elif imported_version <= "0.1.5": + elif parse_version(imported_version) <= parse_version("0.1.5"): if "workflow" in data: graph = data.get("workflow", {}).get("graph", {}) dependencies_list = self._extract_dependencies_from_workflow_graph(graph) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 6f7e705b52..6792324ec8 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,5 +1,6 @@ +import uuid from collections.abc import Generator, Mapping -from typing import Any, Union +from typing import Any, Optional, Union from openai._exceptions import RateLimitError @@ -15,6 +16,7 @@ from libs.helper import RateLimiter from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow from services.billing_service import BillingService +from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService @@ -86,7 +88,8 @@ class AppGenerateService: request_id=request_id, ) elif app_model.mode == AppMode.ADVANCED_CHAT.value: - workflow = cls._get_workflow(app_model, invoke_from) + workflow_id = args.get("workflow_id") + workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( AdvancedChatAppGenerator.convert_to_event_stream( AdvancedChatAppGenerator().generate( @@ -101,7 +104,8 @@ class AppGenerateService: request_id=request_id, ) elif app_model.mode == AppMode.WORKFLOW.value: - workflow = cls._get_workflow(app_model, invoke_from) + workflow_id = args.get("workflow_id") + workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( WorkflowAppGenerator.convert_to_event_stream( WorkflowAppGenerator().generate( @@ -210,14 +214,27 @@ class AppGenerateService: ) @classmethod - def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow: + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow: """ Get workflow :param app_model: app model :param invoke_from: invoke from + :param workflow_id: optional workflow id to specify a specific version :return: """ workflow_service = WorkflowService() + + # If workflow_id is specified, get the specific workflow version + if workflow_id: + try: + workflow_uuid = uuid.UUID(workflow_id) + except ValueError: + raise WorkflowIdFormatError(f"Invalid workflow_id format: '{workflow_id}'. ") + workflow = workflow_service.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id) + if not workflow: + raise WorkflowNotFoundError(f"Workflow not found with id: {workflow_id}") + return workflow + if invoke_from == InvokeFrom.DEBUGGER: # fetch draft workflow by app_model workflow = workflow_service.get_draft_workflow(app_model=app_model) diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 5a12aa2e54..476fce0057 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -159,9 +159,9 @@ class BillingService: ): limiter_key = f"{account_id}:{tenant_id}" if cls.compliance_download_rate_limiter.is_rate_limited(limiter_key): - from controllers.console.error import CompilanceRateLimitError + from controllers.console.error import ComplianceRateLimitError - raise CompilanceRateLimitError() + raise ComplianceRateLimitError() json = { "doc_name": doc_name, diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index d057a14afb..b28afcaa41 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -13,7 +13,19 @@ from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant -from models.model import App, Conversation, Message +from models.model import ( + App, + AppAnnotationHitHistory, + Conversation, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from models.workflow import WorkflowAppLog from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService @@ -21,6 +33,85 @@ logger = logging.getLogger(__name__) class ClearFreePlanTenantExpiredLogs: + @classmethod + def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None: + """ + Clean up message-related tables to avoid data redundancy. + This method cleans up tables that have foreign key relationships with Message. + + Args: + session: Database session, the same with the one in process_tenant method + tenant_id: Tenant ID for logging purposes + batch_message_ids: List of message IDs to clean up + """ + if not batch_message_ids: + return + + # Clean up each related table + related_tables = [ + (MessageFeedback, "message_feedbacks"), + (MessageFile, "message_files"), + (MessageAnnotation, "message_annotations"), + (MessageChain, "message_chains"), + (MessageAgentThought, "message_agent_thoughts"), + (AppAnnotationHitHistory, "app_annotation_hit_histories"), + (SavedMessage, "saved_messages"), + ] + + for model, table_name in related_tables: + # Query records related to expired messages + records = ( + session.query(model) + .filter( + model.message_id.in_(batch_message_ids), # type: ignore + ) + .all() + ) + + if len(records) == 0: + continue + + # Save records before deletion + record_ids = [record.id for record in records] + try: + record_data = [] + for record in records: + try: + if hasattr(record, "to_dict"): + record_data.append(record.to_dict()) + else: + # if record doesn't have to_dict method, we need to transform it to dict manually + record_dict = {} + for column in record.__table__.columns: + record_dict[column.name] = getattr(record, column.name) + record_data.append(record_dict) + except Exception: + logger.exception("Failed to transform %s record: %s", table_name, record.id) + continue + + if record_data: + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/{table_name}/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder(record_data), + ).encode("utf-8"), + ) + except Exception: + logger.exception("Failed to save %s records", table_name) + + session.query(model).filter( + model.id.in_(record_ids), # type: ignore + ).delete(synchronize_session=False) + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(record_ids)} " + f"{table_name} records for tenant {tenant_id}" + ) + ) + @classmethod def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): with flask_app.app_context(): @@ -58,6 +149,7 @@ class ClearFreePlanTenantExpiredLogs: Message.id.in_(message_ids), ).delete(synchronize_session=False) + cls._clear_message_related_tables(session, tenant_id, message_ids) session.commit() click.echo( @@ -199,6 +291,48 @@ class ClearFreePlanTenantExpiredLogs: if len(workflow_runs) < batch: break + while True: + with Session(db.engine).no_autoflush as session: + workflow_app_logs = ( + session.query(WorkflowAppLog) + .filter( + WorkflowAppLog.tenant_id == tenant_id, + WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days), + ) + .limit(batch) + .all() + ) + + if len(workflow_app_logs) == 0: + break + + # save workflow app logs + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_app_logs/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder( + [workflow_app_log.to_dict() for workflow_app_log in workflow_app_logs], + ), + ).encode("utf-8"), + ) + + workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs] + + # delete workflow app logs + session.query(WorkflowAppLog).filter( + WorkflowAppLog.id.in_(workflow_app_log_ids), + ).delete(synchronize_session=False) + session.commit() + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_app_log_ids)}" + f" workflow app logs for tenant {tenant_id}" + ) + ) + @classmethod def process(cls, days: int, batch: int, tenant_ids: list[str]): """ diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 206c832a20..692a3639cd 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,12 +1,15 @@ from collections.abc import Callable, Sequence -from typing import Optional, Union +from typing import Any, Optional, Union from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator +from core.variables.types import SegmentType +from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from extensions.ext_database import db +from factories import variable_factory from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ConversationVariable @@ -15,6 +18,7 @@ from models.model import App, Conversation, EndUser, Message from services.errors.conversation import ( ConversationNotExistsError, ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, LastConversationNotExistsError, ) from services.errors.message import MessageNotExistsError @@ -220,3 +224,82 @@ class ConversationService: ] return InfiniteScrollPagination(variables, limit, has_more) + + @classmethod + def update_conversation_variable( + cls, + app_model: App, + conversation_id: str, + variable_id: str, + user: Optional[Union[Account, EndUser]], + new_value: Any, + ) -> dict: + """ + Update a conversation variable's value. + + Args: + app_model: The app model + conversation_id: The conversation ID + variable_id: The variable ID to update + user: The user (Account or EndUser) + new_value: The new value for the variable + + Returns: + Dictionary containing the updated variable information + + Raises: + ConversationNotExistsError: If the conversation doesn't exist + ConversationVariableNotExistsError: If the variable doesn't exist + ConversationVariableTypeMismatchError: If the new value type doesn't match the variable's expected type + """ + # Verify conversation exists and user has access + conversation = cls.get_conversation(app_model, conversation_id, user) + + # Get the existing conversation variable + stmt = ( + select(ConversationVariable) + .where(ConversationVariable.app_id == app_model.id) + .where(ConversationVariable.conversation_id == conversation.id) + .where(ConversationVariable.id == variable_id) + ) + + with Session(db.engine) as session: + existing_variable = session.scalar(stmt) + if not existing_variable: + raise ConversationVariableNotExistsError() + + # Convert existing variable to Variable object + current_variable = existing_variable.to_variable() + + # Validate that the new value type matches the expected variable type + expected_type = SegmentType(current_variable.value_type) + if not expected_type.is_valid(new_value): + inferred_type = SegmentType.infer_segment_type(new_value) + raise ConversationVariableTypeMismatchError( + f"Type mismatch: variable '{current_variable.name}' expects {expected_type.value}, " + f"but got {inferred_type.value if inferred_type else 'unknown'} type" + ) + + # Create updated variable with new value only, preserving everything else + updated_variable_dict = { + "id": current_variable.id, + "name": current_variable.name, + "description": current_variable.description, + "value_type": current_variable.value_type, + "value": new_value, + "selector": current_variable.selector, + } + + updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict) + + # Use the conversation variable updater to persist the changes + updater = conversation_variable_updater_factory() + updater.update(conversation_id, updated_variable) + updater.flush() + + # Return the updated variable data + return { + "created_at": existing_variable.created_at, + "updated_at": naive_utc_now(), # Update timestamp + **updated_variable.model_dump(), + } diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 1280399990..8934608da1 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -266,7 +266,7 @@ class DatasetService: "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError(f"The dataset in unavailable, due to: {ex.description}") + raise ValueError(f"The dataset is unavailable, due to: {ex.description}") @staticmethod def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): @@ -370,7 +370,7 @@ class DatasetService: raise ValueError("External knowledge api id is required.") # Update metadata fields dataset.updated_by = user.id if user else None - dataset.updated_at = datetime.datetime.utcnow() + dataset.updated_at = naive_utc_now() db.session.add(dataset) # Update external knowledge binding @@ -2040,6 +2040,7 @@ class SegmentService: db.session.add(segment_document) # update document word count + assert document.word_count is not None document.word_count += segment_document.word_count db.session.add(document) db.session.commit() @@ -2124,6 +2125,7 @@ class SegmentService: else: keywords_list.append(None) # update document word count + assert document.word_count is not None document.word_count += increment_word_count db.session.add(document) try: @@ -2185,6 +2187,7 @@ class SegmentService: db.session.commit() # update document word count if word_count_change != 0: + assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task @@ -2260,6 +2263,7 @@ class SegmentService: word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: + assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) db.session.add(segment) @@ -2323,6 +2327,7 @@ class SegmentService: delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) db.session.delete(segment) # update document word count + assert document.word_count is not None document.word_count -= segment.word_count db.session.add(document) db.session.commit() @@ -2367,7 +2372,7 @@ class SegmentService: ) if not segments: return - real_deal_segmment_ids = [] + real_deal_segment_ids = [] for segment in segments: indexing_cache_key = f"segment_{segment.id}_indexing" cache_result = redis_client.get(indexing_cache_key) @@ -2377,10 +2382,10 @@ class SegmentService: segment.disabled_at = None segment.disabled_by = None db.session.add(segment) - real_deal_segmment_ids.append(segment.id) + real_deal_segment_ids.append(segment.id) db.session.commit() - enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) elif action == "disable": segments = ( db.session.query(DocumentSegment) @@ -2394,7 +2399,7 @@ class SegmentService: ) if not segments: return - real_deal_segmment_ids = [] + real_deal_segment_ids = [] for segment in segments: indexing_cache_key = f"segment_{segment.id}_indexing" cache_result = redis_client.get(indexing_cache_key) @@ -2404,10 +2409,10 @@ class SegmentService: segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.disabled_by = current_user.id db.session.add(segment) - real_deal_segmment_ids.append(segment.id) + real_deal_segment_ids.append(segment.id) db.session.commit() - disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) else: raise InvalidActionError() @@ -2665,7 +2670,7 @@ class SegmentService: # check segment segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id) + .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) .first() ) if not segment: diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 54d45f45ea..f8612456d6 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -52,6 +52,16 @@ class EnterpriseService: return data.get("result", False) + @classmethod + def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_codes: list[str]): + if not app_codes: + return {} + body = {"userId": user_id, "appCodes": app_codes} + data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body) + if not data: + raise ValueError("No data found.") + return data.get("permissions", {}) + @classmethod def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings: if not app_id: diff --git a/api/services/errors/app.py b/api/services/errors/app.py index 5d348c61be..390716a47f 100644 --- a/api/services/errors/app.py +++ b/api/services/errors/app.py @@ -8,3 +8,11 @@ class WorkflowHashNotEqualError(Exception): class IsDraftWorkflowError(Exception): pass + + +class WorkflowNotFoundError(Exception): + pass + + +class WorkflowIdFormatError(Exception): + pass diff --git a/api/services/errors/conversation.py b/api/services/errors/conversation.py index f8051e3417..a123f99b59 100644 --- a/api/services/errors/conversation.py +++ b/api/services/errors/conversation.py @@ -15,3 +15,7 @@ class ConversationCompletedError(Exception): class ConversationVariableNotExistsError(BaseServiceError): pass + + +class ConversationVariableTypeMismatchError(BaseServiceError): + pass diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index cfcb121153..2a83588f41 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -79,7 +79,10 @@ class MetadataService: document_ids = [binding.document_id for binding in dataset_metadata_bindings] documents = DocumentService.get_document_by_ids(document_ids) for document in documents: - doc_metadata = copy.deepcopy(document.doc_metadata) + if not document.doc_metadata: + doc_metadata = {} + else: + doc_metadata = copy.deepcopy(document.doc_metadata) value = doc_metadata.pop(old_name, None) doc_metadata[name] = value document.doc_metadata = doc_metadata @@ -109,7 +112,10 @@ class MetadataService: document_ids = [binding.document_id for binding in dataset_metadata_bindings] documents = DocumentService.get_document_by_ids(document_ids) for document in documents: - doc_metadata = copy.deepcopy(document.doc_metadata) + if not document.doc_metadata: + doc_metadata = {} + else: + doc_metadata = copy.deepcopy(document.doc_metadata) doc_metadata.pop(metadata.name, None) document.doc_metadata = doc_metadata db.session.add(document) @@ -137,7 +143,6 @@ class MetadataService: lock_key = f"dataset_metadata_lock_{dataset.id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) - dataset.built_in_field_enabled = True db.session.add(dataset) documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) if documents: @@ -153,6 +158,7 @@ class MetadataService: doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value document.doc_metadata = doc_metadata db.session.add(document) + dataset.built_in_field_enabled = True db.session.commit() except Exception: logging.exception("Enable built-in field failed") @@ -166,13 +172,15 @@ class MetadataService: lock_key = f"dataset_metadata_lock_{dataset.id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) - dataset.built_in_field_enabled = False db.session.add(dataset) documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) document_ids = [] if documents: for document in documents: - doc_metadata = copy.deepcopy(document.doc_metadata) + if not document.doc_metadata: + doc_metadata = {} + else: + doc_metadata = copy.deepcopy(document.doc_metadata) doc_metadata.pop(BuiltInField.document_name.value, None) doc_metadata.pop(BuiltInField.uploader.value, None) doc_metadata.pop(BuiltInField.upload_date.value, None) @@ -181,6 +189,7 @@ class MetadataService: document.doc_metadata = doc_metadata db.session.add(document) document_ids.append(document.id) + dataset.built_in_field_enabled = False db.session.commit() except Exception: logging.exception("Disable built-in field failed") diff --git a/api/services/plugin/data_migration.py b/api/services/plugin/data_migration.py index 7a4f886bf5..c5ad65ec87 100644 --- a/api/services/plugin/data_migration.py +++ b/api/services/plugin/data_migration.py @@ -2,6 +2,7 @@ import json import logging import click +import sqlalchemy as sa from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID from models.engine import db @@ -38,7 +39,7 @@ class PluginDataMigration: where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != '' limit 1000""" with db.engine.begin() as conn: - rs = conn.execute(db.text(sql)) + rs = conn.execute(sa.text(sql)) current_iter_count = 0 for i in rs: @@ -94,7 +95,7 @@ limit 1000""" :provider_name {update_retrieval_model_sql} where id = :record_id""" - conn.execute(db.text(sql), params) + conn.execute(sa.text(sql), params) click.echo( click.style( f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})", @@ -148,7 +149,7 @@ limit 1000""" params = {"last_id": last_id or ""} with db.engine.begin() as conn: - rs = conn.execute(db.text(sql), params) + rs = conn.execute(sa.text(sql), params) current_iter_count = 0 batch_updates = [] @@ -193,7 +194,7 @@ limit 1000""" SET {provider_column_name} = :updated_value WHERE id = :record_id """ - conn.execute(db.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates]) + conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates]) click.echo( click.style( f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]", diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 222d70a317..221069b2b3 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -9,6 +9,7 @@ from typing import Any, Optional from uuid import uuid4 import click +import sqlalchemy as sa import tqdm from flask import Flask, current_app from sqlalchemy.orm import Session @@ -197,7 +198,7 @@ class PluginMigration: """ with Session(db.engine) as session: rs = session.execute( - db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id} + sa.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id} ) result = [] for row in rs: diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 841eeb4333..da0fc58566 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -508,10 +508,10 @@ class BuiltinToolManageService: oauth_params = encrypter.decrypt(user_client.oauth_params) return oauth_params - # only verified provider can use custom oauth client - is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified( - tenant_id, provider.plugin_unique_identifier - ) + # only verified provider can use official oauth client + is_verified = not isinstance( + provider_controller, PluginToolProviderController + ) or PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier) if not is_verified: return oauth_params diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 3164e010b4..6bbb3bca04 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -256,7 +256,7 @@ class WorkflowDraftVariableService: def _reset_node_var_or_sys_var( self, workflow: Workflow, variable: WorkflowDraftVariable ) -> WorkflowDraftVariable | None: - # If a variable does not allow updating, it makes no sence to resetting it. + # If a variable does not allow updating, it makes no sense to reset it. if not variable.editable: return variable # No execution record for this variable, delete the variable instead. @@ -422,7 +422,7 @@ class WorkflowDraftVariableService: description=conv_var.description, ) draft_conv_vars.append(draft_var) - _batch_upsert_draft_varaible( + _batch_upsert_draft_variable( self._session, draft_conv_vars, policy=_UpsertPolicy.IGNORE, @@ -434,7 +434,7 @@ class _UpsertPolicy(StrEnum): OVERWRITE = "overwrite" -def _batch_upsert_draft_varaible( +def _batch_upsert_draft_variable( session: Session, draft_vars: Sequence[WorkflowDraftVariable], policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE, @@ -478,7 +478,7 @@ def _batch_upsert_draft_varaible( "node_execution_id": stmt.excluded.node_execution_id, }, ) - elif _UpsertPolicy.IGNORE: + elif policy == _UpsertPolicy.IGNORE: stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name()) else: raise Exception("Invalid value for update policy.") @@ -721,7 +721,7 @@ class DraftVariableSaver: draft_vars = self._build_variables_from_start_mapping(outputs) else: draft_vars = self._build_variables_from_mapping(outputs) - _batch_upsert_draft_varaible(self._session, draft_vars) + _batch_upsert_draft_variable(self._session, draft_vars) @staticmethod def _should_variable_be_editable(node_id: str, name: str) -> bool: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 8588144980..d2715a61fe 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -129,7 +129,10 @@ class WorkflowService: if not workflow: return None if workflow.version == Workflow.VERSION_DRAFT: - raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}") + raise IsDraftWorkflowError( + f"Cannot use draft workflow version. Workflow ID: {workflow_id}. " + f"Please use a published workflow version or leave workflow_id empty." + ) return workflow def get_published_workflow(self, app_model: App) -> Optional[Workflow]: @@ -441,9 +444,9 @@ class WorkflowService: self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: """ - Run draft workflow node + Run free workflow node """ - # run draft workflow node + # run free workflow node start_at = time.perf_counter() node_execution = self._handle_node_run_result( diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 714e30acc3..dee43cd854 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -134,6 +134,7 @@ def batch_create_segment_to_index_task( db.session.add(segment_document) document_segments.append(segment_document) # update document word count + assert dataset_document.word_count is not None dataset_document.word_count += word_count_change db.session.add(dataset_document) # add index to db diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index fe6d613b1c..c769446ed5 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -56,15 +56,17 @@ def clean_dataset_task( documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() + # Fix: Always clean vector database resources regardless of document existence + # This ensures all 33 vector databases properly drop tables/collections/indices + if doc_form is None: + raise ValueError("Index type must be specified.") + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) + if documents is None or len(documents) == 0: logging.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green")) else: logging.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green")) - # Specify the index type before initializing the index processor - if doc_form is None: - raise ValueError("Index type must be specified.") - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) for document in documents: db.session.delete(document) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index b6f772dd60..929b60e529 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -3,6 +3,7 @@ import time from collections.abc import Callable import click +import sqlalchemy as sa from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError @@ -331,7 +332,7 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str): def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: while True: with db.engine.begin() as conn: - rs = conn.execute(db.text(query_sql), params) + rs = conn.execute(sa.text(query_sql), params) if rs.rowcount == 0: break diff --git a/api/tests/integration_tests/controllers/console/app/test_description_validation.py b/api/tests/integration_tests/controllers/console/app/test_description_validation.py new file mode 100644 index 0000000000..2d0ceac760 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_description_validation.py @@ -0,0 +1,168 @@ +""" +Unit tests for App description validation functions. + +This test module validates the 400-character limit enforcement +for App descriptions across all creation and editing endpoints. +""" + +import os +import sys + +import pytest + +# Add the API root to Python path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + + +class TestAppDescriptionValidationUnit: + """Unit tests for description validation function""" + + def test_validate_description_length_function(self): + """Test the _validate_description_length function directly""" + from controllers.console.app.app import _validate_description_length + + # Test valid descriptions + assert _validate_description_length("") == "" + assert _validate_description_length("x" * 400) == "x" * 400 + assert _validate_description_length(None) is None + + # Test invalid descriptions + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 401) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 500) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 1000) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_validation_consistency_with_dataset(self): + """Test that App and Dataset validation functions are consistent""" + from controllers.console.app.app import _validate_description_length as app_validate + from controllers.console.datasets.datasets import _validate_description_length as dataset_validate + from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate + + # Test same valid inputs + valid_desc = "x" * 400 + assert app_validate(valid_desc) == dataset_validate(valid_desc) == service_dataset_validate(valid_desc) + assert app_validate("") == dataset_validate("") == service_dataset_validate("") + assert app_validate(None) == dataset_validate(None) == service_dataset_validate(None) + + # Test same invalid inputs produce same error + invalid_desc = "x" * 401 + + app_error = None + dataset_error = None + service_dataset_error = None + + try: + app_validate(invalid_desc) + except ValueError as e: + app_error = str(e) + + try: + dataset_validate(invalid_desc) + except ValueError as e: + dataset_error = str(e) + + try: + service_dataset_validate(invalid_desc) + except ValueError as e: + service_dataset_error = str(e) + + assert app_error == dataset_error == service_dataset_error + assert app_error == "Description cannot exceed 400 characters." + + def test_boundary_values(self): + """Test boundary values for description validation""" + from controllers.console.app.app import _validate_description_length + + # Test exact boundary + exactly_400 = "x" * 400 + assert _validate_description_length(exactly_400) == exactly_400 + + # Test just over boundary + just_over_400 = "x" * 401 + with pytest.raises(ValueError): + _validate_description_length(just_over_400) + + # Test just under boundary + just_under_400 = "x" * 399 + assert _validate_description_length(just_under_400) == just_under_400 + + def test_edge_cases(self): + """Test edge cases for description validation""" + from controllers.console.app.app import _validate_description_length + + # Test None input + assert _validate_description_length(None) is None + + # Test empty string + assert _validate_description_length("") == "" + + # Test single character + assert _validate_description_length("a") == "a" + + # Test unicode characters + unicode_desc = "测试" * 200 # 400 characters in Chinese + assert _validate_description_length(unicode_desc) == unicode_desc + + # Test unicode over limit + unicode_over = "测试" * 201 # 402 characters + with pytest.raises(ValueError): + _validate_description_length(unicode_over) + + def test_whitespace_handling(self): + """Test how validation handles whitespace""" + from controllers.console.app.app import _validate_description_length + + # Test description with spaces + spaces_400 = " " * 400 + assert _validate_description_length(spaces_400) == spaces_400 + + # Test description with spaces over limit + spaces_401 = " " * 401 + with pytest.raises(ValueError): + _validate_description_length(spaces_401) + + # Test mixed content + mixed_400 = "a" * 200 + " " * 200 + assert _validate_description_length(mixed_400) == mixed_400 + + # Test mixed over limit + mixed_401 = "a" * 200 + " " * 201 + with pytest.raises(ValueError): + _validate_description_length(mixed_401) + + +if __name__ == "__main__": + # Run tests directly + import traceback + + test_instance = TestAppDescriptionValidationUnit() + test_methods = [method for method in dir(test_instance) if method.startswith("test_")] + + passed = 0 + failed = 0 + + for test_method in test_methods: + try: + print(f"Running {test_method}...") + getattr(test_instance, test_method)() + print(f"✅ {test_method} PASSED") + passed += 1 + except Exception as e: + print(f"❌ {test_method} FAILED: {str(e)}") + traceback.print_exc() + failed += 1 + + print(f"\n📊 Test Results: {passed} passed, {failed} failed") + + if failed == 0: + print("🎉 All tests passed!") + else: + print("💥 Some tests failed!") + sys.exit(1) diff --git a/api/tests/integration_tests/storage/test_clickzetta_volume.py b/api/tests/integration_tests/storage/test_clickzetta_volume.py new file mode 100644 index 0000000000..293b469ef3 --- /dev/null +++ b/api/tests/integration_tests/storage/test_clickzetta_volume.py @@ -0,0 +1,168 @@ +"""Integration tests for ClickZetta Volume Storage.""" + +import os +import tempfile +import unittest + +import pytest + +from extensions.storage.clickzetta_volume.clickzetta_volume_storage import ( + ClickZettaVolumeConfig, + ClickZettaVolumeStorage, +) + + +class TestClickZettaVolumeStorage(unittest.TestCase): + """Test cases for ClickZetta Volume Storage.""" + + def setUp(self): + """Set up test environment.""" + self.config = ClickZettaVolumeConfig( + username=os.getenv("CLICKZETTA_USERNAME", "test_user"), + password=os.getenv("CLICKZETTA_PASSWORD", "test_pass"), + instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"), + service=os.getenv("CLICKZETTA_SERVICE", "uat-api.clickzetta.com"), + workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"), + vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"), + schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"), + volume_type="table", + table_prefix="test_dataset_", + ) + + @pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided") + def test_user_volume_operations(self): + """Test basic operations with User Volume.""" + config = self.config + config.volume_type = "user" + + storage = ClickZettaVolumeStorage(config) + + # Test file operations + test_filename = "test_file.txt" + test_content = b"Hello, ClickZetta Volume!" + + # Save file + storage.save(test_filename, test_content) + + # Check if file exists + assert storage.exists(test_filename) + + # Load file + loaded_content = storage.load_once(test_filename) + assert loaded_content == test_content + + # Test streaming + stream_content = b"" + for chunk in storage.load_stream(test_filename): + stream_content += chunk + assert stream_content == test_content + + # Test download + with tempfile.NamedTemporaryFile() as temp_file: + storage.download(test_filename, temp_file.name) + with open(temp_file.name, "rb") as f: + downloaded_content = f.read() + assert downloaded_content == test_content + + # Test scan + files = storage.scan("", files=True, directories=False) + assert test_filename in files + + # Delete file + storage.delete(test_filename) + assert not storage.exists(test_filename) + + @pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided") + def test_table_volume_operations(self): + """Test basic operations with Table Volume.""" + config = self.config + config.volume_type = "table" + + storage = ClickZettaVolumeStorage(config) + + # Test file operations with dataset_id + dataset_id = "12345" + test_filename = f"{dataset_id}/test_file.txt" + test_content = b"Hello, Table Volume!" + + # Save file + storage.save(test_filename, test_content) + + # Check if file exists + assert storage.exists(test_filename) + + # Load file + loaded_content = storage.load_once(test_filename) + assert loaded_content == test_content + + # Test scan for dataset + files = storage.scan(dataset_id, files=True, directories=False) + assert "test_file.txt" in files + + # Delete file + storage.delete(test_filename) + assert not storage.exists(test_filename) + + def test_config_validation(self): + """Test configuration validation.""" + # Test missing required fields + with pytest.raises(ValueError): + ClickZettaVolumeConfig( + username="", # Empty username should fail + password="pass", + instance="instance", + ) + + # Test invalid volume type + with pytest.raises(ValueError): + ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type") + + # Test external volume without volume_name + with pytest.raises(ValueError): + ClickZettaVolumeConfig( + username="user", + password="pass", + instance="instance", + volume_type="external", + # Missing volume_name + ) + + def test_volume_path_generation(self): + """Test volume path generation for different types.""" + storage = ClickZettaVolumeStorage(self.config) + + # Test table volume path + path = storage._get_volume_path("test.txt", "12345") + assert path == "test_dataset_12345/test.txt" + + # Test path with existing dataset_id prefix + path = storage._get_volume_path("12345/test.txt") + assert path == "12345/test.txt" + + # Test user volume + storage._config.volume_type = "user" + path = storage._get_volume_path("test.txt") + assert path == "test.txt" + + def test_sql_prefix_generation(self): + """Test SQL prefix generation for different volume types.""" + storage = ClickZettaVolumeStorage(self.config) + + # Test table volume SQL prefix + prefix = storage._get_volume_sql_prefix("12345") + assert prefix == "TABLE VOLUME test_dataset_12345" + + # Test user volume SQL prefix + storage._config.volume_type = "user" + prefix = storage._get_volume_sql_prefix() + assert prefix == "USER VOLUME" + + # Test external volume SQL prefix + storage._config.volume_type = "external" + storage._config.volume_name = "my_external_volume" + prefix = storage._get_volume_sql_prefix() + assert prefix == "VOLUME my_external_volume" + + +if __name__ == "__main__": + unittest.main() diff --git a/api/tests/integration_tests/vdb/clickzetta/README.md b/api/tests/integration_tests/vdb/clickzetta/README.md new file mode 100644 index 0000000000..c16dca8018 --- /dev/null +++ b/api/tests/integration_tests/vdb/clickzetta/README.md @@ -0,0 +1,25 @@ +# Clickzetta Integration Tests + +## Running Tests + +To run the Clickzetta integration tests, you need to set the following environment variables: + +```bash +export CLICKZETTA_USERNAME=your_username +export CLICKZETTA_PASSWORD=your_password +export CLICKZETTA_INSTANCE=your_instance +export CLICKZETTA_SERVICE=api.clickzetta.com +export CLICKZETTA_WORKSPACE=your_workspace +export CLICKZETTA_VCLUSTER=your_vcluster +export CLICKZETTA_SCHEMA=dify +``` + +Then run the tests: + +```bash +pytest api/tests/integration_tests/vdb/clickzetta/ +``` + +## Security Note + +Never commit credentials to the repository. Always use environment variables or secure credential management systems. diff --git a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py new file mode 100644 index 0000000000..8b57132772 --- /dev/null +++ b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py @@ -0,0 +1,224 @@ +import os + +import pytest + +from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector +from core.rag.models.document import Document +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + + +class TestClickzettaVector(AbstractVectorTest): + """ + Test cases for Clickzetta vector database integration. + """ + + @pytest.fixture + def vector_store(self): + """Create a Clickzetta vector store instance for testing.""" + # Skip test if Clickzetta credentials are not configured + if not os.getenv("CLICKZETTA_USERNAME"): + pytest.skip("CLICKZETTA_USERNAME is not configured") + if not os.getenv("CLICKZETTA_PASSWORD"): + pytest.skip("CLICKZETTA_PASSWORD is not configured") + if not os.getenv("CLICKZETTA_INSTANCE"): + pytest.skip("CLICKZETTA_INSTANCE is not configured") + + config = ClickzettaConfig( + username=os.getenv("CLICKZETTA_USERNAME", ""), + password=os.getenv("CLICKZETTA_PASSWORD", ""), + instance=os.getenv("CLICKZETTA_INSTANCE", ""), + service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"), + workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"), + vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"), + schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"), + batch_size=10, # Small batch size for testing + enable_inverted_index=True, + analyzer_type="chinese", + analyzer_mode="smart", + vector_distance_function="cosine_distance", + ) + + with setup_mock_redis(): + vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config) + + yield vector + + # Cleanup: delete the test collection + try: + vector.delete() + except Exception: + pass + + def test_clickzetta_vector_basic_operations(self, vector_store): + """Test basic CRUD operations on Clickzetta vector store.""" + # Prepare test data + texts = [ + "这是第一个测试文档,包含一些中文内容。", + "This is the second test document with English content.", + "第三个文档混合了English和中文内容。", + ] + embeddings = [ + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 1.0, 1.1, 1.2], + ] + documents = [ + Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"}) + for i, text in enumerate(texts) + ] + + # Test create (initial insert) + vector_store.create(texts=documents, embeddings=embeddings) + + # Test text_exists + assert vector_store.text_exists("doc_0") + assert not vector_store.text_exists("doc_999") + + # Test search_by_vector + query_vector = [0.1, 0.2, 0.3, 0.4] + results = vector_store.search_by_vector(query_vector, top_k=2) + assert len(results) > 0 + assert results[0].page_content == texts[0] # Should match the first document + + # Test search_by_full_text (Chinese) + results = vector_store.search_by_full_text("中文", top_k=3) + assert len(results) >= 2 # Should find documents with Chinese content + + # Test search_by_full_text (English) + results = vector_store.search_by_full_text("English", top_k=3) + assert len(results) >= 2 # Should find documents with English content + + # Test delete_by_ids + vector_store.delete_by_ids(["doc_0"]) + assert not vector_store.text_exists("doc_0") + assert vector_store.text_exists("doc_1") + + # Test delete_by_metadata_field + vector_store.delete_by_metadata_field("source", "test") + assert not vector_store.text_exists("doc_1") + assert not vector_store.text_exists("doc_2") + + def test_clickzetta_vector_advanced_search(self, vector_store): + """Test advanced search features of Clickzetta vector store.""" + # Prepare test data with more complex metadata + documents = [] + embeddings = [] + for i in range(10): + doc = Document( + page_content=f"Document {i}: " + get_example_text(), + metadata={ + "doc_id": f"adv_doc_{i}", + "category": "technical" if i % 2 == 0 else "general", + "document_id": f"doc_{i // 3}", # Group documents + "importance": i, + }, + ) + documents.append(doc) + # Create varied embeddings + embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i]) + + vector_store.create(texts=documents, embeddings=embeddings) + + # Test vector search with document filter + query_vector = [0.5, 1.0, 1.5, 2.0] + results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"]) + assert len(results) > 0 + # All results should belong to doc_0 or doc_1 groups + for result in results: + assert result.metadata["document_id"] in ["doc_0", "doc_1"] + + # Test score threshold + results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5) + # Check that all results have a score above threshold + for result in results: + assert result.metadata.get("score", 0) >= 0.5 + + def test_clickzetta_batch_operations(self, vector_store): + """Test batch insertion operations.""" + # Prepare large batch of documents + batch_size = 25 + documents = [] + embeddings = [] + + for i in range(batch_size): + doc = Document( + page_content=f"Batch document {i}: This is a test document for batch processing.", + metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"}, + ) + documents.append(doc) + embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)]) + + # Test batch insert + vector_store.add_texts(documents=documents, embeddings=embeddings) + + # Verify all documents were inserted + for i in range(batch_size): + assert vector_store.text_exists(f"batch_doc_{i}") + + # Clean up + vector_store.delete_by_metadata_field("batch", "test_batch") + + def test_clickzetta_edge_cases(self, vector_store): + """Test edge cases and error handling.""" + # Test empty operations + vector_store.create(texts=[], embeddings=[]) + vector_store.add_texts(documents=[], embeddings=[]) + vector_store.delete_by_ids([]) + + # Test special characters in content + special_doc = Document( + page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline", + metadata={"doc_id": "special_doc", "test": "edge_case"}, + ) + embeddings = [[0.1, 0.2, 0.3, 0.4]] + + vector_store.add_texts(documents=[special_doc], embeddings=embeddings) + assert vector_store.text_exists("special_doc") + + # Test search with special characters + results = vector_store.search_by_full_text("quotes", top_k=1) + if results: # Full-text search might not be available + assert len(results) > 0 + + # Clean up + vector_store.delete_by_ids(["special_doc"]) + + def test_clickzetta_full_text_search_modes(self, vector_store): + """Test different full-text search capabilities.""" + # Prepare documents with various language content + documents = [ + Document( + page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"} + ), + Document( + page_content="Clickzetta provides powerful Lakehouse solutions", + metadata={"doc_id": "en_doc_1", "lang": "english"}, + ), + Document( + page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"} + ), + Document( + page_content="Modern data architecture includes Lakehouse technology", + metadata={"doc_id": "en_doc_2", "lang": "english"}, + ), + ] + + embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents] + + vector_store.create(texts=documents, embeddings=embeddings) + + # Test Chinese full-text search + results = vector_store.search_by_full_text("Lakehouse", top_k=4) + assert len(results) >= 2 # Should find at least documents with "Lakehouse" + + # Test English full-text search + results = vector_store.search_by_full_text("solutions", top_k=2) + assert len(results) >= 1 # Should find English documents with "solutions" + + # Test mixed search + results = vector_store.search_by_full_text("数据架构", top_k=2) + assert len(results) >= 1 # Should find Chinese documents with this phrase + + # Clean up + vector_store.delete_by_metadata_field("lang", "chinese") + vector_store.delete_by_metadata_field("lang", "english") diff --git a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py new file mode 100644 index 0000000000..ef54eaa174 --- /dev/null +++ b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +""" +Test Clickzetta integration in Docker environment +""" + +import os +import time + +import requests +from clickzetta import connect + + +def test_clickzetta_connection(): + """Test direct connection to Clickzetta""" + print("=== Testing direct Clickzetta connection ===") + try: + conn = connect( + username=os.getenv("CLICKZETTA_USERNAME", "test_user"), + password=os.getenv("CLICKZETTA_PASSWORD", "test_password"), + instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"), + service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"), + workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"), + vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"), + database=os.getenv("CLICKZETTA_SCHEMA", "dify"), + ) + + with conn.cursor() as cursor: + # Test basic connectivity + cursor.execute("SELECT 1 as test") + result = cursor.fetchone() + print(f"✓ Connection test: {result}") + + # Check if our test table exists + cursor.execute("SHOW TABLES IN dify") + tables = cursor.fetchall() + print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}") + + # Check if test collection exists + test_collection = "collection_test_dataset" + if test_collection in [t[1] for t in tables if t[0] == "dify"]: + cursor.execute(f"DESCRIBE dify.{test_collection}") + columns = cursor.fetchall() + print(f"✓ Table structure for {test_collection}:") + for col in columns: + print(f" - {col[0]}: {col[1]}") + + # Check for indexes + cursor.execute(f"SHOW INDEXES IN dify.{test_collection}") + indexes = cursor.fetchall() + print(f"✓ Indexes on {test_collection}:") + for idx in indexes: + print(f" - {idx}") + + return True + except Exception as e: + print(f"✗ Connection test failed: {e}") + return False + + +def test_dify_api(): + """Test Dify API with Clickzetta backend""" + print("\n=== Testing Dify API ===") + base_url = "http://localhost:5001" + + # Wait for API to be ready + max_retries = 30 + for i in range(max_retries): + try: + response = requests.get(f"{base_url}/console/api/health") + if response.status_code == 200: + print("✓ Dify API is ready") + break + except: + if i == max_retries - 1: + print("✗ Dify API is not responding") + return False + time.sleep(2) + + # Check vector store configuration + try: + # This is a simplified check - in production, you'd use proper auth + print("✓ Dify is configured to use Clickzetta as vector store") + return True + except Exception as e: + print(f"✗ API test failed: {e}") + return False + + +def verify_table_structure(): + """Verify the table structure meets Dify requirements""" + print("\n=== Verifying Table Structure ===") + + expected_columns = { + "id": "VARCHAR", + "page_content": "VARCHAR", + "metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta + "vector": "ARRAY", + } + + expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"] + + print("✓ Expected table structure:") + for col, dtype in expected_columns.items(): + print(f" - {col}: {dtype}") + + print("\n✓ Required metadata fields:") + for field in expected_metadata_fields: + print(f" - {field}") + + print("\n✓ Index requirements:") + print(" - Vector index (HNSW) on 'vector' column") + print(" - Full-text index on 'page_content' (optional)") + print(" - Functional index on metadata->>'$.doc_id' (recommended)") + print(" - Functional index on metadata->>'$.document_id' (recommended)") + + return True + + +def main(): + """Run all tests""" + print("Starting Clickzetta integration tests for Dify Docker\n") + + tests = [ + ("Direct Clickzetta Connection", test_clickzetta_connection), + ("Dify API Status", test_dify_api), + ("Table Structure Verification", verify_table_structure), + ] + + results = [] + for test_name, test_func in tests: + try: + success = test_func() + results.append((test_name, success)) + except Exception as e: + print(f"\n✗ {test_name} crashed: {e}") + results.append((test_name, False)) + + # Summary + print("\n" + "=" * 50) + print("Test Summary:") + print("=" * 50) + + passed = sum(1 for _, success in results if success) + total = len(results) + + for test_name, success in results: + status = "✅ PASSED" if success else "❌ FAILED" + print(f"{test_name}: {status}") + + print(f"\nTotal: {passed}/{total} tests passed") + + if passed == total: + print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.") + print("\nNext steps:") + print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d") + print("2. Access Dify at http://localhost:3000") + print("3. Create a dataset and test vector storage with Clickzetta") + return 0 + else: + print("\n⚠️ Some tests failed. Please check the errors above.") + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py index da549af1b6..aebf3fbda1 100644 --- a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py +++ b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py @@ -2,6 +2,7 @@ import os import uuid import tablestore +from _pytest.python_api import approx from core.rag.datasource.vdb.tablestore.tablestore_vector import ( TableStoreConfig, @@ -16,7 +17,7 @@ from tests.integration_tests.vdb.test_vector_store import ( class TableStoreVectorTest(AbstractVectorTest): - def __init__(self): + def __init__(self, normalize_full_text_score: bool = False): super().__init__() self.vector = TableStoreVector( collection_name=self.collection_name, @@ -25,6 +26,7 @@ class TableStoreVectorTest(AbstractVectorTest): instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"), access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"), access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"), + normalize_full_text_bm25_score=normalize_full_text_score, ), ) @@ -64,7 +66,21 @@ class TableStoreVectorTest(AbstractVectorTest): docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id]) assert len(docs) == 1 assert docs[0].metadata["doc_id"] == self.example_doc_id - assert not hasattr(docs[0], "score") + if self.vector._config.normalize_full_text_bm25_score: + assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3) + else: + assert docs[0].metadata.get("score") is None + + # return none if normalize_full_text_score=true and score_threshold > 0 + docs = self.vector.search_by_full_text( + get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5 + ) + if self.vector._config.normalize_full_text_bm25_score: + assert len(docs) == 0 + else: + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == self.example_doc_id + assert docs[0].metadata.get("score") is None docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())]) assert len(docs) == 0 @@ -80,3 +96,5 @@ class TableStoreVectorTest(AbstractVectorTest): def test_tablestore_vector(setup_mock_redis): TableStoreVectorTest().run_all_tests() + TableStoreVectorTest(normalize_full_text_score=True).run_all_tests() + TableStoreVectorTest(normalize_full_text_score=False).run_all_tests() diff --git a/api/tests/test_containers_integration_tests/__init__.py b/api/tests/test_containers_integration_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py new file mode 100644 index 0000000000..0369a5cbd0 --- /dev/null +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -0,0 +1,328 @@ +""" +TestContainers-based integration test configuration for Dify API. + +This module provides containerized test infrastructure using TestContainers library +to spin up real database and service instances for integration testing. This approach +ensures tests run against actual service implementations rather than mocks, providing +more reliable and realistic test scenarios. +""" + +import logging +import os +from collections.abc import Generator +from typing import Optional + +import pytest +from flask import Flask +from flask.testing import FlaskClient +from sqlalchemy.orm import Session +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_for_logs +from testcontainers.postgres import PostgresContainer +from testcontainers.redis import RedisContainer + +from app_factory import create_app +from models import db + +# Configure logging for test containers +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class DifyTestContainers: + """ + Manages all test containers required for Dify integration tests. + + This class provides a centralized way to manage multiple containers + needed for comprehensive integration testing, including databases, + caches, and search engines. + """ + + def __init__(self): + """Initialize container management with default configurations.""" + self.postgres: Optional[PostgresContainer] = None + self.redis: Optional[RedisContainer] = None + self.dify_sandbox: Optional[DockerContainer] = None + self._containers_started = False + logger.info("DifyTestContainers initialized - ready to manage test containers") + + def start_containers_with_env(self) -> None: + """ + Start all required containers for integration testing. + + This method initializes and starts PostgreSQL, Redis + containers with appropriate configurations for Dify testing. Containers + are started in dependency order to ensure proper initialization. + """ + if self._containers_started: + logger.info("Containers already started - skipping container startup") + return + + logger.info("Starting test containers for Dify integration tests...") + + # Start PostgreSQL container for main application database + # PostgreSQL is used for storing user data, workflows, and application state + logger.info("Initializing PostgreSQL container...") + self.postgres = PostgresContainer( + image="postgres:16-alpine", + ) + self.postgres.start() + db_host = self.postgres.get_container_host_ip() + db_port = self.postgres.get_exposed_port(5432) + os.environ["DB_HOST"] = db_host + os.environ["DB_PORT"] = str(db_port) + os.environ["DB_USERNAME"] = self.postgres.username + os.environ["DB_PASSWORD"] = self.postgres.password + os.environ["DB_DATABASE"] = self.postgres.dbname + logger.info( + "PostgreSQL container started successfully - Host: %s, Port: %s User: %s, Database: %s", + db_host, + db_port, + self.postgres.username, + self.postgres.dbname, + ) + + # Wait for PostgreSQL to be ready + logger.info("Waiting for PostgreSQL to be ready to accept connections...") + wait_for_logs(self.postgres, "is ready to accept connections", timeout=30) + logger.info("PostgreSQL container is ready and accepting connections") + + # Install uuid-ossp extension for UUID generation + logger.info("Installing uuid-ossp extension...") + try: + import psycopg2 + + conn = psycopg2.connect( + host=db_host, + port=db_port, + user=self.postgres.username, + password=self.postgres.password, + database=self.postgres.dbname, + ) + conn.autocommit = True + cursor = conn.cursor() + cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + cursor.close() + conn.close() + logger.info("uuid-ossp extension installed successfully") + except Exception as e: + logger.warning("Failed to install uuid-ossp extension: %s", e) + + # Set up storage environment variables + os.environ["STORAGE_TYPE"] = "opendal" + os.environ["OPENDAL_SCHEME"] = "fs" + os.environ["OPENDAL_FS_ROOT"] = "storage" + + # Start Redis container for caching and session management + # Redis is used for storing session data, cache entries, and temporary data + logger.info("Initializing Redis container...") + self.redis = RedisContainer(image="redis:latest", port=6379) + self.redis.start() + redis_host = self.redis.get_container_host_ip() + redis_port = self.redis.get_exposed_port(6379) + os.environ["REDIS_HOST"] = redis_host + os.environ["REDIS_PORT"] = str(redis_port) + logger.info("Redis container started successfully - Host: %s, Port: %s", redis_host, redis_port) + + # Wait for Redis to be ready + logger.info("Waiting for Redis to be ready to accept connections...") + wait_for_logs(self.redis, "Ready to accept connections", timeout=30) + logger.info("Redis container is ready and accepting connections") + + # Start Dify Sandbox container for code execution environment + # Dify Sandbox provides a secure environment for executing user code + logger.info("Initializing Dify Sandbox container...") + self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest") + self.dify_sandbox.with_exposed_ports(8194) + self.dify_sandbox.env = { + "API_KEY": "test_api_key", + } + self.dify_sandbox.start() + sandbox_host = self.dify_sandbox.get_container_host_ip() + sandbox_port = self.dify_sandbox.get_exposed_port(8194) + os.environ["CODE_EXECUTION_ENDPOINT"] = f"http://{sandbox_host}:{sandbox_port}" + os.environ["CODE_EXECUTION_API_KEY"] = "test_api_key" + logger.info("Dify Sandbox container started successfully - Host: %s, Port: %s", sandbox_host, sandbox_port) + + # Wait for Dify Sandbox to be ready + logger.info("Waiting for Dify Sandbox to be ready to accept connections...") + wait_for_logs(self.dify_sandbox, "config init success", timeout=60) + logger.info("Dify Sandbox container is ready and accepting connections") + + self._containers_started = True + logger.info("All test containers started successfully") + + def stop_containers(self) -> None: + """ + Stop and clean up all test containers. + + This method ensures proper cleanup of all containers to prevent + resource leaks and conflicts between test runs. + """ + if not self._containers_started: + logger.info("No containers to stop - containers were not started") + return + + logger.info("Stopping and cleaning up test containers...") + containers = [self.redis, self.postgres, self.dify_sandbox] + for container in containers: + if container: + try: + container_name = container.image + logger.info("Stopping container: %s", container_name) + container.stop() + logger.info("Successfully stopped container: %s", container_name) + except Exception as e: + # Log error but don't fail the test cleanup + logger.warning("Failed to stop container %s: %s", container, e) + + self._containers_started = False + logger.info("All test containers stopped and cleaned up successfully") + + +# Global container manager instance +_container_manager = DifyTestContainers() + + +def _create_app_with_containers() -> Flask: + """ + Create Flask application configured to use test containers. + + This function creates a Flask application instance that is configured + to connect to the test containers instead of the default development + or production databases. + + Returns: + Flask: Configured Flask application for containerized testing + """ + logger.info("Creating Flask application with test container configuration...") + + # Re-create the config after environment variables have been set + from configs import dify_config + + # Force re-creation of config with new environment variables + dify_config.__dict__.clear() + dify_config.__init__() + + # Create and configure the Flask application + logger.info("Initializing Flask application...") + app = create_app() + logger.info("Flask application created successfully") + + # Initialize database schema + logger.info("Creating database schema...") + with app.app_context(): + db.create_all() + logger.info("Database schema created successfully") + + logger.info("Flask application configured and ready for testing") + return app + + +@pytest.fixture(scope="session") +def set_up_containers_and_env() -> Generator[DifyTestContainers, None, None]: + """ + Session-scoped fixture to manage test containers. + + This fixture ensures containers are started once per test session + and properly cleaned up when all tests are complete. This approach + improves test performance by reusing containers across multiple tests. + + Yields: + DifyTestContainers: Container manager instance + """ + logger.info("=== Starting test session container management ===") + _container_manager.start_containers_with_env() + logger.info("Test containers ready for session") + yield _container_manager + logger.info("=== Cleaning up test session containers ===") + _container_manager.stop_containers() + logger.info("Test session container cleanup completed") + + +@pytest.fixture(scope="session") +def flask_app_with_containers(set_up_containers_and_env) -> Flask: + """ + Session-scoped Flask application fixture using test containers. + + This fixture provides a Flask application instance that is configured + to use the test containers for all database and service connections. + + Args: + containers: Container manager fixture + + Returns: + Flask: Configured Flask application + """ + logger.info("=== Creating session-scoped Flask application ===") + app = _create_app_with_containers() + logger.info("Session-scoped Flask application created successfully") + return app + + +@pytest.fixture +def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]: + """ + Request context fixture for containerized Flask application. + + This fixture provides a Flask request context for tests that need + to interact with the Flask application within a request scope. + + Args: + flask_app_with_containers: Flask application fixture + + Yields: + None: Request context is active during yield + """ + logger.debug("Creating Flask request context...") + with flask_app_with_containers.test_request_context(): + logger.debug("Flask request context active") + yield + logger.debug("Flask request context closed") + + +@pytest.fixture +def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]: + """ + Test client fixture for containerized Flask application. + + This fixture provides a Flask test client that can be used to make + HTTP requests to the containerized application for integration testing. + + Args: + flask_app_with_containers: Flask application fixture + + Yields: + FlaskClient: Test client instance + """ + logger.debug("Creating Flask test client...") + with flask_app_with_containers.test_client() as client: + logger.debug("Flask test client ready") + yield client + logger.debug("Flask test client closed") + + +@pytest.fixture +def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]: + """ + Database session fixture for containerized testing. + + This fixture provides a SQLAlchemy database session that is connected + to the test PostgreSQL container, allowing tests to interact with + the database directly. + + Args: + flask_app_with_containers: Flask application fixture + + Yields: + Session: Database session instance + """ + logger.debug("Creating database session...") + with flask_app_with_containers.app_context(): + session = db.session() + logger.debug("Database session created and ready") + try: + yield session + finally: + session.close() + logger.debug("Database session closed") diff --git a/api/tests/test_containers_integration_tests/factories/__init__.py b/api/tests/test_containers_integration_tests/factories/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py new file mode 100644 index 0000000000..d6e14f3f54 --- /dev/null +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -0,0 +1,371 @@ +import unittest +from datetime import UTC, datetime +from typing import Optional +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from core.file import File, FileTransferMethod, FileType +from extensions.ext_database import db +from factories.file_factory import StorageKeyLoader +from models import ToolFile, UploadFile +from models.enums import CreatorUserRole + + +@pytest.mark.usefixtures("flask_req_ctx_with_containers") +class TestStorageKeyLoader(unittest.TestCase): + """ + Integration tests for StorageKeyLoader class. + + Tests the batched loading of storage keys from the database for files + with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE. + """ + + def setUp(self): + """Set up test data before each test method.""" + self.session = db.session() + self.tenant_id = str(uuid4()) + self.user_id = str(uuid4()) + self.conversation_id = str(uuid4()) + + # Create test data that will be cleaned up after each test + self.test_upload_files = [] + self.test_tool_files = [] + + # Create StorageKeyLoader instance + self.loader = StorageKeyLoader(self.session, self.tenant_id) + + def tearDown(self): + """Clean up test data after each test method.""" + self.session.rollback() + + def _create_upload_file( + self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None + ) -> UploadFile: + """Helper method to create an UploadFile record for testing.""" + if file_id is None: + file_id = str(uuid4()) + if storage_key is None: + storage_key = f"test_storage_key_{uuid4()}" + if tenant_id is None: + tenant_id = self.tenant_id + + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key=storage_key, + name="test_file.txt", + size=1024, + extension=".txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=self.user_id, + created_at=datetime.now(UTC), + used=False, + ) + upload_file.id = file_id + + self.session.add(upload_file) + self.session.flush() + self.test_upload_files.append(upload_file) + + return upload_file + + def _create_tool_file( + self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None + ) -> ToolFile: + """Helper method to create a ToolFile record for testing.""" + if file_id is None: + file_id = str(uuid4()) + if file_key is None: + file_key = f"test_file_key_{uuid4()}" + if tenant_id is None: + tenant_id = self.tenant_id + + tool_file = ToolFile() + tool_file.id = file_id + tool_file.user_id = self.user_id + tool_file.tenant_id = tenant_id + tool_file.conversation_id = self.conversation_id + tool_file.file_key = file_key + tool_file.mimetype = "text/plain" + tool_file.original_url = "http://example.com/file.txt" + tool_file.name = "test_tool_file.txt" + tool_file.size = 2048 + + self.session.add(tool_file) + self.session.flush() + self.test_tool_files.append(tool_file) + + return tool_file + + def _create_file( + self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None + ) -> File: + """Helper method to create a File object for testing.""" + if tenant_id is None: + tenant_id = self.tenant_id + + # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods + file_related_id = None + remote_url = None + + if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): + file_related_id = related_id + elif transfer_method == FileTransferMethod.REMOTE_URL: + remote_url = "https://example.com/test_file.txt" + file_related_id = related_id + + return File( + id=str(uuid4()), # Generate new UUID for File.id + tenant_id=tenant_id, + type=FileType.DOCUMENT, + transfer_method=transfer_method, + related_id=file_related_id, + remote_url=remote_url, + filename="test_file.txt", + extension=".txt", + mime_type="text/plain", + size=1024, + storage_key="initial_key", + ) + + def test_load_storage_keys_local_file(self): + """Test loading storage keys for LOCAL_FILE transfer method.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == upload_file.key + + def test_load_storage_keys_remote_url(self): + """Test loading storage keys for REMOTE_URL transfer method.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == upload_file.key + + def test_load_storage_keys_tool_file(self): + """Test loading storage keys for TOOL_FILE transfer method.""" + # Create test data + tool_file = self._create_tool_file() + file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == tool_file.file_key + + def test_load_storage_keys_mixed_methods(self): + """Test batch loading with mixed transfer methods.""" + # Create test data for different transfer methods + upload_file1 = self._create_upload_file() + upload_file2 = self._create_upload_file() + tool_file = self._create_tool_file() + + file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) + file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + + files = [file1, file2, file3] + + # Load storage keys + self.loader.load_storage_keys(files) + + # Verify all storage keys were loaded correctly + assert file1._storage_key == upload_file1.key + assert file2._storage_key == upload_file2.key + assert file3._storage_key == tool_file.file_key + + def test_load_storage_keys_empty_list(self): + """Test with empty file list.""" + # Should not raise any exceptions + self.loader.load_storage_keys([]) + + def test_load_storage_keys_tenant_mismatch(self): + """Test tenant_id validation.""" + # Create file with different tenant_id + upload_file = self._create_upload_file() + file = self._create_file( + related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) + ) + + # Should raise ValueError for tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file]) + + assert "invalid file, expected tenant_id" in str(context.value) + + def test_load_storage_keys_missing_file_id(self): + """Test with None file.related_id.""" + # Create a file with valid parameters first, then manually set related_id to None + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = None + + # Should raise ValueError for None file related_id + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file]) + + assert str(context.value) == "file id should not be None." + + def test_load_storage_keys_nonexistent_upload_file_records(self): + """Test with missing UploadFile database records.""" + # Create file with non-existent upload file id + non_existent_id = str(uuid4()) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Should raise ValueError for missing record + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_nonexistent_tool_file_records(self): + """Test with missing ToolFile database records.""" + # Create file with non-existent tool file id + non_existent_id = str(uuid4()) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE) + + # Should raise ValueError for missing record + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_invalid_uuid(self): + """Test with invalid UUID format.""" + # Create a file with valid parameters first, then manually set invalid related_id + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = "invalid-uuid-format" + + # Should raise ValueError for invalid UUID + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_batch_efficiency(self): + """Test batched operations use efficient queries.""" + # Create multiple files of different types + upload_files = [self._create_upload_file() for _ in range(3)] + tool_files = [self._create_tool_file() for _ in range(2)] + + files = [] + files.extend( + [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files] + ) + files.extend( + [self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files] + ) + + # Mock the session to count queries + with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: + self.loader.load_storage_keys(files) + + # Should make exactly 2 queries (one for upload_files, one for tool_files) + assert mock_scalars.call_count == 2 + + # Verify all storage keys were loaded correctly + for i, file in enumerate(files[:3]): + assert file._storage_key == upload_files[i].key + for i, file in enumerate(files[3:]): + assert file._storage_key == tool_files[i].file_key + + def test_load_storage_keys_tenant_isolation(self): + """Test that tenant isolation works correctly.""" + # Create files for different tenants + other_tenant_id = str(uuid4()) + + # Create upload file for current tenant + upload_file_current = self._create_upload_file() + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) + + # Create upload file for other tenant (but don't add to cleanup list) + upload_file_other = UploadFile( + tenant_id=other_tenant_id, + storage_type="local", + key="other_tenant_key", + name="other_file.txt", + size=1024, + extension=".txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=self.user_id, + created_at=datetime.now(UTC), + used=False, + ) + upload_file_other.id = str(uuid4()) + self.session.add(upload_file_other) + self.session.flush() + + # Create file for other tenant but try to load with current tenant's loader + file_other = self._create_file( + related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) + + # Should raise ValueError due to tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file_other]) + + assert "invalid file, expected tenant_id" in str(context.value) + + # Current tenant's file should still work + self.loader.load_storage_keys([file_current]) + assert file_current._storage_key == upload_file_current.key + + def test_load_storage_keys_mixed_tenant_batch(self): + """Test batch with mixed tenant files (should fail on first mismatch).""" + # Create files for current tenant + upload_file_current = self._create_upload_file() + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) + + # Create file for different tenant + other_tenant_id = str(uuid4()) + file_other = self._create_file( + related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) + + # Should raise ValueError on tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file_current, file_other]) + + assert "invalid file, expected tenant_id" in str(context.value) + + def test_load_storage_keys_duplicate_file_ids(self): + """Test handling of duplicate file IDs in the batch.""" + # Create upload file + upload_file = self._create_upload_file() + + # Create two File objects with same related_id + file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Should handle duplicates gracefully + self.loader.load_storage_keys([file1, file2]) + + # Both files should have the same storage key + assert file1._storage_key == upload_file.key + assert file2._storage_key == upload_file.key + + def test_load_storage_keys_session_isolation(self): + """Test that the loader uses the provided session correctly.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Create loader with different session (same underlying connection) + + with Session(bind=db.engine) as other_session: + other_loader = StorageKeyLoader(other_session, self.tenant_id) + with pytest.raises(ValueError): + other_loader.load_storage_keys([file]) diff --git a/api/tests/test_containers_integration_tests/services/__init__.py b/api/tests/test_containers_integration_tests/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py new file mode 100644 index 0000000000..3d7be0df7d --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -0,0 +1,3340 @@ +import json +from hashlib import sha256 +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import Unauthorized + +from configs import dify_config +from controllers.console.error import AccountNotFound, NotAllowedCreateWorkspace +from models.account import AccountStatus, TenantAccountJoin +from services.account_service import AccountService, RegisterService, TenantService, TokenPair +from services.errors.account import ( + AccountAlreadyInTenantError, + AccountLoginError, + AccountNotFoundError, + AccountPasswordError, + AccountRegisterError, + CurrentPasswordIncorrectError, +) +from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError + + +class TestAccountService: + """Integration tests for AccountService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + patch("services.account_service.PassportService") as mock_passport_service, + ): + # Setup default mock returns + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True + mock_feature_service.get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_billing_service.is_email_in_freeze.return_value = False + mock_passport_service.return_value.issue.return_value = "mock_jwt_token" + + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + "passport_service": mock_passport_service, + } + + def test_create_account_and_login(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation and login with correct password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + assert account.email == email + assert account.status == AccountStatus.ACTIVE.value + + # Login with correct password + logged_in = AccountService.authenticate(email, password) + assert logged_in.id == account.id + + def test_create_account_without_password(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation without password (for OAuth users). + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + assert account.email == email + assert account.password is None + assert account.password_salt is None + + def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation when registration is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks to disable registration + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = False + + with pytest.raises(AccountNotFound): # AccountNotFound exception + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=fake.password(length=12), + ) + + def test_create_account_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation when email is in freeze period. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True + dify_config.BILLING_ENABLED = True + + with pytest.raises(AccountRegisterError): + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + dify_config.BILLING_ENABLED = False # Reset config for other tests + + def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with non-existent account. + """ + fake = Faker() + email = fake.email() + password = fake.password(length=12) + with pytest.raises(AccountNotFoundError): + AccountService.authenticate(email, password) + + def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with banned account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account first + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Ban the account + account.status = AccountStatus.BANNED.value + from extensions.ext_database import db + + db.session.commit() + + with pytest.raises(AccountLoginError): + AccountService.authenticate(email, password) + + def test_authenticate_wrong_password(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with wrong password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + correct_password = fake.password(length=12) + wrong_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account first + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=correct_password, + ) + + with pytest.raises(AccountPasswordError): + AccountService.authenticate(email, wrong_password) + + def test_authenticate_with_invite_token(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with invite token to set password for account without password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + new_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account without password + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + + # Authenticate with invite token to set password + authenticated_account = AccountService.authenticate( + email, + new_password, + invite_token="valid_invite_token", + ) + + assert authenticated_account.id == account.id + assert authenticated_account.password is not None + assert authenticated_account.password_salt is not None + + def test_authenticate_pending_account_activation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test authentication activates pending account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account with pending status + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + account.status = AccountStatus.PENDING.value + from extensions.ext_database import db + + db.session.commit() + + # Authenticate should activate the account + authenticated_account = AccountService.authenticate(email, password) + assert authenticated_account.status == AccountStatus.ACTIVE.value + assert authenticated_account.initialized_at is not None + + def test_update_account_password_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful password update. + """ + fake = Faker() + email = fake.email() + name = fake.name() + old_password = fake.password(length=12) + new_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=old_password, + ) + + # Update password + updated_account = AccountService.update_account_password(account, old_password, new_password) + + # Verify new password works + authenticated_account = AccountService.authenticate(email, new_password) + assert authenticated_account.id == account.id + + def test_update_account_password_wrong_current_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test password update with wrong current password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + old_password = fake.password(length=12) + wrong_password = fake.password(length=12) + new_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=old_password, + ) + + with pytest.raises(CurrentPasswordIncorrectError): + AccountService.update_account_password(account, wrong_password, new_password) + + def test_update_account_password_invalid_new_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test password update with invalid new password format. + """ + fake = Faker() + email = fake.email() + name = fake.name() + old_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=old_password, + ) + + # Test with too short password (assuming minimum length validation) + with pytest.raises(ValueError): # Password validation error + AccountService.update_account_password(account, old_password, "123") + + def test_create_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation with automatic tenant creation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + account = AccountService.create_account_and_tenant( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + assert account.email == email + + # Verify tenant was created and linked + from extensions.ext_database import db + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + + def test_create_account_and_tenant_workspace_creation_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account creation when workspace creation is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + with pytest.raises(WorkSpaceNotAllowedCreateError): + AccountService.create_account_and_tenant( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + def test_create_account_and_tenant_workspace_limit_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account creation when workspace limit is exceeded. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + with pytest.raises(WorkspacesLimitExceededError): + AccountService.create_account_and_tenant( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + def test_link_account_integrate_new_provider(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test linking account with new OAuth provider. + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + + # Link with new provider + AccountService.link_account_integrate("new-google", "google_open_id_123", account) + + # Verify integration was created + from extensions.ext_database import db + from models.account import AccountIntegrate + + integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="new-google").first() + assert integration is not None + assert integration.open_id == "google_open_id_123" + + def test_link_account_integrate_existing_provider( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test linking account with existing provider (should update). + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + + # Link with provider first time + AccountService.link_account_integrate("exists-google", "google_open_id_123", account) + + # Link with same provider but different open_id (should update) + AccountService.link_account_integrate("exists-google", "google_open_id_456", account) + + # Verify integration was updated + from extensions.ext_database import db + from models.account import AccountIntegrate + + integration = ( + db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="exists-google").first() + ) + assert integration.open_id == "google_open_id_456" + + def test_close_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test closing an account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Close account + AccountService.close_account(account) + + # Verify account status changed + from extensions.ext_database import db + + db.session.refresh(account) + assert account.status == AccountStatus.CLOSED.value + + def test_update_account_fields(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating account fields. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + updated_name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Update account fields + updated_account = AccountService.update_account(account, name=updated_name, interface_theme="dark") + + assert updated_account.name == updated_name + assert updated_account.interface_theme == "dark" + + def test_update_account_invalid_field(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating account with invalid field. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + with pytest.raises(AttributeError): + AccountService.update_account(account, invalid_field="value") + + def test_update_login_info(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating login information. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Update login info + AccountService.update_login_info(account, ip_address=ip_address) + + # Verify login info was updated + from extensions.ext_database import db + + db.session.refresh(account) + assert account.last_login_ip == ip_address + assert account.last_login_at is not None + + def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful login with token generation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Login + token_pair = AccountService.login(account, ip_address=ip_address) + + assert isinstance(token_pair, TokenPair) + assert token_pair.access_token == "mock_access_token" + assert token_pair.refresh_token is not None + + # Verify passport service was called with correct parameters + mock_passport = mock_external_service_dependencies["passport_service"].return_value + mock_passport.issue.assert_called_once() + call_args = mock_passport.issue.call_args[0][0] + assert call_args["user_id"] == account.id + assert call_args["iss"] is not None + assert call_args["sub"] == "Console API Passport" + + def test_login_pending_account_activation(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test login activates pending account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account with pending status + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + account.status = AccountStatus.PENDING.value + from extensions.ext_database import db + + db.session.commit() + + # Login should activate the account + token_pair = AccountService.login(account) + + db.session.refresh(account) + assert account.status == AccountStatus.ACTIVE.value + + def test_logout(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test logout functionality. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Login first to get refresh token + token_pair = AccountService.login(account) + + # Logout + AccountService.logout(account=account) + + # Verify refresh token was deleted from Redis + from extensions.ext_redis import redis_client + + refresh_token_key = f"account_refresh_token:{account.id}" + assert redis_client.get(refresh_token_key) is None + + def test_refresh_token_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful token refresh. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "new_mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + # Create associated Tenant + TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + + # Login to get initial tokens + initial_token_pair = AccountService.login(account) + + # Refresh token + new_token_pair = AccountService.refresh_token(initial_token_pair.refresh_token) + + assert isinstance(new_token_pair, TokenPair) + assert new_token_pair.access_token == "new_mock_access_token" + assert new_token_pair.refresh_token != initial_token_pair.refresh_token + + def test_refresh_token_invalid_token(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test refresh token with invalid token. + """ + fake = Faker() + invalid_token = fake.uuid4() + with pytest.raises(ValueError, match="Invalid refresh token"): + AccountService.refresh_token(invalid_token) + + def test_refresh_token_invalid_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test refresh token with valid token but invalid account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Login to get tokens + token_pair = AccountService.login(account) + + # Delete account + from extensions.ext_database import db + + db.session.delete(account) + db.session.commit() + + # Try to refresh token with deleted account + with pytest.raises(ValueError, match="Invalid account"): + AccountService.refresh_token(token_pair.refresh_token) + + def test_load_user_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading user by ID successfully. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + # Create associated Tenant + TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + + # Load user + loaded_user = AccountService.load_user(account.id) + + assert loaded_user is not None + assert loaded_user.id == account.id + assert loaded_user.email == account.email + + def test_load_user_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading non-existent user. + """ + fake = Faker() + non_existent_user_id = fake.uuid4() + loaded_user = AccountService.load_user(non_existent_user_id) + assert loaded_user is None + + def test_load_user_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading banned user raises Unauthorized. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Ban the account + account.status = AccountStatus.BANNED.value + from extensions.ext_database import db + + db.session.commit() + + with pytest.raises(Unauthorized): # Unauthorized exception + AccountService.load_user(account.id) + + def test_get_account_jwt_token(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test JWT token generation for account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_jwt_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate JWT token + token = AccountService.get_account_jwt_token(account) + + assert token == "mock_jwt_token" + + # Verify passport service was called with correct parameters + mock_passport = mock_external_service_dependencies["passport_service"].return_value + mock_passport.issue.assert_called_once() + call_args = mock_passport.issue.call_args[0][0] + assert call_args["user_id"] == account.id + assert call_args["iss"] is not None + assert call_args["sub"] == "Console API Passport" + + def test_load_logged_in_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading logged in account by ID. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + # Create associated Tenant + TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + + # Load logged in account + loaded_account = AccountService.load_logged_in_account(account_id=account.id) + + assert loaded_account is not None + assert loaded_account.id == account.id + + def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user through email successfully. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Get user through email + found_user = AccountService.get_user_through_email(email) + + assert found_user is not None + assert found_user.id == account.id + + def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user through non-existent email. + """ + fake = Faker() + non_existent_email = fake.email() + found_user = AccountService.get_user_through_email(non_existent_email) + assert found_user is None + + def test_get_user_through_email_banned_account( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting banned user through email raises Unauthorized. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Ban the account + account.status = AccountStatus.BANNED.value + from extensions.ext_database import db + + db.session.commit() + + with pytest.raises(Unauthorized): # Unauthorized exception + AccountService.get_user_through_email(email) + + def test_get_user_through_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user through email that is in freeze period. + """ + fake = Faker() + email_in_freeze = fake.email() + # Setup mocks + dify_config.BILLING_ENABLED = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True + + with pytest.raises(AccountRegisterError): + AccountService.get_user_through_email(email_in_freeze) + + # Reset config + dify_config.BILLING_ENABLED = False + + def test_delete_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account deletion (should add task to queue). + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + with patch("services.account_service.delete_account_task") as mock_delete_task: + # Delete account + AccountService.delete_account(account) + + # Verify task was added to queue + mock_delete_task.delay.assert_called_once_with(account.id) + + def test_generate_account_deletion_verification_code( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test generating account deletion verification code. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate verification code + token, code = AccountService.generate_account_deletion_verification_code(account) + + assert token is not None + assert code is not None + assert len(code) == 6 + assert code.isdigit() + + def test_verify_account_deletion_code_valid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test verifying valid account deletion code. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate verification code + token, code = AccountService.generate_account_deletion_verification_code(account) + + # Verify code + is_valid = AccountService.verify_account_deletion_code(token, code) + assert is_valid is True + + def test_verify_account_deletion_code_invalid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test verifying invalid account deletion code. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + wrong_code = fake.numerify(text="######") + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate verification code + token, code = AccountService.generate_account_deletion_verification_code(account) + + # Verify with wrong code + is_valid = AccountService.verify_account_deletion_code(token, wrong_code) + assert is_valid is False + + def test_verify_account_deletion_code_invalid_token( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test verifying account deletion code with invalid token. + """ + fake = Faker() + invalid_token = fake.uuid4() + invalid_code = fake.numerify(text="######") + is_valid = AccountService.verify_account_deletion_code(invalid_token, invalid_code) + assert is_valid is False + + +class TestTenantService: + """Integration tests for TenantService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + ): + # Setup default mock returns + mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True + mock_feature_service.get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_billing_service.is_email_in_freeze.return_value = False + + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + } + + def test_create_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tenant creation with default settings. + """ + fake = Faker() + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant + tenant = TenantService.create_tenant(name=tenant_name) + + assert tenant.name == tenant_name + assert tenant.plan == "basic" + assert tenant.status == "normal" + assert tenant.encrypt_public_key is not None + + def test_create_tenant_workspace_creation_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tenant creation when workspace creation is disabled. + """ + fake = Faker() + tenant_name = fake.company() + # Setup mocks to disable workspace creation + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + + with pytest.raises(NotAllowedCreateWorkspace): # NotAllowedCreateWorkspace exception + TenantService.create_tenant(name=tenant_name) + + def test_create_tenant_with_custom_name(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tenant creation with custom name and setup flag. + """ + fake = Faker() + custom_tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + + # Create tenant with setup flag (should bypass workspace creation restriction) + tenant = TenantService.create_tenant(name=custom_tenant_name, is_setup=True, is_from_dashboard=True) + + assert tenant.name == custom_tenant_name + assert tenant.plan == "basic" + assert tenant.status == "normal" + assert tenant.encrypt_public_key is not None + + def test_create_tenant_member_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tenant member creation. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create tenant member + tenant_member = TenantService.create_tenant_member(tenant, account, role="admin") + + assert tenant_member.tenant_id == tenant.id + assert tenant_member.account_id == account.id + assert tenant_member.role == "admin" + + def test_create_tenant_member_duplicate_owner(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test creating duplicate owner for a tenant (should fail). + """ + fake = Faker() + tenant_name = fake.company() + email1 = fake.email() + name1 = fake.name() + password1 = fake.password(length=12) + email2 = fake.email() + name2 = fake.name() + password2 = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + account1 = AccountService.create_account( + email=email1, + name=name1, + interface_language="en-US", + password=password1, + ) + account2 = AccountService.create_account( + email=email2, + name=name2, + interface_language="en-US", + password=password2, + ) + + # Create first owner + TenantService.create_tenant_member(tenant, account1, role="owner") + + # Try to create second owner (should fail) + with pytest.raises(Exception, match="Tenant already has an owner"): + TenantService.create_tenant_member(tenant, account2, role="owner") + + def test_create_tenant_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating role for existing tenant member. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create member with initial role + tenant_member1 = TenantService.create_tenant_member(tenant, account, role="normal") + assert tenant_member1.role == "normal" + + # Update member role + tenant_member2 = TenantService.create_tenant_member(tenant, account, role="editor") + assert tenant_member2.tenant_id == tenant_member1.tenant_id + assert tenant_member2.account_id == tenant_member1.account_id + assert tenant_member2.role == "editor" + + def test_get_join_tenants_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting join tenants for an account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant1_name = fake.company() + tenant2_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenants + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant1 = TenantService.create_tenant(name=tenant1_name) + tenant2 = TenantService.create_tenant(name=tenant2_name) + + # Add account to both tenants + TenantService.create_tenant_member(tenant1, account, role="normal") + TenantService.create_tenant_member(tenant2, account, role="admin") + + # Get join tenants + join_tenants = TenantService.get_join_tenants(account) + + assert len(join_tenants) == 2 + tenant_names = [tenant.name for tenant in join_tenants] + assert tenant1_name in tenant_names + assert tenant2_name in tenant_names + + def test_get_current_tenant_by_account_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting current tenant by account successfully. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant = TenantService.create_tenant(name=tenant_name) + + # Add account to tenant and set as current + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + from extensions.ext_database import db + + db.session.commit() + + # Get current tenant + current_tenant = TenantService.get_current_tenant_by_account(account) + + assert current_tenant.id == tenant.id + assert current_tenant.name == tenant.name + assert current_tenant.role == "owner" + + def test_get_current_tenant_by_account_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting current tenant when account has no current tenant. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account without setting current tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Try to get current tenant (should fail) + with pytest.raises(AttributeError): + TenantService.get_current_tenant_by_account(account) + + def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tenant switching. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant1_name = fake.company() + tenant2_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenants + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant1 = TenantService.create_tenant(name=tenant1_name) + tenant2 = TenantService.create_tenant(name=tenant2_name) + + # Add account to both tenants + TenantService.create_tenant_member(tenant1, account, role="owner") + TenantService.create_tenant_member(tenant2, account, role="admin") + + # Set initial current tenant + account.current_tenant = tenant1 + from extensions.ext_database import db + + db.session.commit() + + # Switch to second tenant + TenantService.switch_tenant(account, tenant2.id) + + # Verify tenant was switched + db.session.refresh(account) + assert account.current_tenant_id == tenant2.id + + def test_switch_tenant_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tenant switching without providing tenant ID. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Try to switch tenant without providing tenant ID + with pytest.raises(ValueError, match="Tenant ID must be provided"): + TenantService.switch_tenant(account, None) + + def test_switch_tenant_account_not_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test switching to a tenant where account is not a member. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant = TenantService.create_tenant(name=tenant_name) + + # Try to switch to tenant where account is not a member + with pytest.raises(Exception, match="Tenant not found or account is not a member of the tenant"): + TenantService.switch_tenant(account, tenant.id) + + def test_has_roles_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking if tenant has specific roles. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + admin_account = AccountService.create_account( + email=admin_email, + name=admin_name, + interface_language="en-US", + password=admin_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, admin_account, role="admin") + + # Check if tenant has owner role + from models.account import TenantAccountRole + + has_owner = TenantService.has_roles(tenant, [TenantAccountRole.OWNER]) + assert has_owner is True + + # Check if tenant has admin role + has_admin = TenantService.has_roles(tenant, [TenantAccountRole.ADMIN]) + assert has_admin is True + + # Check if tenant has normal role (should be False) + has_normal = TenantService.has_roles(tenant, [TenantAccountRole.NORMAL]) + assert has_normal is False + + def test_has_roles_invalid_role_type(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking roles with invalid role type. + """ + fake = Faker() + tenant_name = fake.company() + invalid_role = fake.word() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant + tenant = TenantService.create_tenant(name=tenant_name) + + # Try to check roles with invalid role type + with pytest.raises(ValueError, match="all roles must be TenantAccountRole"): + TenantService.has_roles(tenant, [invalid_role]) + + def test_get_user_role_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user role in a tenant. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant with specific role + TenantService.create_tenant_member(tenant, account, role="editor") + + # Get user role + user_role = TenantService.get_user_role(account, tenant) + + assert user_role == "editor" + + def test_check_member_permission_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking member permission successfully. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="normal") + + # Check owner permission to add member (should succeed) + TenantService.check_member_permission(tenant, owner_account, member_account, "add") + + def test_check_member_permission_invalid_action( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test checking member permission with invalid action. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + invalid_action = fake.word() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant + TenantService.create_tenant_member(tenant, account, role="owner") + + # Try to check permission with invalid action + with pytest.raises(Exception, match="Invalid action"): + TenantService.check_member_permission(tenant, account, None, invalid_action) + + def test_check_member_permission_operate_self(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking member permission when trying to operate self. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant + TenantService.create_tenant_member(tenant, account, role="owner") + + # Try to check permission to operate self + with pytest.raises(Exception, match="Cannot operate self"): + TenantService.check_member_permission(tenant, account, account, "remove") + + def test_remove_member_from_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful member removal from tenant. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="normal") + + # Remove member + TenantService.remove_member_from_tenant(tenant, member_account, owner_account) + + # Verify member was removed + from extensions.ext_database import db + from models.account import TenantAccountJoin + + member_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + ) + assert member_join is None + + def test_remove_member_from_tenant_operate_self( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test removing member when trying to operate self. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant + TenantService.create_tenant_member(tenant, account, role="owner") + + # Try to remove self + with pytest.raises(Exception, match="Cannot operate self"): + TenantService.remove_member_from_tenant(tenant, account, account) + + def test_remove_member_from_tenant_not_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test removing member who is not in the tenant. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + non_member_email = fake.email() + non_member_name = fake.name() + non_member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + non_member_account = AccountService.create_account( + email=non_member_email, + name=non_member_name, + interface_language="en-US", + password=non_member_password, + ) + + # Add only owner to tenant + TenantService.create_tenant_member(tenant, owner_account, role="owner") + + # Try to remove non-member + with pytest.raises(Exception, match="Member not in tenant"): + TenantService.remove_member_from_tenant(tenant, non_member_account, owner_account) + + def test_update_member_role_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful member role update. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="normal") + + # Update member role + TenantService.update_member_role(tenant, member_account, "admin", owner_account) + + # Verify role was updated + from extensions.ext_database import db + from models.account import TenantAccountJoin + + member_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + ) + assert member_join.role == "admin" + + def test_update_member_role_to_owner(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating member role to owner (should change current owner to admin). + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="admin") + + # Update member role to owner + TenantService.update_member_role(tenant, member_account, "owner", owner_account) + + # Verify roles were updated correctly + from extensions.ext_database import db + from models.account import TenantAccountJoin + + owner_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=owner_account.id).first() + ) + member_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + ) + assert owner_join.role == "admin" + assert member_join.role == "owner" + + def test_update_member_role_already_assigned(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating member role to already assigned role. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="admin") + + # Try to update member role to already assigned role + with pytest.raises(Exception, match="The provided role is already assigned to the member"): + TenantService.update_member_role(tenant, member_account, "admin", owner_account) + + def test_get_tenant_count_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting tenant count successfully. + """ + fake = Faker() + tenant1_name = fake.company() + tenant2_name = fake.company() + tenant3_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create multiple tenants + tenant1 = TenantService.create_tenant(name=tenant1_name) + tenant2 = TenantService.create_tenant(name=tenant2_name) + tenant3 = TenantService.create_tenant(name=tenant3_name) + + # Get tenant count + tenant_count = TenantService.get_tenant_count() + + # Should have at least 3 tenants (may be more from other tests) + assert tenant_count >= 3 + + def test_create_owner_tenant_if_not_exist_new_user( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating owner tenant for new user without existing tenants. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + workspace_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create owner tenant + TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + assert account.current_tenant is not None + assert account.current_tenant.name == workspace_name + + def test_create_owner_tenant_if_not_exist_existing_tenant( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating owner tenant when user already has a tenant. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + existing_tenant_name = fake.company() + new_workspace_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + + # Create account and existing tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + existing_tenant = TenantService.create_tenant(name=existing_tenant_name) + TenantService.create_tenant_member(existing_tenant, account, role="owner") + account.current_tenant = existing_tenant + from extensions.ext_database import db + + db.session.commit() + + # Try to create owner tenant again (should not create new one) + TenantService.create_owner_tenant_if_not_exist(account, name=new_workspace_name) + + # Verify no new tenant was created + tenant_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).all() + assert len(tenant_joins) == 1 + assert account.current_tenant.id == existing_tenant.id + + def test_create_owner_tenant_if_not_exist_workspace_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating owner tenant when workspace creation is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + workspace_name = fake.company() + # Setup mocks to disable workspace creation + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Try to create owner tenant (should fail) + with pytest.raises(WorkSpaceNotAllowedCreateError): # WorkSpaceNotAllowedCreateError exception + TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) + + def test_get_tenant_members_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting tenant members successfully. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + normal_email = fake.email() + normal_name = fake.name() + normal_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + admin_account = AccountService.create_account( + email=admin_email, + name=admin_name, + interface_language="en-US", + password=admin_password, + ) + normal_account = AccountService.create_account( + email=normal_email, + name=normal_name, + interface_language="en-US", + password=normal_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, admin_account, role="admin") + TenantService.create_tenant_member(tenant, normal_account, role="normal") + + # Get tenant members + members = TenantService.get_tenant_members(tenant) + + assert len(members) == 3 + member_emails = [member.email for member in members] + assert owner_email in member_emails + assert admin_email in member_emails + assert normal_email in member_emails + + # Verify roles are set correctly + for member in members: + if member.email == owner_email: + assert member.role == "owner" + elif member.email == admin_email: + assert member.role == "admin" + elif member.email == normal_email: + assert member.role == "normal" + + def test_get_dataset_operator_members_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting dataset operator members successfully. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + operator_email = fake.email() + operator_name = fake.name() + operator_password = fake.password(length=12) + normal_email = fake.email() + normal_name = fake.name() + normal_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + dataset_operator_account = AccountService.create_account( + email=operator_email, + name=operator_name, + interface_language="en-US", + password=operator_password, + ) + normal_account = AccountService.create_account( + email=normal_email, + name=normal_name, + interface_language="en-US", + password=normal_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, dataset_operator_account, role="dataset_operator") + TenantService.create_tenant_member(tenant, normal_account, role="normal") + + # Get dataset operator members + dataset_operators = TenantService.get_dataset_operator_members(tenant) + + assert len(dataset_operators) == 1 + assert dataset_operators[0].email == operator_email + assert dataset_operators[0].role == "dataset_operator" + + def test_get_custom_config_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting custom config successfully. + """ + fake = Faker() + tenant_name = fake.company() + theme = fake.random_element(elements=("dark", "light")) + language = fake.random_element(elements=("zh-CN", "en-US")) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant with custom config + tenant = TenantService.create_tenant(name=tenant_name) + + # Set custom config + custom_config = {"theme": theme, "language": language, "feature_flags": {"beta": True}} + tenant.custom_config_dict = custom_config + from extensions.ext_database import db + + db.session.commit() + + # Get custom config + retrieved_config = TenantService.get_custom_config(tenant.id) + + assert retrieved_config == custom_config + assert retrieved_config["theme"] == theme + assert retrieved_config["language"] == language + assert retrieved_config["feature_flags"]["beta"] is True + + +class TestRegisterService: + """Integration tests for RegisterService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + patch("services.account_service.PassportService") as mock_passport_service, + ): + # Setup default mock returns + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True + mock_feature_service.get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_billing_service.is_email_in_freeze.return_value = False + mock_passport_service.return_value.issue.return_value = "mock_jwt_token" + + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + "passport_service": mock_passport_service, + } + + def test_setup_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful system setup with account creation and tenant setup. + """ + fake = Faker() + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute setup + RegisterService.setup( + email=admin_email, + name=admin_name, + password=admin_password, + ip_address=ip_address, + ) + + # Verify account was created + from extensions.ext_database import db + from models.account import Account + from models.model import DifySetup + + account = db.session.query(Account).filter_by(email=admin_email).first() + assert account is not None + assert account.name == admin_name + assert account.last_login_ip == ip_address + assert account.initialized_at is not None + assert account.status == "active" + + # Verify DifySetup was created + dify_setup = db.session.query(DifySetup).first() + assert dify_setup is not None + + # Verify tenant was created and linked + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + + def test_setup_failure_rollback(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test setup failure with proper rollback of all created entities. + """ + fake = Faker() + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Mock AccountService.create_account to raise exception + with patch("services.account_service.AccountService.create_account") as mock_create_account: + mock_create_account.side_effect = Exception("Database error") + + # Execute setup and verify exception + with pytest.raises(ValueError, match="Setup failed: Database error"): + RegisterService.setup( + email=admin_email, + name=admin_name, + password=admin_password, + ip_address=ip_address, + ) + + # Verify no entities were created (rollback worked) + from extensions.ext_database import db + from models.account import Account, Tenant, TenantAccountJoin + from models.model import DifySetup + + account = db.session.query(Account).filter_by(email=admin_email).first() + tenant_count = db.session.query(Tenant).count() + tenant_join_count = db.session.query(TenantAccountJoin).count() + dify_setup_count = db.session.query(DifySetup).count() + + assert account is None + assert tenant_count == 0 + assert tenant_join_count == 0 + assert dify_setup_count == 0 + + def test_register_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful account registration with workspace creation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + ) + + # Verify account was created + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + assert account.current_tenant is not None + assert account.current_tenant.name == f"{name}'s Workspace" + + def test_register_with_oauth(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration with OAuth integration. + """ + fake = Faker() + email = fake.email() + name = fake.name() + open_id = fake.uuid4() + provider = fake.random_element(elements=("google", "github", "microsoft")) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration with OAuth + account = RegisterService.register( + email=email, + name=name, + password=None, + open_id=open_id, + provider=provider, + language=language, + ) + + # Verify account was created + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify OAuth integration was created + from extensions.ext_database import db + from models.account import AccountIntegrate + + integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() + assert integration is not None + assert integration.open_id == open_id + + def test_register_with_pending_status(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration with pending status. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration with pending status + from models.account import AccountStatus + + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + status=AccountStatus.PENDING, + ) + + # Verify account was created with pending status + assert account.email == email + assert account.name == name + assert account.status == "pending" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + + def test_register_workspace_creation_disabled(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration when workspace creation is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # with pytest.raises(AccountRegisterError, match="Workspace is not allowed to create."): + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + ) + + # Verify account was created with no tenant + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is None + + def test_register_workspace_limit_exceeded(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration when workspace limit is exceeded. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # with pytest.raises(AccountRegisterError, match="Workspace is not allowed to create."): + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + ) + + # Verify account was created with no tenant + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is None + + def test_register_without_workspace(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration without workspace creation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration without workspace creation + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + create_workspace_required=False, + ) + + # Verify account was created + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify no tenant was created + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is None + + def test_invite_new_member_new_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting a new member who doesn't have an account yet. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + new_member_email = fake.email() + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Mock the email task + with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: + mock_send_mail.delay.return_value = None + + # Execute invitation + token = RegisterService.invite_new_member( + tenant=tenant, + email=new_member_email, + language=language, + role="normal", + inviter=inviter, + ) + + # Verify token was generated + assert token is not None + assert len(token) > 0 + + # Verify email task was called + mock_send_mail.delay.assert_called_once() + + # Verify new account was created with pending status + from extensions.ext_database import db + from models.account import Account, TenantAccountJoin + + new_account = db.session.query(Account).filter_by(email=new_member_email).first() + assert new_account is not None + assert new_account.name == new_member_email.split("@")[0] # Default name from email + assert new_account.status == "pending" + + # Verify tenant member was created + tenant_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=new_account.id).first() + ) + assert tenant_join is not None + assert tenant_join.role == "normal" + + def test_invite_new_member_existing_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting an existing member who is not in the tenant yet. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + existing_member_email = fake.email() + existing_member_name = fake.name() + existing_member_password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Create existing account + existing_account = AccountService.create_account( + email=existing_member_email, + name=existing_member_name, + interface_language="en-US", + password=existing_member_password, + ) + + # Mock the email task + with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: + mock_send_mail.delay.return_value = None + with pytest.raises(AccountAlreadyInTenantError, match="Account already in tenant."): + # Execute invitation + token = RegisterService.invite_new_member( + tenant=tenant, + email=existing_member_email, + language=language, + role="admin", + inviter=inviter, + ) + + # Verify email task was not called + mock_send_mail.delay.assert_not_called() + + # Verify tenant member was created for existing account + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=existing_account.id).first() + ) + assert tenant_join is not None + assert tenant_join.role == "admin" + + def test_invite_new_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting a member who is already in the tenant with pending status. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + existing_pending_member_email = fake.email() + existing_pending_member_name = fake.name() + existing_pending_member_password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Create existing account with pending status + existing_account = AccountService.create_account( + email=existing_pending_member_email, + name=existing_pending_member_name, + interface_language="en-US", + password=existing_pending_member_password, + ) + existing_account.status = "pending" + from extensions.ext_database import db + + db.session.commit() + + # Add existing account to tenant + TenantService.create_tenant_member(tenant, existing_account, role="normal") + + # Mock the email task + with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: + mock_send_mail.delay.return_value = None + + # Execute invitation (should resend email for pending member) + token = RegisterService.invite_new_member( + tenant=tenant, + email=existing_pending_member_email, + language=language, + role="normal", + inviter=inviter, + ) + + # Verify token was generated + assert token is not None + assert len(token) > 0 + + # Verify email task was called + mock_send_mail.delay.assert_called_once() + + def test_invite_new_member_no_inviter(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting a member without providing an inviter. + """ + fake = Faker() + tenant_name = fake.company() + new_member_email = fake.email() + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant + tenant = TenantService.create_tenant(name=tenant_name) + + # Execute invitation without inviter (should fail) + with pytest.raises(ValueError, match="Inviter is required"): + RegisterService.invite_new_member( + tenant=tenant, + email=new_member_email, + language=language, + role="normal", + inviter=None, + ) + + def test_invite_new_member_account_already_in_tenant( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test inviting a member who is already in the tenant with active status. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + already_in_tenant_email = fake.email() + already_in_tenant_name = fake.name() + already_in_tenant_password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Create existing account with active status + existing_account = AccountService.create_account( + email=already_in_tenant_email, + name=already_in_tenant_name, + interface_language="en-US", + password=already_in_tenant_password, + ) + existing_account.status = "active" + from extensions.ext_database import db + + db.session.commit() + + # Add existing account to tenant + TenantService.create_tenant_member(tenant, existing_account, role="normal") + + # Execute invitation (should fail for active member) + with pytest.raises(AccountAlreadyInTenantError, match="Account already in tenant."): + RegisterService.invite_new_member( + tenant=tenant, + email=already_in_tenant_email, + language=language, + role="normal", + inviter=inviter, + ) + + def test_generate_invite_token_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful generation of invite token. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Execute token generation + token = RegisterService.generate_invite_token(tenant, account) + + # Verify token was generated + assert token is not None + assert len(token) > 0 + + # Verify token was stored in Redis + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + stored_data = redis_client.get(token_key) + assert stored_data is not None + + # Verify stored data contains correct information + import json + + invitation_data = json.loads(stored_data.decode("utf-8")) + assert invitation_data["account_id"] == str(account.id) + assert invitation_data["email"] == account.email + assert invitation_data["workspace_id"] == tenant.id + + def test_is_valid_invite_token_valid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation of valid invite token. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + # Execute validation + is_valid = RegisterService.is_valid_invite_token(token) + + # Verify token is valid + assert is_valid is True + + def test_is_valid_invite_token_invalid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation of invalid invite token. + """ + fake = Faker() + invalid_token = fake.uuid4() + # Execute validation with non-existent token + is_valid = RegisterService.is_valid_invite_token(invalid_token) + + # Verify token is invalid + assert is_valid is False + + def test_revoke_token_with_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test revoking token with workspace ID and email. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + # Verify token exists in Redis before revocation + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + assert redis_client.get(token_key) is not None + + # Execute token revocation + RegisterService.revoke_token( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify token was not deleted from Redis + assert redis_client.get(token_key) is not None + + def test_revoke_token_without_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test revoking token without workspace ID and email. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + # Verify token exists in Redis before revocation + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + assert redis_client.get(token_key) is not None + + # Execute token revocation without workspace and email + RegisterService.revoke_token( + workspace_id="", + email="", + token=token, + ) + + # Verify token was deleted from Redis + assert redis_client.get(token_key) is None + + def test_get_invitation_if_token_valid_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with valid token. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + TenantService.create_tenant_member(tenant, account, role="normal") + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + email_hash = sha256(account.email.encode()).hexdigest() + cache_key = f"member_invite_token:{tenant.id}, {email_hash}:{token}" + from extensions.ext_redis import redis_client + + redis_client.setex(cache_key, 24 * 60 * 60, account.id) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify result contains expected data + assert result is not None + assert result["account"].id == account.id + assert result["tenant"].id == tenant.id + assert result["data"]["account_id"] == str(account.id) + assert result["data"]["email"] == account.email + assert result["data"]["workspace_id"] == tenant.id + + def test_get_invitation_if_token_valid_invalid_token( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with invalid token. + """ + fake = Faker() + workspace_id = fake.uuid4() + email = fake.email() + invalid_token = fake.uuid4() + # Execute invitation retrieval with invalid token + result = RegisterService.get_invitation_if_token_valid( + workspace_id=workspace_id, + email=email, + token=invalid_token, + ) + + # Verify result is None + assert result is None + + def test_get_invitation_if_token_valid_invalid_tenant( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with invalid tenant. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + invalid_tenant_id = fake.uuid4() + token = fake.uuid4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create a real token but with non-existent tenant ID + from extensions.ext_redis import redis_client + + invitation_data = { + "account_id": str(account.id), + "email": account.email, + "workspace_id": invalid_tenant_id, + } + token_key = RegisterService._get_invitation_token_key(token) + import json + + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=invalid_tenant_id, + email=account.email, + token=token, + ) + + # Verify result is None (tenant not found) + assert result is None + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_if_token_valid_account_mismatch( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with account ID mismatch. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + token = fake.uuid4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + TenantService.create_tenant_member(tenant, account, role="normal") + + # Create a real token but with mismatched account ID + from extensions.ext_redis import redis_client + + invitation_data = { + "account_id": "different-account-id", # Different from actual account ID + "email": account.email, + "workspace_id": tenant.id, + } + token_key = RegisterService._get_invitation_token_key(token) + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify result is None (account ID mismatch) + assert result is None + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_if_token_valid_tenant_not_normal( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with tenant not in normal status. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + token = fake.uuid4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + TenantService.create_tenant_member(tenant, account, role="normal") + + # Change tenant status to non-normal + tenant.status = "suspended" + from extensions.ext_database import db + + db.session.commit() + + # Create a real token + from extensions.ext_redis import redis_client + + invitation_data = { + "account_id": str(account.id), + "email": account.email, + "workspace_id": tenant.id, + } + token_key = RegisterService._get_invitation_token_key(token) + import json + + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify result is None (tenant not in normal status) + assert result is None + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_by_token_with_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation by token with workspace ID and email. + """ + fake = Faker() + token = fake.uuid4() + workspace_id = fake.uuid4() + email = fake.email() + + # Create the cache key as the service does + from hashlib import sha256 + + from extensions.ext_redis import redis_client + + email_hash = sha256(email.encode()).hexdigest() + cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" + + # Store account ID in Redis + account_id = fake.uuid4() + redis_client.setex(cache_key, 24 * 60 * 60, account_id) + + # Execute invitation retrieval + result = RegisterService._get_invitation_by_token( + token=token, + workspace_id=workspace_id, + email=email, + ) + + # Verify result contains expected data + assert result is not None + assert result["account_id"] == account_id + assert result["email"] == email + assert result["workspace_id"] == workspace_id + + # Clean up + redis_client.delete(cache_key) + + def test_get_invitation_by_token_without_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation by token without workspace ID and email. + """ + fake = Faker() + token = fake.uuid4() + invitation_data = { + "account_id": fake.uuid4(), + "email": fake.email(), + "workspace_id": fake.uuid4(), + } + + # Store invitation data in Redis using standard token key + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + import json + + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService._get_invitation_by_token(token=token) + + # Verify result contains expected data + assert result is not None + assert result["account_id"] == invitation_data["account_id"] + assert result["email"] == invitation_data["email"] + assert result["workspace_id"] == invitation_data["workspace_id"] + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_token_key(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting invitation token key. + """ + fake = Faker() + token = fake.uuid4() + # Execute token key generation + token_key = RegisterService._get_invitation_token_key(token) + + # Verify token key format + assert token_key == f"member_invite:token:{token}" diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py new file mode 100644 index 0000000000..0ab5f398e3 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -0,0 +1,1252 @@ +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import NotFound + +from models.model import MessageAnnotation +from services.annotation_service import AppAnnotationService +from services.app_service import AppService + + +class TestAnnotationService: + """Integration tests for AnnotationService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + patch("services.annotation_service.FeatureService") as mock_feature_service, + patch("services.annotation_service.add_annotation_to_index_task") as mock_add_task, + patch("services.annotation_service.update_annotation_to_index_task") as mock_update_task, + patch("services.annotation_service.delete_annotation_index_task") as mock_delete_task, + patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, + patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, + patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, + patch("services.annotation_service.current_user") as mock_current_user, + ): + # Setup default mock returns + mock_account_feature_service.get_features.return_value.billing.enabled = False + mock_add_task.delay.return_value = None + mock_update_task.delay.return_value = None + mock_delete_task.delay.return_value = None + mock_enable_task.delay.return_value = None + mock_disable_task.delay.return_value = None + mock_batch_import_task.delay.return_value = None + + yield { + "account_feature_service": mock_account_feature_service, + "feature_service": mock_feature_service, + "add_task": mock_add_task, + "update_task": mock_update_task, + "delete_task": mock_delete_task, + "enable_task": mock_enable_task, + "disable_task": mock_disable_task, + "batch_import_task": mock_batch_import_task, + "current_user": mock_current_user, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant first + from services.account_service import AccountService, TenantService + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + # Create app + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Setup current_user mock + self._mock_current_user(mock_external_service_dependencies, account.id, tenant.id) + + return app, account + + def _mock_current_user(self, mock_external_service_dependencies, account_id, tenant_id): + """ + Helper method to mock the current user for testing. + """ + mock_external_service_dependencies["current_user"].id = account_id + mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id + + def _create_test_conversation(self, app, account, fake): + """ + Helper method to create a test conversation with all required fields. + """ + from extensions.ext_database import db + from models.model import Conversation + + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=fake.sentence(), + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from="console", + from_source="console", + from_end_user_id=None, + from_account_id=account.id, + ) + + db.session.add(conversation) + db.session.flush() + return conversation + + def _create_test_message(self, app, conversation, account, fake): + """ + Helper method to create a test message with all required fields. + """ + import json + + from extensions.ext_database import db + from models.model import Message + + message = Message( + app_id=app.id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation.id, + inputs={}, + query=fake.sentence(), + message=json.dumps([{"role": "user", "text": fake.sentence()}]), + message_tokens=0, + message_unit_price=0, + message_price_unit=0.001, + answer=fake.text(max_nb_chars=200), + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0.001, + parent_message_id=None, + provider_response_latency=0, + total_price=0, + currency="USD", + invoke_from="console", + from_source="console", + from_end_user_id=None, + from_account_id=account.id, + ) + + db.session.add(message) + db.session.commit() + return message + + def test_insert_app_annotation_directly_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct insertion of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation directly + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + assert annotation.hit_count == 0 + assert annotation.id is not None + + # Verify annotation was saved to database + from extensions.ext_database import db + + db.session.refresh(annotation) + assert annotation.id is not None + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_insert_app_annotation_directly_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test direct insertion of app annotation when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Try to insert annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id) + + def test_update_app_annotation_directly_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct update of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # First, create an annotation + original_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(original_args, app.id) + + # Update the annotation + updated_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + + # Verify annotation was updated correctly + assert updated_annotation.id == annotation.id + assert updated_annotation.app_id == app.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.account_id == account.id + + # Verify original values were changed + assert updated_annotation.question != original_args["question"] + assert updated_annotation.content != original_args["answer"] + + # Verify update_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["update_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_new( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating new annotation from message. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Setup annotation data with message_id + annotation_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation from message + annotation = AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.conversation_id == conversation.id + assert annotation.message_id == message.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_update( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test updating existing annotation from message. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Create initial annotation + initial_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + initial_annotation = AppAnnotationService.up_insert_app_annotation_from_message(initial_args, app.id) + + # Update the annotation + updated_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.up_insert_app_annotation_from_message(updated_args, app.id) + + # Verify annotation was updated correctly (same ID) + assert updated_annotation.id == initial_annotation.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.question != initial_args["question"] + assert updated_annotation.content != initial_args["answer"] + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating annotation from message when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Try to insert annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id) + + def test_get_annotation_list_by_app_id_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful retrieval of annotation list by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple annotations + annotations = [] + for i in range(3): + annotation_args = { + "question": f"Question {i}: {fake.sentence()}", + "answer": f"Answer {i}: {fake.text(max_nb_chars=200)}", + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotations.append(annotation) + + # Get annotation list + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="" + ) + + # Verify results + assert len(annotation_list) == 3 + assert total == 3 + + # Verify all annotations belong to the correct app + for annotation in annotation_list: + assert annotation.app_id == app.id + assert annotation.account_id == account.id + + def test_get_annotation_list_by_app_id_with_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test retrieval of annotation list with keyword search. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotations with specific keywords + unique_keyword = fake.word() + annotation_args = { + "question": f"Question with {unique_keyword} keyword", + "answer": f"Answer with {unique_keyword} keyword", + } + AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Create another annotation without the keyword + other_args = { + "question": "Question without keyword", + "answer": "Answer without keyword", + } + AppAnnotationService.insert_app_annotation_directly(other_args, app.id) + + # Search with keyword + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword=unique_keyword + ) + + # Verify only matching annotations are returned + assert len(annotation_list) == 1 + assert total == 1 + assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content + + def test_get_annotation_list_by_app_id_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test retrieval of annotation list when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to get annotation list with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="") + + def test_delete_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful deletion of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotation_id = annotation.id + + # Delete the annotation + AppAnnotationService.delete_app_annotation(app.id, annotation_id) + + # Verify annotation was deleted + from extensions.ext_database import db + + deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + assert deleted_annotation is None + + # Verify delete_annotation_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["delete_task"].delay.assert_not_called() + + def test_delete_app_annotation_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deletion of app annotation when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + annotation_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to delete annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id) + + def test_delete_app_annotation_annotation_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test deletion of app annotation when annotation is not found. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + non_existent_annotation_id = fake.uuid4() + + # Try to delete non-existent annotation + with pytest.raises(NotFound, match="Annotation not found"): + AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id) + + def test_enable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful enabling of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Setup enable arguments + enable_args = { + "score_threshold": 0.8, + "embedding_provider_name": "openai", + "embedding_model_name": "text-embedding-ada-002", + } + + # Enable annotation + result = AppAnnotationService.enable_app_annotation(enable_args, app.id) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["enable_task"].delay.assert_called_once() + + def test_disable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful disabling of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Disable annotation + result = AppAnnotationService.disable_app_annotation(app.id) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["disable_task"].delay.assert_called_once() + + def test_enable_app_annotation_cached_job(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test enabling app annotation when job is already cached. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock Redis to return cached job + from extensions.ext_redis import redis_client + + cached_job_id = fake.uuid4() + enable_app_annotation_key = f"enable_app_annotation_{app.id}" + redis_client.set(enable_app_annotation_key, cached_job_id) + + # Setup enable arguments + enable_args = { + "score_threshold": 0.8, + "embedding_provider_name": "openai", + "embedding_model_name": "text-embedding-ada-002", + } + + # Enable annotation (should return cached job) + result = AppAnnotationService.enable_app_annotation(enable_args, app.id) + + # Verify cached result + assert cached_job_id == result["job_id"].decode("utf-8") + assert result["job_status"] == "processing" + + # Verify task was not called again + mock_external_service_dependencies["enable_task"].delay.assert_not_called() + + # Clean up + redis_client.delete(enable_app_annotation_key) + + def test_get_annotation_hit_histories_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of annotation hit histories. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Add some hit histories + for i in range(3): + AppAnnotationService.add_annotation_history( + annotation_id=annotation.id, + app_id=app.id, + annotation_question=annotation.question, + annotation_content=annotation.content, + query=f"Query {i}: {fake.sentence()}", + user_id=account.id, + message_id=fake.uuid4(), + from_source="console", + score=0.8 + (i * 0.1), + ) + + # Get hit histories + hit_histories, total = AppAnnotationService.get_annotation_hit_histories( + app.id, annotation.id, page=1, limit=10 + ) + + # Verify results + assert len(hit_histories) == 3 + assert total == 3 + + # Verify all histories belong to the correct annotation + for history in hit_histories: + assert history.annotation_id == annotation.id + assert history.app_id == app.id + assert history.account_id == account.id + + def test_add_annotation_history_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful addition of annotation history. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Get initial hit count + initial_hit_count = annotation.hit_count + + # Add annotation history + query = fake.sentence() + message_id = fake.uuid4() + score = 0.85 + + AppAnnotationService.add_annotation_history( + annotation_id=annotation.id, + app_id=app.id, + annotation_question=annotation.question, + annotation_content=annotation.content, + query=query, + user_id=account.id, + message_id=message_id, + from_source="console", + score=score, + ) + + # Verify hit count was incremented + from extensions.ext_database import db + + db.session.refresh(annotation) + assert annotation.hit_count == initial_hit_count + 1 + + # Verify history was created + from models.model import AppAnnotationHitHistory + + history = ( + db.session.query(AppAnnotationHitHistory) + .filter( + AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id + ) + .first() + ) + + assert history is not None + assert history.app_id == app.id + assert history.account_id == account.id + assert history.question == query + assert history.score == score + assert history.source == "console" + + def test_get_annotation_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of annotation by ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + created_annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Get annotation by ID + retrieved_annotation = AppAnnotationService.get_annotation_by_id(created_annotation.id) + + # Verify annotation was retrieved correctly + assert retrieved_annotation is not None + assert retrieved_annotation.id == created_annotation.id + assert retrieved_annotation.app_id == app.id + assert retrieved_annotation.question == annotation_args["question"] + assert retrieved_annotation.content == annotation_args["answer"] + assert retrieved_annotation.account_id == account.id + + def test_batch_import_app_annotations_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful batch import of app annotations. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create CSV content + csv_content = "Question 1,Answer 1\nQuestion 2,Answer 2\nQuestion 3,Answer 3" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + mock_external_service_dependencies["feature_service"].get_features.return_value.billing.enabled = False + + # Mock pandas to return expected DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame( + {0: ["Question 1", "Question 2", "Question 3"], 1: ["Answer 1", "Answer 2", "Answer 3"]} + ) + mock_pd.read_csv.return_value = mock_df + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["batch_import_task"].delay.assert_called_once() + + def test_batch_import_app_annotations_empty_file( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test batch import with empty CSV file. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create empty CSV content + csv_content = "" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + # Mock pandas to return empty DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame() + mock_pd.read_csv.return_value = mock_df + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify error result + assert "error_msg" in result + assert "empty" in result["error_msg"].lower() + + def test_batch_import_app_annotations_quota_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test batch import when quota is exceeded. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create CSV content + csv_content = "Question 1,Answer 1\nQuestion 2,Answer 2\nQuestion 3,Answer 3" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + # Mock pandas to return DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame( + {0: ["Question 1", "Question 2", "Question 3"], 1: ["Answer 1", "Answer 2", "Answer 3"]} + ) + mock_pd.read_csv.return_value = mock_df + + # Mock FeatureService to return billing enabled with quota exceeded + mock_external_service_dependencies["feature_service"].get_features.return_value.billing.enabled = True + mock_external_service_dependencies[ + "feature_service" + ].get_features.return_value.annotation_quota_limit.limit = 1 + mock_external_service_dependencies[ + "feature_service" + ].get_features.return_value.annotation_quota_limit.size = 0 + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify error result + assert "error_msg" in result + assert "limit" in result["error_msg"].lower() + + def test_get_app_annotation_setting_by_app_id_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting enabled app annotation setting by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Get annotation setting + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Verify result structure + assert result["enabled"] is True + assert result["id"] == annotation_setting.id + assert result["score_threshold"] == 0.8 + assert result["embedding_model"]["embedding_provider_name"] == "openai" + assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" + + def test_get_app_annotation_setting_by_app_id_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting disabled app annotation setting by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Get annotation setting (no setting exists) + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Verify result structure + assert result["enabled"] is False + + def test_update_app_annotation_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful update of app annotation setting. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Update annotation setting + update_args = { + "score_threshold": 0.9, + } + + result = AppAnnotationService.update_app_annotation_setting(app.id, annotation_setting.id, update_args) + + # Verify result structure + assert result["enabled"] is True + assert result["id"] == annotation_setting.id + assert result["score_threshold"] == 0.9 + assert result["embedding_model"]["embedding_provider_name"] == "openai" + assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" + + # Verify database was updated + db.session.refresh(annotation_setting) + assert annotation_setting.score_threshold == 0.9 + + def test_export_annotation_list_by_app_id_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful export of annotation list by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple annotations + annotations = [] + for i in range(3): + annotation_args = { + "question": f"Question {i}: {fake.sentence()}", + "answer": f"Answer {i}: {fake.text(max_nb_chars=200)}", + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotations.append(annotation) + + # Export annotation list + exported_annotations = AppAnnotationService.export_annotation_list_by_app_id(app.id) + + # Verify results + assert len(exported_annotations) == 3 + + # Verify all annotations belong to the correct app and are ordered by created_at desc + for i, annotation in enumerate(exported_annotations): + assert annotation.app_id == app.id + assert annotation.account_id == account.id + if i > 0: + # Verify descending order (newer first) + assert annotation.created_at <= exported_annotations[i - 1].created_at + + def test_export_annotation_list_by_app_id_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test export of annotation list when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to export annotation list with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id) + + def test_insert_app_annotation_directly_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct insertion of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation directly + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + assert annotation.hit_count == 0 + assert annotation.id is not None + + # Verify add_annotation_to_index_task was called + mock_external_service_dependencies["add_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["add_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == annotation_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id + + def test_update_app_annotation_directly_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct update of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # First, create an annotation + original_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(original_args, app.id) + + # Reset mock to clear previous calls + mock_external_service_dependencies["update_task"].delay.reset_mock() + + # Update the annotation + updated_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + + # Verify annotation was updated correctly + assert updated_annotation.id == annotation.id + assert updated_annotation.app_id == app.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.account_id == account.id + + # Verify original values were changed + assert updated_annotation.question != original_args["question"] + assert updated_annotation.content != original_args["answer"] + + # Verify update_annotation_to_index_task was called + mock_external_service_dependencies["update_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["update_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == updated_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id + + def test_delete_app_annotation_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful deletion of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotation_id = annotation.id + + # Reset mock to clear previous calls + mock_external_service_dependencies["delete_task"].delay.reset_mock() + + # Delete the annotation + AppAnnotationService.delete_app_annotation(app.id, annotation_id) + + # Verify annotation was deleted + deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + assert deleted_annotation is None + + # Verify delete_annotation_index_task was called + mock_external_service_dependencies["delete_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["delete_task"].delay.call_args[0] + assert call_args[0] == annotation_id # annotation_id + assert call_args[1] == app.id # app_id + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == collection_binding.id # collection_binding_id + + def test_up_insert_app_annotation_from_message_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating annotation from message with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Setup annotation data with message_id + annotation_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation from message + annotation = AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.conversation_id == conversation.id + assert annotation.message_id == message.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + + # Verify add_annotation_to_index_task was called + mock_external_service_dependencies["add_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["add_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == annotation_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py new file mode 100644 index 0000000000..69cd9fafee --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -0,0 +1,928 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from constants.model_template import default_app_templates +from models.model import App, Site +from services.account_service import AccountService, TenantService +from services.app_service import AppService + + +class TestAppService: + """Integration tests for AppService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns for app service + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Mock ModelManager for model configuration + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app creation with basic parameters. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + # Create app + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Verify app was created correctly + assert app.name == app_args["name"] + assert app.description == app_args["description"] + assert app.mode == app_args["mode"] + assert app.icon_type == app_args["icon_type"] + assert app.icon == app_args["icon"] + assert app.icon_background == app_args["icon_background"] + assert app.tenant_id == tenant.id + assert app.api_rph == app_args["api_rph"] + assert app.api_rpm == app_args["api_rpm"] + assert app.created_by == account.id + assert app.updated_by == account.id + assert app.status == "normal" + assert app.enable_site is True + assert app.enable_api is True + assert app.is_demo is False + assert app.is_public is False + assert app.is_universal is False + + def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app creation with different app modes. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Test different app modes + # from AppMode enum in default_app_model_template + app_modes = [v.value for v in default_app_templates] + + for mode in app_modes: + app_args = { + "name": f"{fake.company()} {mode}", + "description": f"Test app for {mode} mode", + "mode": mode, + "icon_type": "emoji", + "icon": "🚀", + "icon_background": "#4ECDC4", + } + + app = app_service.create_app(tenant.id, app_args, account) + + # Verify app mode was set correctly + assert app.mode == mode + assert app.name == app_args["name"] + assert app.tenant_id == tenant.id + assert app.created_by == account.id + + def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app retrieval. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + created_app = app_service.create_app(tenant.id, app_args, account) + + # Get app using the service + retrieved_app = app_service.get_app(created_app) + + # Verify retrieved app matches created app + assert retrieved_app.id == created_app.id + assert retrieved_app.name == created_app.name + assert retrieved_app.description == created_app.description + assert retrieved_app.mode == created_app.mode + assert retrieved_app.tenant_id == created_app.tenant_id + assert retrieved_app.created_by == created_app.created_by + + def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful paginated app list retrieval. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Create multiple apps + app_names = [fake.company() for _ in range(5)] + for name in app_names: + app_args = { + "name": name, + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "📱", + "icon_background": "#96CEB4", + } + app_service.create_app(tenant.id, app_args, account) + + # Get paginated apps + args = { + "page": 1, + "limit": 10, + "mode": "chat", + } + + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + + # Verify pagination results + assert paginated_apps is not None + assert len(paginated_apps.items) >= 5 # Should have at least 5 apps + assert paginated_apps.page == 1 + assert paginated_apps.per_page == 10 + + # Verify all apps belong to the correct tenant + for app in paginated_apps.items: + assert app.tenant_id == tenant.id + assert app.mode == "chat" + + def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test paginated app list with various filters. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Create apps with different modes + chat_app_args = { + "name": "Chat App", + "description": "A chat application", + "mode": "chat", + "icon_type": "emoji", + "icon": "💬", + "icon_background": "#FF6B6B", + } + completion_app_args = { + "name": "Completion App", + "description": "A completion application", + "mode": "completion", + "icon_type": "emoji", + "icon": "✍️", + "icon_background": "#4ECDC4", + } + + chat_app = app_service.create_app(tenant.id, chat_app_args, account) + completion_app = app_service.create_app(tenant.id, completion_app_args, account) + + # Test filter by mode + chat_args = { + "page": 1, + "limit": 10, + "mode": "chat", + } + chat_apps = app_service.get_paginate_apps(account.id, tenant.id, chat_args) + assert len(chat_apps.items) == 1 + assert chat_apps.items[0].mode == "chat" + + # Test filter by name + name_args = { + "page": 1, + "limit": 10, + "mode": "chat", + "name": "Chat", + } + filtered_apps = app_service.get_paginate_apps(account.id, tenant.id, name_args) + assert len(filtered_apps.items) == 1 + assert "Chat" in filtered_apps.items[0].name + + # Test filter by created_by_me + created_by_me_args = { + "page": 1, + "limit": 10, + "mode": "completion", + "is_created_by_me": True, + } + my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args) + assert len(my_apps.items) == 1 + + def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test paginated app list with tag filters. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Create an app + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🏷️", + "icon_background": "#FFEAA7", + } + app = app_service.create_app(tenant.id, app_args, account) + + # Mock TagService to return the app ID for tag filtering + with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service: + mock_tag_service.return_value = [app.id] + + # Test with tag filter + args = { + "page": 1, + "limit": 10, + "mode": "chat", + "tag_ids": ["tag1", "tag2"], + } + + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + + # Verify tag service was called + mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"]) + + # Verify results + assert paginated_apps is not None + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].id == app.id + + # Test with tag filter that returns no results + with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service: + mock_tag_service.return_value = [] + + args = { + "page": 1, + "limit": 10, + "mode": "chat", + "tag_ids": ["nonexistent_tag"], + } + + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + + # Should return None when no apps match tag filter + assert paginated_apps is None + + def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app update with all fields. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original values + original_name = app.name + original_description = app.description + original_icon = app.icon + original_icon_background = app.icon_background + original_use_icon_as_answer_icon = app.use_icon_as_answer_icon + + # Update app + update_args = { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": "emoji", + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + } + + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app(app, update_args) + + # Verify updated fields + assert updated_app.name == update_args["name"] + assert updated_app.description == update_args["description"] + assert updated_app.icon == update_args["icon"] + assert updated_app.icon_background == update_args["icon_background"] + assert updated_app.use_icon_as_answer_icon is True + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app name update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original name + original_name = app.name + + # Update app name + new_name = "New App Name" + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_name(app, new_name) + + assert updated_app.name == new_name + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app icon update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original values + original_icon = app.icon + original_icon_background = app.icon_background + + # Update app icon + new_icon = "🌟" + new_icon_background = "#FFD93D" + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) + + assert updated_app.icon == new_icon + assert updated_app.icon_background == new_icon_background + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app site status update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🌐", + "icon_background": "#74B9FF", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original site status + original_site_status = app.enable_site + + # Update site status to disabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_site_status(app, False) + assert updated_app.enable_site is False + assert updated_app.updated_by == account.id + + # Update site status back to enabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_site_status(updated_app, True) + assert updated_app.enable_site is True + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app API status update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🔌", + "icon_background": "#A29BFE", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original API status + original_api_status = app.enable_api + + # Update API status to disabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_api_status(app, False) + assert updated_app.enable_api is False + assert updated_app.updated_by == account.id + + # Update API status back to enabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_api_status(updated_app, True) + assert updated_app.enable_api is True + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app site status update when status doesn't change. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🔄", + "icon_background": "#FD79A8", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original values + original_site_status = app.enable_site + original_updated_at = app.updated_at + + # Update site status to the same value (no change) + updated_app = app_service.update_app_site_status(app, original_site_status) + + # Verify app is returned unchanged + assert updated_app.id == app.id + assert updated_app.enable_site == original_site_status + assert updated_app.updated_at == original_updated_at + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app deletion. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🗑️", + "icon_background": "#E17055", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store app ID for verification + app_id = app.id + + # Mock the async deletion task + with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task: + mock_delete_task.delay.return_value = None + + # Delete app + app_service.delete_app(app) + + # Verify async deletion task was called + mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) + + # Verify app was deleted from database + from extensions.ext_database import db + + deleted_app = db.session.query(App).filter_by(id=app_id).first() + assert deleted_app is None + + def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app deletion with related data cleanup. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🧹", + "icon_background": "#00B894", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store app ID for verification + app_id = app.id + + # Mock webapp auth cleanup + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.webapp_auth.enabled = True + + # Mock the async deletion task + with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task: + mock_delete_task.delay.return_value = None + + # Delete app + app_service.delete_app(app) + + # Verify webapp auth cleanup was called + mock_external_service_dependencies["enterprise_service"].WebAppAuth.cleanup_webapp.assert_called_once_with( + app_id + ) + + # Verify async deletion task was called + mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) + + # Verify app was deleted from database + from extensions.ext_database import db + + deleted_app = db.session.query(App).filter_by(id=app_id).first() + assert deleted_app is None + + def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app metadata retrieval. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "📊", + "icon_background": "#6C5CE7", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Get app metadata + app_meta = app_service.get_app_meta(app) + + # Verify metadata contains expected fields + assert "tool_icons" in app_meta + # Note: get_app_meta currently only returns tool_icons + + def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app code retrieval by app ID. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🔗", + "icon_background": "#FDCB6E", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Get app code by ID + app_code = AppService.get_app_code_by_id(app.id) + + # Verify app code was retrieved correctly + # Note: Site would be created when App is created, site.code is auto-generated + assert app_code is not None + assert len(app_code) > 0 + + def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app ID retrieval by app code. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🆔", + "icon_background": "#E84393", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Create a site for the app + site = Site() + site.app_id = app.id + site.code = fake.postalcode() + site.title = fake.company() + site.status = "normal" + site.default_language = "en-US" + site.customize_token_strategy = "uuid" + from extensions.ext_database import db + + db.session.add(site) + db.session.commit() + + # Get app ID by code + app_id = AppService.get_app_id_by_code(site.code) + + # Verify app ID was retrieved correctly + assert app_id == app.id + + def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app creation with invalid mode. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments with invalid mode + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "invalid_mode", # Invalid mode + "icon_type": "emoji", + "icon": "❌", + "icon_background": "#D63031", + } + + app_service = AppService() + + # Attempt to create app with invalid mode + with pytest.raises(ValueError, match="invalid mode value"): + app_service.create_app(tenant.id, app_args, account) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py new file mode 100644 index 0000000000..85a9355c79 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -0,0 +1,739 @@ +import pytest +from faker import Faker + +from core.variables.segments import StringSegment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from models import App, Workflow +from models.enums import DraftVariableType +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import ( + UpdateNotSupportedError, + WorkflowDraftVariableService, +) + + +class TestWorkflowDraftVariableService: + """ + Comprehensive integration tests for WorkflowDraftVariableService using testcontainers. + + This test class covers all major functionality of the WorkflowDraftVariableService: + - CRUD operations for workflow draft variables (Create, Read, Update, Delete) + - Variable listing and filtering by type (conversation, system, node) + - Variable updates and resets with proper validation + - Variable deletion operations at different scopes + - Special functionality like prefill and conversation ID retrieval + - Error handling for various edge cases and invalid operations + + All tests use the testcontainers infrastructure to ensure proper database isolation + and realistic testing environment with actual database interactions. + """ + + @pytest.fixture + def mock_external_service_dependencies(self): + """ + Mock setup for external service dependencies. + + WorkflowDraftVariableService doesn't have external dependencies that need mocking, + so this fixture returns an empty dictionary to maintain consistency with other test classes. + This ensures the test structure remains consistent across different service test files. + """ + # WorkflowDraftVariableService doesn't have external dependencies that need mocking + return {} + + def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None): + """ + Helper method to create a test app with realistic data for testing. + + This method creates a complete App instance with all required fields populated + using Faker for generating realistic test data. The app is configured for + workflow mode to support workflow draft variable testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies (unused in this service) + fake: Faker instance for generating test data, creates new instance if not provided + + Returns: + App: Created test app instance with all required fields populated + """ + fake = fake or Faker() + app = App() + app.id = fake.uuid4() + app.tenant_id = fake.uuid4() + app.name = fake.company() + app.description = fake.text() + app.mode = "workflow" + app.icon_type = "emoji" + app.icon = "🤖" + app.icon_background = "#FFEAD5" + app.enable_site = True + app.enable_api = True + app.created_by = fake.uuid4() + app.updated_by = app.created_by + + from extensions.ext_database import db + + db.session.add(app) + db.session.commit() + return app + + def _create_test_workflow(self, db_session_with_containers, app, fake=None): + """ + Helper method to create a test workflow associated with an app. + + This method creates a Workflow instance using the proper factory method + to ensure all required fields are set correctly. The workflow is configured + as a draft version with basic graph structure for testing workflow variables. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: The app to associate the workflow with + fake: Faker instance for generating test data, creates new instance if not provided + + Returns: + Workflow: Created test workflow instance with proper configuration + """ + fake = fake or Faker() + workflow = Workflow.new( + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + version="draft", + graph='{"nodes": [], "edges": []}', + features="{}", + created_by=app.created_by, + environment_variables=[], + conversation_variables=[], + ) + from extensions.ext_database import db + + db.session.add(workflow) + db.session.commit() + return workflow + + def _create_test_variable( + self, db_session_with_containers, app_id, node_id, name, value, variable_type="conversation", fake=None + ): + """ + Helper method to create a test workflow draft variable with proper configuration. + + This method creates different types of variables (conversation, system, node) using + the appropriate factory methods to ensure proper initialization. Each variable type + has specific requirements and this method handles the creation logic for all types. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app_id: ID of the app to associate the variable with + node_id: ID of the node (or special constants like CONVERSATION_VARIABLE_NODE_ID) + name: Name of the variable for identification + value: StringSegment value for the variable content + variable_type: Type of variable ("conversation", "system", "node") determining creation method + fake: Faker instance for generating test data, creates new instance if not provided + + Returns: + WorkflowDraftVariable: Created test variable instance with proper type configuration + """ + fake = fake or Faker() + if variable_type == "conversation": + # Create conversation variable using the appropriate factory method + variable = WorkflowDraftVariable.new_conversation_variable( + app_id=app_id, + name=name, + value=value, + description=fake.text(max_nb_chars=20), + ) + elif variable_type == "system": + # Create system variable with editable flag and execution context + variable = WorkflowDraftVariable.new_sys_variable( + app_id=app_id, + name=name, + value=value, + node_execution_id=fake.uuid4(), + editable=True, + ) + else: # node variable + # Create node variable with visibility and editability settings + variable = WorkflowDraftVariable.new_node_variable( + app_id=app_id, + node_id=node_id, + name=name, + value=value, + node_execution_id=fake.uuid4(), + visible=True, + editable=True, + ) + from extensions.ext_database import db + + db.session.add(variable) + db.session.commit() + return variable + + def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting a single variable by ID successfully. + + This test verifies that the service can retrieve a specific variable + by its ID and that the returned variable contains the correct data. + It ensures the basic CRUD read operation works correctly for workflow draft variables. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + test_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_variable = service.get_variable(variable.id) + assert retrieved_variable is not None + assert retrieved_variable.id == variable.id + assert retrieved_variable.name == "test_var" + assert retrieved_variable.app_id == app.id + assert retrieved_variable.get_value().value == test_value.value + + def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting a variable that doesn't exist. + + This test verifies that the service returns None when trying to + retrieve a variable with a non-existent ID. This ensures proper + handling of missing data scenarios. + """ + fake = Faker() + non_existent_id = fake.uuid4() + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_variable = service.get_variable(non_existent_id) + assert retrieved_variable is None + + def test_get_draft_variables_by_selectors_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting variables by selectors successfully. + + This test verifies that the service can retrieve multiple variables + using selector pairs (node_id, variable_name) and returns the correct + variables for each selector. This is useful for bulk variable retrieval + operations in workflow execution contexts. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + var1_value = StringSegment(value=fake.word()) + var2_value = StringSegment(value=fake.word()) + var3_value = StringSegment(value=fake.word()) + var1 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var1", var1_value, fake=fake + ) + var2 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var2", var2_value, fake=fake + ) + var3 = self._create_test_variable( + db_session_with_containers, app.id, "test_node_1", "var3", var3_value, "node", fake=fake + ) + selectors = [ + [CONVERSATION_VARIABLE_NODE_ID, "var1"], + [CONVERSATION_VARIABLE_NODE_ID, "var2"], + ["test_node_1", "var3"], + ] + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors) + assert len(retrieved_variables) == 3 + var_names = [var.name for var in retrieved_variables] + assert "var1" in var_names + assert "var2" in var_names + assert "var3" in var_names + for var in retrieved_variables: + if var.name == "var1": + assert var.get_value().value == var1_value.value + elif var.name == "var2": + assert var.get_value().value == var2_value.value + elif var.name == "var3": + assert var.get_value().value == var3_value.value + + def test_list_variables_without_values_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test listing variables without values successfully with pagination. + + This test verifies that the service can list variables with pagination + and that the returned variables don't include their values (for performance). + This is important for scenarios where only variable metadata is needed + without loading the actual content. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + for i in range(5): + test_value = StringSegment(value=fake.numerify("value##")) + self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_variables_without_values(app.id, page=1, limit=3) + assert result.total == 5 + assert len(result.variables) == 3 + assert result.variables[0].created_at >= result.variables[1].created_at + assert result.variables[1].created_at >= result.variables[2].created_at + for var in result.variables: + assert var.name is not None + assert var.app_id == app.id + + def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test listing variables for a specific node successfully. + + This test verifies that the service can filter and return only + variables associated with a specific node ID. This is crucial for + workflow execution where variables need to be scoped to specific nodes. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + node_id = fake.word() + var1_value = StringSegment(value=fake.word()) + var2_value = StringSegment(value=fake.word()) + var3_value = StringSegment(value=fake.word()) + self._create_test_variable(db_session_with_containers, app.id, node_id, "var1", var1_value, "node", fake=fake) + self._create_test_variable(db_session_with_containers, app.id, node_id, "var2", var3_value, "node", fake=fake) + self._create_test_variable( + db_session_with_containers, app.id, "other_node", "var3", var2_value, "node", fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_node_variables(app.id, node_id) + assert len(result.variables) == 2 + for var in result.variables: + assert var.node_id == node_id + assert var.app_id == app.id + var_names = [var.name for var in result.variables] + assert "var1" in var_names + assert "var2" in var_names + assert "var3" not in var_names + + def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test listing conversation variables successfully. + + This test verifies that the service can filter and return only + conversation variables, excluding system and node variables. + Conversation variables are user-facing variables that can be + modified during conversation flows. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + conv_var1_value = StringSegment(value=fake.word()) + conv_var2_value = StringSegment(value=fake.word()) + conv_var1 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var1", conv_var1_value, fake=fake + ) + conv_var2 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var2", conv_var2_value, fake=fake + ) + sys_var_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var", sys_var_value, "system", fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_conversation_variables(app.id) + assert len(result.variables) == 2 + for var in result.variables: + assert var.node_id == CONVERSATION_VARIABLE_NODE_ID + assert var.app_id == app.id + assert var.get_variable_type() == DraftVariableType.CONVERSATION + var_names = [var.name for var in result.variables] + assert "conv_var1" in var_names + assert "conv_var2" in var_names + assert "sys_var" not in var_names + + def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating a variable's name and value successfully. + + This test verifies that the service can update both the name and value + of an editable variable and that the changes are persisted correctly. + It also checks that the last_edited_at timestamp is updated appropriately. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + original_value = StringSegment(value=fake.word()) + new_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + "original_name", + original_value, + fake=fake, + ) + service = WorkflowDraftVariableService(db_session_with_containers) + updated_variable = service.update_variable(variable, name="new_name", value=new_value) + assert updated_variable.name == "new_name" + assert updated_variable.get_value().value == new_value.value + assert updated_variable.last_edited_at is not None + from extensions.ext_database import db + + db.session.refresh(variable) + assert variable.name == "new_name" + assert variable.get_value().value == new_value.value + assert variable.last_edited_at is not None + + def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test that updating a non-editable variable raises an exception. + + This test verifies that the service properly prevents updates to + variables that are not marked as editable. This is important for + maintaining data integrity and preventing unauthorized modifications + to system-controlled variables. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + original_value = StringSegment(value=fake.word()) + new_value = StringSegment(value=fake.word()) + variable = WorkflowDraftVariable.new_sys_variable( + app_id=app.id, + name=fake.word(), # This is typically not editable + value=original_value, + node_execution_id=fake.uuid4(), + editable=False, # Set as non-editable + ) + from extensions.ext_database import db + + db.session.add(variable) + db.session.commit() + service = WorkflowDraftVariableService(db_session_with_containers) + with pytest.raises(UpdateNotSupportedError) as exc_info: + service.update_variable(variable, name="new_name", value=new_value) + assert "variable not support updating" in str(exc_info.value) + assert variable.id in str(exc_info.value) + + def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test resetting conversation variable successfully. + + This test verifies that the service can reset a conversation variable + to its default value and clear the last_edited_at timestamp. + This functionality is useful for reverting user modifications + back to the original workflow configuration. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) + from core.variables.variables import StringVariable + + conv_var = StringVariable( + id=fake.uuid4(), + name="test_conv_var", + value="default_value", + selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"], + ) + workflow.conversation_variables = [conv_var] + from extensions.ext_database import db + + db.session.commit() + modified_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + "test_conv_var", + modified_value, + fake=fake, + ) + variable.last_edited_at = fake.date_time() + db.session.commit() + service = WorkflowDraftVariableService(db_session_with_containers) + reset_variable = service.reset_variable(workflow, variable) + assert reset_variable is not None + assert reset_variable.get_value().value == "default_value" + assert reset_variable.last_edited_at is None + db.session.refresh(variable) + assert variable.get_value().value == "default_value" + assert variable.last_edited_at is None + + def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deleting a single variable successfully. + + This test verifies that the service can delete a specific variable + and that it's properly removed from the database. It ensures that + the deletion operation is atomic and complete. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + test_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake + ) + from extensions.ext_database import db + + assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None + service = WorkflowDraftVariableService(db_session_with_containers) + service.delete_variable(variable) + assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None + + def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deleting all variables for a workflow successfully. + + This test verifies that the service can delete all variables + associated with a specific app/workflow. This is useful for + cleanup operations when workflows are deleted or reset. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + for i in range(3): + test_value = StringSegment(value=fake.numerify("value##")) + self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake + ) + other_app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + other_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, other_app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), other_value, fake=fake + ) + from extensions.ext_database import db + + app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + assert len(app_variables) == 3 + assert len(other_app_variables) == 1 + service = WorkflowDraftVariableService(db_session_with_containers) + service.delete_workflow_variables(app.id) + app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + assert len(app_variables_after) == 0 + assert len(other_app_variables_after) == 1 + + def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deleting all variables for a specific node successfully. + + This test verifies that the service can delete all variables + associated with a specific node while preserving variables + for other nodes and conversation variables. This is important + for node-specific cleanup operations in workflow management. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + node_id = fake.word() + for i in range(2): + test_value = StringSegment(value=fake.numerify("node_value##")) + self._create_test_variable( + db_session_with_containers, app.id, node_id, fake.word(), test_value, "node", fake=fake + ) + other_node_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, app.id, "other_node", fake.word(), other_node_value, "node", fake=fake + ) + conv_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), conv_value, fake=fake + ) + from extensions.ext_database import db + + target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + other_node_variables = ( + db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + ) + conv_variables = ( + db.session.query(WorkflowDraftVariable) + .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) + .all() + ) + assert len(target_node_variables) == 2 + assert len(other_node_variables) == 1 + assert len(conv_variables) == 1 + service = WorkflowDraftVariableService(db_session_with_containers) + service.delete_node_variables(app.id, node_id) + target_node_variables_after = ( + db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + ) + other_node_variables_after = ( + db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + ) + conv_variables_after = ( + db.session.query(WorkflowDraftVariable) + .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) + .all() + ) + assert len(target_node_variables_after) == 0 + assert len(other_node_variables_after) == 1 + assert len(conv_variables_after) == 1 + + def test_prefill_conversation_variable_default_values_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test prefill conversation variable default values successfully. + + This test verifies that the service can automatically create + conversation variables with default values based on the workflow + configuration when none exist. This is important for initializing + workflow variables with proper defaults from the workflow definition. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) + from core.variables.variables import StringVariable + + conv_var1 = StringVariable( + id=fake.uuid4(), + name="conv_var1", + value="default_value1", + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var1"], + ) + conv_var2 = StringVariable( + id=fake.uuid4(), + name="conv_var2", + value="default_value2", + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"], + ) + workflow.conversation_variables = [conv_var1, conv_var2] + from extensions.ext_database import db + + db.session.commit() + service = WorkflowDraftVariableService(db_session_with_containers) + service.prefill_conversation_variable_default_values(workflow) + draft_variables = ( + db.session.query(WorkflowDraftVariable) + .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) + .all() + ) + assert len(draft_variables) == 2 + var_names = [var.name for var in draft_variables] + assert "conv_var1" in var_names + assert "conv_var2" in var_names + for var in draft_variables: + assert var.app_id == app.id + assert var.node_id == CONVERSATION_VARIABLE_NODE_ID + assert var.editable is True + assert var.get_variable_type() == DraftVariableType.CONVERSATION + + def test_get_conversation_id_from_draft_variable_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting conversation ID from draft variable successfully. + + This test verifies that the service can extract the conversation ID + from a system variable named "conversation_id". This is important + for maintaining conversation context across workflow executions. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + conversation_id = fake.uuid4() + conv_id_value = StringSegment(value=conversation_id) + self._create_test_variable( + db_session_with_containers, + app.id, + SYSTEM_VARIABLE_NODE_ID, + "conversation_id", + conv_id_value, + "system", + fake=fake, + ) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) + assert retrieved_conv_id == conversation_id + + def test_get_conversation_id_from_draft_variable_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting conversation ID when it doesn't exist. + + This test verifies that the service returns None when no + conversation_id variable exists for the app. This ensures + proper handling of missing conversation context scenarios. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) + assert retrieved_conv_id is None + + def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test listing system variables successfully. + + This test verifies that the service can filter and return only + system variables, excluding conversation and node variables. + System variables are internal variables used by the workflow + engine for maintaining state and context. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + sys_var1_value = StringSegment(value=fake.word()) + sys_var2_value = StringSegment(value=fake.word()) + sys_var1 = self._create_test_variable( + db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var1", sys_var1_value, "system", fake=fake + ) + sys_var2 = self._create_test_variable( + db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var2", sys_var2_value, "system", fake=fake + ) + conv_var_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var", conv_var_value, fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_system_variables(app.id) + assert len(result.variables) == 2 + for var in result.variables: + assert var.node_id == SYSTEM_VARIABLE_NODE_ID + assert var.app_id == app.id + assert var.get_variable_type() == DraftVariableType.SYS + var_names = [var.name for var in result.variables] + assert "sys_var1" in var_names + assert "sys_var2" in var_names + assert "conv_var" not in var_names + + def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting variables by name successfully for different types. + + This test verifies that the service can retrieve variables by name + for different variable types (conversation, system, node). This + functionality is important for variable lookup operations during + workflow execution and user interactions. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + test_value = StringSegment(value=fake.word()) + conv_var = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_conv_var", test_value, fake=fake + ) + sys_var = self._create_test_variable( + db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "test_sys_var", test_value, "system", fake=fake + ) + node_var = self._create_test_variable( + db_session_with_containers, app.id, "test_node", "test_node_var", test_value, "node", fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var") + assert retrieved_conv_var is not None + assert retrieved_conv_var.name == "test_conv_var" + assert retrieved_conv_var.node_id == CONVERSATION_VARIABLE_NODE_ID + retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var") + assert retrieved_sys_var is not None + assert retrieved_sys_var.name == "test_sys_var" + assert retrieved_sys_var.node_id == SYSTEM_VARIABLE_NODE_ID + retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var") + assert retrieved_node_var is not None + assert retrieved_node_var.name == "test_node_var" + assert retrieved_node_var.node_id == "test_node" + + def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting variables by name when they don't exist. + + This test verifies that the service returns None when trying to + retrieve variables by name that don't exist. This ensures proper + handling of missing variable scenarios for all variable types. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var") + assert retrieved_conv_var is None + retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var") + assert retrieved_sys_var is None + retrieved_node_var = service.get_node_variable(app.id, "test_node", "non_existent_node_var") + assert retrieved_node_var is None diff --git a/api/tests/test_containers_integration_tests/workflow/__init__.py b/api/tests/test_containers_integration_tests/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/__init__.py b/api/tests/test_containers_integration_tests/workflow/nodes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/__init__.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_executor.py new file mode 100644 index 0000000000..487178ff58 --- /dev/null +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -0,0 +1,11 @@ +import pytest + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor + +CODE_LANGUAGE = "unsupported_language" + + +def test_unsupported_with_code_template(): + with pytest.raises(CodeExecutionError) as e: + CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) + assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py new file mode 100644 index 0000000000..19a41b6186 --- /dev/null +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -0,0 +1,47 @@ +from textwrap import dedent + +from .test_utils import CodeExecutorTestMixin + + +class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): + """Test class for JavaScript code executor functionality.""" + + def test_javascript_plain(self, flask_app_with_containers): + """Test basic JavaScript code execution with console.log output""" + CodeExecutor, CodeLanguage = self.code_executor_imports + + code = 'console.log("Hello World")' + result_message = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) + assert result_message == "Hello World\n" + + def test_javascript_json(self, flask_app_with_containers): + """Test JavaScript code execution with JSON output""" + CodeExecutor, CodeLanguage = self.code_executor_imports + + code = dedent(""" + obj = {'Hello': 'World'} + console.log(JSON.stringify(obj)) + """) + result = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) + assert result == '{"Hello":"World"}\n' + + def test_javascript_with_code_template(self, flask_app_with_containers): + """Test JavaScript workflow code template execution with inputs""" + CodeExecutor, CodeLanguage = self.code_executor_imports + JavascriptCodeProvider, _ = self.javascript_imports + + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JAVASCRIPT, + code=JavascriptCodeProvider.get_default_code(), + inputs={"arg1": "Hello", "arg2": "World"}, + ) + assert result == {"result": "HelloWorld"} + + def test_javascript_get_runner_script(self, flask_app_with_containers): + """Test JavaScript template transformer runner script generation""" + _, NodeJsTemplateTransformer = self.javascript_imports + + runner_script = NodeJsTemplateTransformer.get_runner_script() + assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1 + assert runner_script.count(NodeJsTemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(NodeJsTemplateTransformer._result_tag) == 2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py new file mode 100644 index 0000000000..c764801170 --- /dev/null +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -0,0 +1,42 @@ +import base64 + +from .test_utils import CodeExecutorTestMixin + + +class TestJinja2CodeExecutor(CodeExecutorTestMixin): + """Test class for Jinja2 code executor functionality.""" + + def test_jinja2(self, flask_app_with_containers): + """Test basic Jinja2 template execution with variable substitution""" + CodeExecutor, CodeLanguage = self.code_executor_imports + _, Jinja2TemplateTransformer = self.jinja2_imports + + template = "Hello {{template}}" + inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") + code = ( + Jinja2TemplateTransformer.get_runner_script() + .replace(Jinja2TemplateTransformer._code_placeholder, template) + .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) + ) + result = CodeExecutor.execute_code( + language=CodeLanguage.JINJA2, preload=Jinja2TemplateTransformer.get_preload_script(), code=code + ) + assert result == "<>Hello World<>\n" + + def test_jinja2_with_code_template(self, flask_app_with_containers): + """Test Jinja2 workflow code template execution with inputs""" + CodeExecutor, CodeLanguage = self.code_executor_imports + + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, code="Hello {{template}}", inputs={"template": "World"} + ) + assert result == {"result": "Hello World"} + + def test_jinja2_get_runner_script(self, flask_app_with_containers): + """Test Jinja2 template transformer runner script generation""" + _, Jinja2TemplateTransformer = self.jinja2_imports + + runner_script = Jinja2TemplateTransformer.get_runner_script() + assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1 + assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py new file mode 100644 index 0000000000..6d93df2472 --- /dev/null +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -0,0 +1,47 @@ +from textwrap import dedent + +from .test_utils import CodeExecutorTestMixin + + +class TestPython3CodeExecutor(CodeExecutorTestMixin): + """Test class for Python3 code executor functionality.""" + + def test_python3_plain(self, flask_app_with_containers): + """Test basic Python3 code execution with print output""" + CodeExecutor, CodeLanguage = self.code_executor_imports + + code = 'print("Hello World")' + result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) + assert result == "Hello World\n" + + def test_python3_json(self, flask_app_with_containers): + """Test Python3 code execution with JSON output""" + CodeExecutor, CodeLanguage = self.code_executor_imports + + code = dedent(""" + import json + print(json.dumps({'Hello': 'World'})) + """) + result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) + assert result == '{"Hello": "World"}\n' + + def test_python3_with_code_template(self, flask_app_with_containers): + """Test Python3 workflow code template execution with inputs""" + CodeExecutor, CodeLanguage = self.code_executor_imports + Python3CodeProvider, _ = self.python3_imports + + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.PYTHON3, + code=Python3CodeProvider.get_default_code(), + inputs={"arg1": "Hello", "arg2": "World"}, + ) + assert result == {"result": "HelloWorld"} + + def test_python3_get_runner_script(self, flask_app_with_containers): + """Test Python3 template transformer runner script generation""" + _, Python3TemplateTransformer = self.python3_imports + + runner_script = Python3TemplateTransformer.get_runner_script() + assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1 + assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(Python3TemplateTransformer._result_tag) == 2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_utils.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_utils.py new file mode 100644 index 0000000000..35a095b049 --- /dev/null +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_utils.py @@ -0,0 +1,115 @@ +""" +Test utilities for code executor integration tests. + +This module provides lazy import functions to avoid module loading issues +that occur when modules are imported before the flask_app_with_containers fixture +has set up the proper environment variables and configuration. +""" + +import importlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + pass + + +def force_reload_code_executor(): + """ + Force reload the code_executor module to reinitialize code_execution_endpoint_url. + + This function should be called after setting up environment variables + to ensure the code_execution_endpoint_url is initialized with the correct value. + """ + try: + import core.helper.code_executor.code_executor + + importlib.reload(core.helper.code_executor.code_executor) + except Exception as e: + # Log the error but don't fail the test + print(f"Warning: Failed to reload code_executor module: {e}") + + +def get_code_executor_imports(): + """ + Lazy import function for core CodeExecutor classes. + + Returns: + tuple: (CodeExecutor, CodeLanguage) classes + """ + from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage + + return CodeExecutor, CodeLanguage + + +def get_javascript_imports(): + """ + Lazy import function for JavaScript-specific modules. + + Returns: + tuple: (JavascriptCodeProvider, NodeJsTemplateTransformer) classes + """ + from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider + from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer + + return JavascriptCodeProvider, NodeJsTemplateTransformer + + +def get_python3_imports(): + """ + Lazy import function for Python3-specific modules. + + Returns: + tuple: (Python3CodeProvider, Python3TemplateTransformer) classes + """ + from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider + from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer + + return Python3CodeProvider, Python3TemplateTransformer + + +def get_jinja2_imports(): + """ + Lazy import function for Jinja2-specific modules. + + Returns: + tuple: (None, Jinja2TemplateTransformer) classes + """ + from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer + + return None, Jinja2TemplateTransformer + + +class CodeExecutorTestMixin: + """ + Mixin class providing lazy import methods for code executor tests. + + This mixin helps avoid module loading issues by deferring imports + until after the flask_app_with_containers fixture has set up the environment. + """ + + def setup_method(self): + """ + Setup method called before each test method. + Force reload the code_executor module to ensure fresh initialization. + """ + force_reload_code_executor() + + @property + def code_executor_imports(self): + """Property to get CodeExecutor and CodeLanguage classes.""" + return get_code_executor_imports() + + @property + def javascript_imports(self): + """Property to get JavaScript-specific classes.""" + return get_javascript_imports() + + @property + def python3_imports(self): + """Property to get Python3-specific classes.""" + return get_python3_imports() + + @property + def jinja2_imports(self): + """Property to get Jinja2-specific classes.""" + return get_jinja2_imports() diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index e9d4ee1935..0ae6a09f5b 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -1,5 +1,6 @@ import os +import pytest from flask import Flask from packaging.version import Version from yarl import URL @@ -137,3 +138,61 @@ def test_db_extras_options_merging(monkeypatch): options = engine_options["connect_args"]["options"] assert "search_path=myschema" in options assert "timezone=UTC" in options + + +@pytest.mark.parametrize( + ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), + [ + ("redis://localhost:6379/1", "localhost", 6379, None, None, "1"), + ("redis://:password@localhost:6379/1", "localhost", 6379, None, "password", "1"), + ("redis://:mypass%23123@localhost:6379/1", "localhost", 6379, None, "mypass#123", "1"), + ("redis://user:pass%40word@redis-host:6380/2", "redis-host", 6380, "user", "pass@word", "2"), + ("redis://admin:complex%23pass%40word@127.0.0.1:6379/0", "127.0.0.1", 6379, "admin", "complex#pass@word", "0"), + ( + "redis://user%40domain:secret%23123@redis.example.com:6380/3", + "redis.example.com", + 6380, + "user@domain", + "secret#123", + "3", + ), + # Password containing %23 substring (double encoding scenario) + ("redis://:mypass%2523@localhost:6379/1", "localhost", 6379, None, "mypass%23", "1"), + # Username and password both containing encoded characters + ("redis://user%2525%40:pass%2523@localhost:6379/1", "localhost", 6379, "user%25@", "pass%23", "1"), + ], +) +def test_celery_broker_url_with_special_chars_password( + monkeypatch, broker_url, expected_host, expected_port, expected_username, expected_password, expected_db +): + """Test that CELERY_BROKER_URL with various formats are handled correctly.""" + from kombu.utils.url import parse_url + + # clear system environment variables + os.environ.clear() + + # Set up basic required environment variables (following existing pattern) + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + # Set the CELERY_BROKER_URL to test + monkeypatch.setenv("CELERY_BROKER_URL", broker_url) + + # Create config and verify the URL is stored correctly + config = DifyConfig() + assert broker_url == config.CELERY_BROKER_URL + + # Test actual parsing behavior using kombu's parse_url (same as production) + redis_config = parse_url(config.CELERY_BROKER_URL) + + # Verify the parsing results match expectations (using kombu's field names) + assert redis_config["hostname"] == expected_host + assert redis_config["port"] == expected_port + assert redis_config["userid"] == expected_username # kombu uses 'userid' not 'username' + assert redis_config["password"] == expected_password + assert redis_config["virtual_host"] == expected_db # kombu uses 'virtual_host' not 'db' diff --git a/api/tests/unit_tests/controllers/console/app/test_description_validation.py b/api/tests/unit_tests/controllers/console/app/test_description_validation.py new file mode 100644 index 0000000000..178267e560 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_description_validation.py @@ -0,0 +1,252 @@ +import pytest + +from controllers.console.app.app import _validate_description_length as app_validate +from controllers.console.datasets.datasets import _validate_description_length as dataset_validate +from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate + + +class TestDescriptionValidationUnit: + """Unit tests for description validation functions in App and Dataset APIs""" + + def test_app_validate_description_length_valid(self): + """Test App validation function with valid descriptions""" + # Empty string should be valid + assert app_validate("") == "" + + # None should be valid + assert app_validate(None) is None + + # Short description should be valid + short_desc = "Short description" + assert app_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert app_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert app_validate(just_under) == just_under + + def test_app_validate_description_length_invalid(self): + """Test App validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + app_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + app_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 1000 characters should fail + very_long = "x" * 1000 + with pytest.raises(ValueError) as exc_info: + app_validate(very_long) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_dataset_validate_description_length_valid(self): + """Test Dataset validation function with valid descriptions""" + # Empty string should be valid + assert dataset_validate("") == "" + + # Short description should be valid + short_desc = "Short description" + assert dataset_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert dataset_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert dataset_validate(just_under) == just_under + + def test_dataset_validate_description_length_invalid(self): + """Test Dataset validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + dataset_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + dataset_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_service_dataset_validate_description_length_valid(self): + """Test Service Dataset validation function with valid descriptions""" + # Empty string should be valid + assert service_dataset_validate("") == "" + + # None should be valid + assert service_dataset_validate(None) is None + + # Short description should be valid + short_desc = "Short description" + assert service_dataset_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert service_dataset_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert service_dataset_validate(just_under) == just_under + + def test_service_dataset_validate_description_length_invalid(self): + """Test Service Dataset validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + service_dataset_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + service_dataset_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_app_dataset_validation_consistency(self): + """Test that App and Dataset validation functions behave identically""" + test_cases = [ + "", # Empty string + "Short description", # Normal description + "x" * 100, # Medium description + "x" * 400, # Exactly at limit + ] + + # Test valid cases produce same results + for test_desc in test_cases: + assert app_validate(test_desc) == dataset_validate(test_desc) == service_dataset_validate(test_desc) + + # Test invalid cases produce same errors + invalid_cases = [ + "x" * 401, # Just over limit + "x" * 500, # Way over limit + "x" * 1000, # Very long + ] + + for invalid_desc in invalid_cases: + app_error = None + dataset_error = None + service_dataset_error = None + + # Capture App validation error + try: + app_validate(invalid_desc) + except ValueError as e: + app_error = str(e) + + # Capture Dataset validation error + try: + dataset_validate(invalid_desc) + except ValueError as e: + dataset_error = str(e) + + # Capture Service Dataset validation error + try: + service_dataset_validate(invalid_desc) + except ValueError as e: + service_dataset_error = str(e) + + # All should produce errors + assert app_error is not None, f"App validation should fail for {len(invalid_desc)} characters" + assert dataset_error is not None, f"Dataset validation should fail for {len(invalid_desc)} characters" + error_msg = f"Service Dataset validation should fail for {len(invalid_desc)} characters" + assert service_dataset_error is not None, error_msg + + # Errors should be identical + error_msg = f"Error messages should be identical for {len(invalid_desc)} characters" + assert app_error == dataset_error == service_dataset_error, error_msg + assert app_error == "Description cannot exceed 400 characters." + + def test_boundary_values(self): + """Test boundary values around the 400 character limit""" + boundary_tests = [ + (0, True), # Empty + (1, True), # Minimum + (399, True), # Just under limit + (400, True), # Exactly at limit + (401, False), # Just over limit + (402, False), # Over limit + (500, False), # Way over limit + ] + + for length, should_pass in boundary_tests: + test_desc = "x" * length + + if should_pass: + # Should not raise exception + assert app_validate(test_desc) == test_desc + assert dataset_validate(test_desc) == test_desc + assert service_dataset_validate(test_desc) == test_desc + else: + # Should raise ValueError + with pytest.raises(ValueError): + app_validate(test_desc) + with pytest.raises(ValueError): + dataset_validate(test_desc) + with pytest.raises(ValueError): + service_dataset_validate(test_desc) + + def test_special_characters(self): + """Test validation with special characters, Unicode, etc.""" + # Unicode characters + unicode_desc = "测试描述" * 100 # Chinese characters + if len(unicode_desc) <= 400: + assert app_validate(unicode_desc) == unicode_desc + assert dataset_validate(unicode_desc) == unicode_desc + assert service_dataset_validate(unicode_desc) == unicode_desc + + # Special characters + special_desc = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?" * 10 + if len(special_desc) <= 400: + assert app_validate(special_desc) == special_desc + assert dataset_validate(special_desc) == special_desc + assert service_dataset_validate(special_desc) == special_desc + + # Mixed content + mixed_desc = "Mixed content: 测试 123 !@# " * 15 + if len(mixed_desc) <= 400: + assert app_validate(mixed_desc) == mixed_desc + assert dataset_validate(mixed_desc) == mixed_desc + assert service_dataset_validate(mixed_desc) == mixed_desc + elif len(mixed_desc) > 400: + with pytest.raises(ValueError): + app_validate(mixed_desc) + with pytest.raises(ValueError): + dataset_validate(mixed_desc) + with pytest.raises(ValueError): + service_dataset_validate(mixed_desc) + + def test_whitespace_handling(self): + """Test validation with various whitespace scenarios""" + # Leading/trailing whitespace + whitespace_desc = " Description with whitespace " + if len(whitespace_desc) <= 400: + assert app_validate(whitespace_desc) == whitespace_desc + assert dataset_validate(whitespace_desc) == whitespace_desc + assert service_dataset_validate(whitespace_desc) == whitespace_desc + + # Newlines and tabs + multiline_desc = "Line 1\nLine 2\tTabbed content" + if len(multiline_desc) <= 400: + assert app_validate(multiline_desc) == multiline_desc + assert dataset_validate(multiline_desc) == multiline_desc + assert service_dataset_validate(multiline_desc) == multiline_desc + + # Only whitespace over limit + only_spaces = " " * 401 + with pytest.raises(ValueError): + app_validate(only_spaces) + with pytest.raises(ValueError): + dataset_validate(only_spaces) + with pytest.raises(ValueError): + service_dataset_validate(only_spaces) diff --git a/api/tests/unit_tests/controllers/console/test_files_security.py b/api/tests/unit_tests/controllers/console/test_files_security.py new file mode 100644 index 0000000000..cb5562d345 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_files_security.py @@ -0,0 +1,278 @@ +import io +from unittest.mock import patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.common.errors import FilenameNotExistsError +from controllers.console.error import ( + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) +from services.errors.file import FileTooLargeError as ServiceFileTooLargeError +from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError + + +class TestFileUploadSecurity: + """Test file upload security logic without complex framework setup""" + + # Test 1: Basic file validation + def test_should_validate_file_presence(self): + """Test that missing file is detected""" + from flask import Flask, request + + app = Flask(__name__) + + with app.test_request_context(method="POST", data={}): + # Simulate the check in FileApi.post() + if "file" not in request.files: + with pytest.raises(NoFileUploadedError): + raise NoFileUploadedError() + + def test_should_validate_multiple_files(self): + """Test that multiple files are rejected""" + from flask import Flask, request + + app = Flask(__name__) + + file_data = { + "file": (io.BytesIO(b"content1"), "file1.txt", "text/plain"), + "file2": (io.BytesIO(b"content2"), "file2.txt", "text/plain"), + } + + with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"): + # Simulate the check in FileApi.post() + if len(request.files) > 1: + with pytest.raises(TooManyFilesError): + raise TooManyFilesError() + + def test_should_validate_empty_filename(self): + """Test that empty filename is rejected""" + from flask import Flask, request + + app = Flask(__name__) + + file_data = {"file": (io.BytesIO(b"content"), "", "text/plain")} + + with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"): + file = request.files["file"] + if not file.filename: + with pytest.raises(FilenameNotExistsError): + raise FilenameNotExistsError + + # Test 2: Security - Filename sanitization + def test_should_detect_path_traversal_in_filename(self): + """Test protection against directory traversal attacks""" + dangerous_filenames = [ + "../../../etc/passwd", + "..\\..\\windows\\system32\\config\\sam", + "../../../../etc/shadow", + "./../../../sensitive.txt", + ] + + for filename in dangerous_filenames: + # Any filename containing .. should be considered dangerous + assert ".." in filename, f"Filename {filename} should be detected as path traversal" + + def test_should_detect_null_byte_injection(self): + """Test protection against null byte injection""" + dangerous_filenames = [ + "file.jpg\x00.php", + "document.pdf\x00.exe", + "image.png\x00.sh", + ] + + for filename in dangerous_filenames: + # Null bytes should be detected + assert "\x00" in filename, f"Filename {filename} should be detected as null byte injection" + + def test_should_sanitize_special_characters(self): + """Test that special characters in filenames are handled safely""" + # Characters that could be problematic in various contexts + dangerous_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\x00"] + + for char in dangerous_chars: + filename = f"file{char}name.txt" + # These characters should be detected or sanitized + assert any(c in filename for c in dangerous_chars) + + # Test 3: Permission validation + def test_should_validate_dataset_permissions(self): + """Test dataset upload permission logic""" + + class MockUser: + is_dataset_editor = False + + user = MockUser() + source = "datasets" + + # Simulate the permission check in FileApi.post() + if source == "datasets" and not user.is_dataset_editor: + with pytest.raises(Forbidden): + raise Forbidden() + + def test_should_allow_general_upload_without_permission(self): + """Test general upload doesn't require dataset permission""" + + class MockUser: + is_dataset_editor = False + + user = MockUser() + source = None # General upload + + # This should not raise an exception + if source == "datasets" and not user.is_dataset_editor: + raise Forbidden() + # Test passes if no exception is raised + + # Test 4: Service error handling + @patch("services.file_service.FileService.upload_file") + def test_should_handle_file_too_large_error(self, mock_upload): + """Test that service FileTooLargeError is properly converted""" + mock_upload.side_effect = ServiceFileTooLargeError("File too large") + + try: + mock_upload(filename="test.txt", content=b"data", mimetype="text/plain", user=None, source=None) + except ServiceFileTooLargeError as e: + # Simulate the error conversion in FileApi.post() + with pytest.raises(FileTooLargeError): + raise FileTooLargeError(e.description) + + @patch("services.file_service.FileService.upload_file") + def test_should_handle_unsupported_file_type_error(self, mock_upload): + """Test that service UnsupportedFileTypeError is properly converted""" + mock_upload.side_effect = ServiceUnsupportedFileTypeError() + + try: + mock_upload( + filename="test.exe", content=b"data", mimetype="application/octet-stream", user=None, source=None + ) + except ServiceUnsupportedFileTypeError: + # Simulate the error conversion in FileApi.post() + with pytest.raises(UnsupportedFileTypeError): + raise UnsupportedFileTypeError() + + # Test 5: File type security + def test_should_identify_dangerous_file_extensions(self): + """Test detection of potentially dangerous file extensions""" + dangerous_extensions = [ + ".php", + ".PHP", + ".pHp", # PHP files (case variations) + ".exe", + ".EXE", # Executables + ".sh", + ".SH", # Shell scripts + ".bat", + ".BAT", # Batch files + ".cmd", + ".CMD", # Command files + ".ps1", + ".PS1", # PowerShell + ".jar", + ".JAR", # Java archives + ".vbs", + ".VBS", # VBScript + ] + + safe_extensions = [".txt", ".pdf", ".jpg", ".png", ".doc", ".docx"] + + # Just verify our test data is correct + for ext in dangerous_extensions: + assert ext.lower() in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"] + + for ext in safe_extensions: + assert ext.lower() not in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"] + + def test_should_detect_double_extensions(self): + """Test detection of double extension attacks""" + suspicious_filenames = [ + "image.jpg.php", + "document.pdf.exe", + "photo.png.sh", + "file.txt.bat", + ] + + for filename in suspicious_filenames: + # Check that these have multiple extensions + parts = filename.split(".") + assert len(parts) > 2, f"Filename {filename} should have multiple extensions" + + # Test 6: Configuration validation + def test_upload_configuration_structure(self): + """Test that upload configuration has correct structure""" + # Simulate the configuration returned by FileApi.get() + config = { + "file_size_limit": 15, + "batch_count_limit": 5, + "image_file_size_limit": 10, + "video_file_size_limit": 500, + "audio_file_size_limit": 50, + "workflow_file_upload_limit": 10, + } + + # Verify all required fields are present + required_fields = [ + "file_size_limit", + "batch_count_limit", + "image_file_size_limit", + "video_file_size_limit", + "audio_file_size_limit", + "workflow_file_upload_limit", + ] + + for field in required_fields: + assert field in config, f"Missing required field: {field}" + assert isinstance(config[field], int), f"Field {field} should be an integer" + assert config[field] > 0, f"Field {field} should be positive" + + # Test 7: Source parameter handling + def test_source_parameter_normalization(self): + """Test that source parameter is properly normalized""" + test_cases = [ + ("datasets", "datasets"), + ("other", None), + ("", None), + (None, None), + ] + + for input_source, expected in test_cases: + # Simulate the source normalization in FileApi.post() + source = "datasets" if input_source == "datasets" else None + if source not in ("datasets", None): + source = None + assert source == expected + + # Test 8: Boundary conditions + def test_should_handle_edge_case_file_sizes(self): + """Test handling of boundary file sizes""" + test_cases = [ + (0, "Empty file"), # 0 bytes + (1, "Single byte"), # 1 byte + (15 * 1024 * 1024 - 1, "Just under limit"), # Just under 15MB + (15 * 1024 * 1024, "At limit"), # Exactly 15MB + (15 * 1024 * 1024 + 1, "Just over limit"), # Just over 15MB + ] + + for size, description in test_cases: + # Just verify our test data + assert isinstance(size, int), f"{description}: Size should be integer" + assert size >= 0, f"{description}: Size should be non-negative" + + def test_should_handle_special_mime_types(self): + """Test handling of various MIME types""" + mime_type_tests = [ + ("application/octet-stream", "Generic binary"), + ("text/plain", "Plain text"), + ("image/jpeg", "JPEG image"), + ("application/pdf", "PDF document"), + ("", "Empty MIME type"), + (None, "None MIME type"), + ] + + for mime_type, description in mime_type_tests: + # Verify test data structure + if mime_type is not None: + assert isinstance(mime_type, str), f"{description}: MIME type should be string or None" diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py new file mode 100644 index 0000000000..5c484403a6 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py @@ -0,0 +1,336 @@ +""" +Unit tests for Service API File Preview endpoint +""" + +import uuid +from unittest.mock import Mock, patch + +import pytest + +from controllers.service_api.app.error import FileAccessDeniedError, FileNotFoundError +from controllers.service_api.app.file_preview import FilePreviewApi +from models.model import App, EndUser, Message, MessageFile, UploadFile + + +class TestFilePreviewApi: + """Test suite for FilePreviewApi""" + + @pytest.fixture + def file_preview_api(self): + """Create FilePreviewApi instance for testing""" + return FilePreviewApi() + + @pytest.fixture + def mock_app(self): + """Mock App model""" + app = Mock(spec=App) + app.id = str(uuid.uuid4()) + app.tenant_id = str(uuid.uuid4()) + return app + + @pytest.fixture + def mock_end_user(self): + """Mock EndUser model""" + end_user = Mock(spec=EndUser) + end_user.id = str(uuid.uuid4()) + return end_user + + @pytest.fixture + def mock_upload_file(self): + """Mock UploadFile model""" + upload_file = Mock(spec=UploadFile) + upload_file.id = str(uuid.uuid4()) + upload_file.name = "test_file.jpg" + upload_file.mime_type = "image/jpeg" + upload_file.size = 1024 + upload_file.key = "storage/key/test_file.jpg" + upload_file.tenant_id = str(uuid.uuid4()) + return upload_file + + @pytest.fixture + def mock_message_file(self): + """Mock MessageFile model""" + message_file = Mock(spec=MessageFile) + message_file.id = str(uuid.uuid4()) + message_file.upload_file_id = str(uuid.uuid4()) + message_file.message_id = str(uuid.uuid4()) + return message_file + + @pytest.fixture + def mock_message(self): + """Mock Message model""" + message = Mock(spec=Message) + message.id = str(uuid.uuid4()) + message.app_id = str(uuid.uuid4()) + return message + + def test_validate_file_ownership_success( + self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test successful file ownership validation""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up the mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # Execute the method + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + + # Assertions + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + def test_validate_file_ownership_file_not_found(self, file_preview_api): + """Test file ownership validation when MessageFile not found""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile not found + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Execute and assert exception + with pytest.raises(FileNotFoundError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "File not found in message context" in str(exc_info.value) + + def test_validate_file_ownership_access_denied(self, file_preview_api, mock_message_file): + """Test file ownership validation when Message not owned by app""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile found but Message not owned by app + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query - found + None, # Message query - not found (access denied) + ] + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "not owned by requesting app" in str(exc_info.value) + + def test_validate_file_ownership_upload_file_not_found(self, file_preview_api, mock_message_file, mock_message): + """Test file ownership validation when UploadFile not found""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile and Message found but UploadFile not found + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query - found + mock_message, # Message query - found + None, # UploadFile query - not found + ] + + # Execute and assert exception + with pytest.raises(FileNotFoundError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "Upload file record not found" in str(exc_info.value) + + def test_validate_file_ownership_tenant_mismatch( + self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test file ownership validation with tenant mismatch""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up tenant mismatch + mock_upload_file.tenant_id = "different_tenant_id" + mock_app.tenant_id = "app_tenant_id" + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "tenant mismatch" in str(exc_info.value) + + def test_validate_file_ownership_invalid_input(self, file_preview_api): + """Test file ownership validation with invalid input""" + + # Test with empty file_id + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership("", "app_id") + assert "Invalid file or app identifier" in str(exc_info.value) + + # Test with empty app_id + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership("file_id", "") + assert "Invalid file or app identifier" in str(exc_info.value) + + def test_build_file_response_basic(self, file_preview_api, mock_upload_file): + """Test basic file response building""" + mock_generator = Mock() + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Check response properties + assert response.mimetype == mock_upload_file.mime_type + assert response.direct_passthrough is True + assert response.headers["Content-Length"] == str(mock_upload_file.size) + assert "Cache-Control" in response.headers + + def test_build_file_response_as_attachment(self, file_preview_api, mock_upload_file): + """Test file response building with attachment flag""" + mock_generator = Mock() + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, True) + + # Check attachment-specific headers + assert "attachment" in response.headers["Content-Disposition"] + assert mock_upload_file.name in response.headers["Content-Disposition"] + assert response.headers["Content-Type"] == "application/octet-stream" + + def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file): + """Test file response building for audio/video files""" + mock_generator = Mock() + mock_upload_file.mime_type = "video/mp4" + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Check Range support for media files + assert response.headers["Accept-Ranges"] == "bytes" + + def test_build_file_response_no_size(self, file_preview_api, mock_upload_file): + """Test file response building when size is unknown""" + mock_generator = Mock() + mock_upload_file.size = 0 # Unknown size + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Content-Length should not be set when size is unknown + assert "Content-Length" not in response.headers + + @patch("controllers.service_api.app.file_preview.storage") + def test_get_method_integration( + self, mock_storage, file_preview_api, mock_app, mock_end_user, mock_upload_file, mock_message_file, mock_message + ): + """Test the full GET method integration (without decorator)""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + mock_generator = Mock() + mock_storage.load.return_value = mock_generator + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse: + # Mock request parsing + mock_parser = Mock() + mock_parser.parse_args.return_value = {"as_attachment": False} + mock_reqparse.RequestParser.return_value = mock_parser + + # Test the core logic directly without Flask decorators + # Validate file ownership + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + # Test file response building + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + assert response is not None + + # Verify storage was called correctly + mock_storage.load.assert_not_called() # Since we're testing components separately + + @patch("controllers.service_api.app.file_preview.storage") + def test_storage_error_handling( + self, mock_storage, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test storage error handling in the core logic""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + # Mock storage error + mock_storage.load.side_effect = Exception("Storage error") + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries for validation + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # First validate file ownership works + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + # Test storage error handling + with pytest.raises(Exception) as exc_info: + mock_storage.load(mock_upload_file.key, stream=True) + + assert "Storage error" in str(exc_info.value) + + @patch("controllers.service_api.app.file_preview.logger") + def test_validate_file_ownership_unexpected_error_logging(self, mock_logger, file_preview_api): + """Test that unexpected errors are logged properly""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database query to raise unexpected exception + mock_db.session.query.side_effect = Exception("Unexpected database error") + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + # Verify error message + assert "File access validation failed" in str(exc_info.value) + + # Verify logging was called + mock_logger.exception.assert_called_once_with( + "Unexpected error during file ownership validation", + extra={"file_id": file_id, "app_id": app_id, "error": "Unexpected database error"}, + ) diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py index 209f8b7c57..1dc380ad0b 100644 --- a/api/tests/unit_tests/core/ops/test_config_entity.py +++ b/api/tests/unit_tests/core/ops/test_config_entity.py @@ -102,9 +102,14 @@ class TestPhoenixConfig: assert config.project == "default" def test_endpoint_validation_with_path(self): - """Test endpoint validation normalizes URL by removing path""" - config = PhoenixConfig(endpoint="https://custom.phoenix.com/api/v1") - assert config.endpoint == "https://custom.phoenix.com" + """Test endpoint validation with path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + + def test_endpoint_validation_without_path(self): + """Test endpoint validation without path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") + assert config.endpoint == "https://app.phoenix.arize.com" class TestLangfuseConfig: @@ -368,13 +373,15 @@ class TestConfigIntegration: """Test that URL normalization works consistently across configs""" # Test that paths are removed from endpoints arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test") - phoenix_config = PhoenixConfig(endpoint="https://phoenix.com/api/v2/") + phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") aliyun_config = AliyunConfig( license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" ) assert arize_config.endpoint == "https://arize.com" - assert phoenix_config.endpoint == "https://phoenix.com" + assert phoenix_with_path_config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + assert phoenix_without_path_config.endpoint == "https://app.phoenix.arize.com" assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" def test_project_default_values(self): diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index 64d0d8c7e7..b33a83ba77 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,4 +1,4 @@ -from core.variables.types import SegmentType +from core.variables.types import ArrayValidation, SegmentType class TestSegmentTypeIsArrayType: @@ -17,7 +17,6 @@ class TestSegmentTypeIsArrayType: value is tested for the is_array_type method. """ # Arrange - all_segment_types = set(SegmentType) expected_array_types = [ SegmentType.ARRAY_ANY, SegmentType.ARRAY_STRING, @@ -58,3 +57,27 @@ class TestSegmentTypeIsArrayType: for seg_type in enum_values: is_array = seg_type.is_array_type() assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}" + + +class TestSegmentTypeIsValidArrayValidation: + """ + Test SegmentType.is_valid with array types using different validation strategies. + """ + + def test_array_validation_all_success(self): + value = ["hello", "world", "foo"] + assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + + def test_array_validation_all_fail(self): + value = ["hello", 123, "world"] + # Should return False, since 123 is not a string + assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + + def test_array_validation_first(self): + value = ["hello", 123, None] + assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST) + + def test_array_validation_none(self): + value = [1, 2, 3] + # validation is None, skip + assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index bb6d72f51e..3101f7dd34 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -49,7 +49,7 @@ def test_executor_with_json_body_and_number_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == [] + assert executor.params is None assert executor.json == {"number": 42} assert executor.data is None assert executor.files is None @@ -102,7 +102,7 @@ def test_executor_with_json_body_and_object_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == [] + assert executor.params is None assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} assert executor.data is None assert executor.files is None @@ -157,7 +157,7 @@ def test_executor_with_json_body_and_nested_object_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == [] + assert executor.params is None assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} assert executor.data is None assert executor.files is None @@ -245,7 +245,7 @@ def test_executor_with_form_data(): assert executor.url == "https://api.example.com/upload" assert "Content-Type" in executor.headers assert "multipart/form-data" in executor.headers["Content-Type"] - assert executor.params == [] + assert executor.params is None assert executor.json is None # '__multipart_placeholder__' is expected when no file inputs exist, # to ensure the request is treated as multipart/form-data by the backend. diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py new file mode 100644 index 0000000000..dd2bc21814 --- /dev/null +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -0,0 +1,168 @@ +import datetime +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.orm import Session + +from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs + + +class TestClearFreePlanTenantExpiredLogs: + """Unit tests for ClearFreePlanTenantExpiredLogs._clear_message_related_tables method.""" + + @pytest.fixture + def mock_session(self): + """Create a mock database session.""" + session = Mock(spec=Session) + session.query.return_value.filter.return_value.all.return_value = [] + session.query.return_value.filter.return_value.delete.return_value = 0 + return session + + @pytest.fixture + def mock_storage(self): + """Create a mock storage object.""" + storage = Mock() + storage.save.return_value = None + return storage + + @pytest.fixture + def sample_message_ids(self): + """Sample message IDs for testing.""" + return ["msg-1", "msg-2", "msg-3"] + + @pytest.fixture + def sample_records(self): + """Sample records for testing.""" + records = [] + for i in range(3): + record = Mock() + record.id = f"record-{i}" + record.to_dict.return_value = { + "id": f"record-{i}", + "message_id": f"msg-{i}", + "created_at": datetime.datetime.now().isoformat(), + } + records.append(record) + return records + + def test_clear_message_related_tables_empty_message_ids(self, mock_session): + """Test that method returns early when message_ids is empty.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", []) + + # Should not call any database operations + mock_session.query.assert_not_called() + mock_storage.save.assert_not_called() + + def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids): + """Test when no related records are found.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = [] + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call query for each related table but find no records + assert mock_session.query.call_count > 0 + mock_storage.save.assert_not_called() + + def test_clear_message_related_tables_with_records_and_to_dict( + self, mock_session, sample_message_ids, sample_records + ): + """Test when records are found and have to_dict method.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call to_dict on each record (called once per table, so 7 times total) + for record in sample_records: + assert record.to_dict.call_count == 7 + + # Should save backup data + assert mock_storage.save.call_count > 0 + + def test_clear_message_related_tables_with_records_no_to_dict(self, mock_session, sample_message_ids): + """Test when records are found but don't have to_dict method.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + # Create records without to_dict method + records = [] + for i in range(2): + record = Mock() + mock_table = Mock() + mock_id_column = Mock() + mock_id_column.name = "id" + mock_message_id_column = Mock() + mock_message_id_column.name = "message_id" + mock_table.columns = [mock_id_column, mock_message_id_column] + record.__table__ = mock_table + record.id = f"record-{i}" + record.message_id = f"msg-{i}" + del record.to_dict + records.append(record) + + # Mock records for first table only, empty for others + mock_session.query.return_value.filter.return_value.all.side_effect = [ + records, + [], + [], + [], + [], + [], + [], + ] + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should save backup data even without to_dict + assert mock_storage.save.call_count > 0 + + def test_clear_message_related_tables_storage_error_continues( + self, mock_session, sample_message_ids, sample_records + ): + """Test that method continues even when storage.save fails.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_storage.save.side_effect = Exception("Storage error") + + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + # Should not raise exception + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should still delete records even if backup fails + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids): + """Test that method continues even when record serialization fails.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + record = Mock() + record.id = "record-1" + record.to_dict.side_effect = Exception("Serialization error") + + mock_session.query.return_value.filter.return_value.all.return_value = [record] + + # Should not raise exception + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should still delete records even if serialization fails + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records): + """Test that deletion is called for found records.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call delete for each table that has records + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_logging_output( + self, mock_session, sample_message_ids, sample_records, capsys + ): + """Test that logging output is generated.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + pass diff --git a/api/uv.lock b/api/uv.lock index 0bce38812e..16624dc8fd 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11, <3.13" resolution-markers = [ "python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'", @@ -983,6 +983,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/1f/935d0810b73184a1d306f92458cb0a2e9b0de2377f536da874e063b8e422/clickhouse_connect-0.7.19-cp312-cp312-win_amd64.whl", hash = "sha256:b771ca6a473d65103dcae82810d3a62475c5372fc38d8f211513c72b954fb020", size = 239584, upload-time = "2024-08-21T21:36:22.105Z" }, ] +[[package]] +name = "clickzetta-connector-python" +version = "0.8.102" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "future" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "pyarrow" }, + { name = "python-dateutil" }, + { name = "requests" }, + { name = "sqlalchemy" }, + { name = "urllib3" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/e5/23dcc950e873127df0135cf45144062a3207f5d2067259c73854e8ce7228/clickzetta_connector_python-0.8.102-py3-none-any.whl", hash = "sha256:c45486ae77fd82df7113ec67ec50e772372588d79c23757f8ee6291a057994a7", size = 77861, upload-time = "2025-07-17T03:11:59.543Z" }, +] + [[package]] name = "cloudscraper" version = "1.2.71" @@ -1265,6 +1284,8 @@ dependencies = [ { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-instrumentation-celery" }, { name = "opentelemetry-instrumentation-flask" }, + { name = "opentelemetry-instrumentation-redis" }, + { name = "opentelemetry-instrumentation-requests" }, { name = "opentelemetry-instrumentation-sqlalchemy" }, { name = "opentelemetry-propagator-b3" }, { name = "opentelemetry-proto" }, @@ -1318,6 +1339,7 @@ dev = [ { name = "pytest-mock" }, { name = "ruff" }, { name = "scipy-stubs" }, + { name = "testcontainers" }, { name = "types-aiofiles" }, { name = "types-beautifulsoup4" }, { name = "types-cachetools" }, @@ -1380,6 +1402,7 @@ vdb = [ { name = "alibabacloud-tea-openapi" }, { name = "chromadb" }, { name = "clickhouse-connect" }, + { name = "clickzetta-connector-python" }, { name = "couchbase" }, { name = "elasticsearch" }, { name = "mo-vector" }, @@ -1447,6 +1470,8 @@ requires-dist = [ { name = "opentelemetry-instrumentation", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-celery", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-flask", specifier = "==0.48b0" }, + { name = "opentelemetry-instrumentation-redis", specifier = "==0.48b0" }, + { name = "opentelemetry-instrumentation-requests", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.48b0" }, { name = "opentelemetry-propagator-b3", specifier = "==1.27.0" }, { name = "opentelemetry-proto", specifier = "==1.27.0" }, @@ -1500,6 +1525,7 @@ dev = [ { name = "pytest-mock", specifier = "~=3.14.0" }, { name = "ruff", specifier = "~=0.12.3" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, + { name = "testcontainers", specifier = "~=4.10.0" }, { name = "types-aiofiles", specifier = "~=24.1.0" }, { name = "types-beautifulsoup4", specifier = "~=4.12.0" }, { name = "types-cachetools", specifier = "~=5.5.0" }, @@ -1562,6 +1588,7 @@ vdb = [ { name = "alibabacloud-tea-openapi", specifier = "~=0.3.9" }, { name = "chromadb", specifier = "==0.5.20" }, { name = "clickhouse-connect", specifier = "~=0.7.16" }, + { name = "clickzetta-connector-python", specifier = ">=0.8.102" }, { name = "couchbase", specifier = "~=4.3.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, { name = "mo-vector", specifier = "~=0.1.13" }, @@ -1600,6 +1627,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, ] +[[package]] +name = "docker" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "requests" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" }, +] + [[package]] name = "docstring-parser" version = "0.16" @@ -2091,7 +2132,7 @@ wheels = [ [[package]] name = "google-cloud-bigquery" -version = "3.34.0" +version = "3.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core", extra = ["grpc"] }, @@ -2102,9 +2143,9 @@ dependencies = [ { name = "python-dateutil" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/24/f9/e9da2d56d7028f05c0e2f5edf6ce43c773220c3172666c3dd925791d763d/google_cloud_bigquery-3.34.0.tar.gz", hash = "sha256:5ee1a78ba5c2ccb9f9a8b2bf3ed76b378ea68f49b6cac0544dc55cc97ff7c1ce", size = 489091, upload-time = "2025-05-29T17:18:06.03Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/2f/3dda76b3ec029578838b1fe6396e6b86eb574200352240e23dea49265bb7/google_cloud_bigquery-3.30.0.tar.gz", hash = "sha256:7e27fbafc8ed33cc200fe05af12ecd74d279fe3da6692585a3cef7aee90575b6", size = 474389, upload-time = "2025-02-27T18:49:45.416Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/7e/7115c4f67ca0bc678f25bff1eab56cc37d06eb9a3978940b2ebd0705aa0a/google_cloud_bigquery-3.34.0-py3-none-any.whl", hash = "sha256:de20ded0680f8136d92ff5256270b5920dfe4fae479f5d0f73e90e5df30b1cf7", size = 253555, upload-time = "2025-05-29T17:18:02.904Z" }, + { url = "https://files.pythonhosted.org/packages/0c/6d/856a6ca55c1d9d99129786c929a27dd9d31992628ebbff7f5d333352981f/google_cloud_bigquery-3.30.0-py2.py3-none-any.whl", hash = "sha256:f4d28d846a727f20569c9b2d2f4fa703242daadcb2ec4240905aa485ba461877", size = 247885, upload-time = "2025-02-27T18:49:43.454Z" }, ] [[package]] @@ -3654,6 +3695,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/3d/fcde4f8f0bf9fa1ee73a12304fa538076fb83fe0a2ae966ab0f0b7da5109/opentelemetry_instrumentation_flask-0.48b0-py3-none-any.whl", hash = "sha256:26b045420b9d76e85493b1c23fcf27517972423480dc6cf78fd6924248ba5808", size = 14588, upload-time = "2024-08-28T21:26:58.504Z" }, ] +[[package]] +name = "opentelemetry-instrumentation-redis" +version = "0.48b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/70/be/92e98e4c7f275be3d373899a41b0a7d4df64266657d985dbbdb9a54de0d5/opentelemetry_instrumentation_redis-0.48b0.tar.gz", hash = "sha256:61e33e984b4120e1b980d9fba6e9f7ca0c8d972f9970654d8f6e9f27fa115a8c", size = 10511, upload-time = "2024-08-28T21:28:15.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/40/892f30d400091106309cc047fd3f6d76a828fedd984a953fd5386b78a2fb/opentelemetry_instrumentation_redis-0.48b0-py3-none-any.whl", hash = "sha256:48c7f2e25cbb30bde749dc0d8b9c74c404c851f554af832956b9630b27f5bcb7", size = 11610, upload-time = "2024-08-28T21:27:18.759Z" }, +] + +[[package]] +name = "opentelemetry-instrumentation-requests" +version = "0.48b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "opentelemetry-util-http" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/ac/5eb78efde21ff21d0ad5dc8c6cc6a0f8ae482ce8a46293c2f45a628b6166/opentelemetry_instrumentation_requests-0.48b0.tar.gz", hash = "sha256:67ab9bd877a0352ee0db4616c8b4ae59736ddd700c598ed907482d44f4c9a2b3", size = 14120, upload-time = "2024-08-28T21:28:16.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/df/0df9226d1b14f29d23c07e6194b9fd5ad50e7d987b7fd13df7dcf718aeb1/opentelemetry_instrumentation_requests-0.48b0-py3-none-any.whl", hash = "sha256:d4f01852121d0bd4c22f14f429654a735611d4f7bf3cf93f244bdf1489b2233d", size = 12366, upload-time = "2024-08-28T21:27:20.771Z" }, +] + [[package]] name = "opentelemetry-instrumentation-sqlalchemy" version = "0.48b0" @@ -3868,11 +3939,11 @@ wheels = [ [[package]] name = "packaging" -version = "24.2" +version = "23.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950, upload-time = "2024-11-08T09:47:47.202Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fb/2b/9b9c33ffed44ee921d0967086d653047286054117d584f1b1a7c22ceaf7b/packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5", size = 146714, upload-time = "2023-10-01T13:50:05.279Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451, upload-time = "2024-11-08T09:47:44.722Z" }, + { url = "https://files.pythonhosted.org/packages/ec/1a/610693ac4ee14fcdf2d9bf3c493370e4f2ef7ae2e19217d7a237ff42367d/packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7", size = 53011, upload-time = "2023-10-01T13:50:03.745Z" }, ] [[package]] @@ -4252,6 +4323,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" }, ] +[[package]] +name = "pyarrow" +version = "14.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d7/8b/d18b7eb6fb22e5ed6ffcbc073c85dae635778dbd1270a6cf5d750b031e84/pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025", size = 1063645, upload-time = "2023-12-18T15:43:41.625Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/8a/411ef0b05483076b7f548c74ccaa0f90c1e60d3875db71a821f6ffa8cf42/pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b", size = 26904455, upload-time = "2023-12-18T15:40:43.477Z" }, + { url = "https://files.pythonhosted.org/packages/6c/6c/882a57798877e3a49ba54d8e0540bea24aed78fb42e1d860f08c3449c75e/pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23", size = 23997116, upload-time = "2023-12-18T15:40:48.533Z" }, + { url = "https://files.pythonhosted.org/packages/ec/3f/ef47fe6192ce4d82803a073db449b5292135406c364a7fc49dfbcd34c987/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200", size = 35944575, upload-time = "2023-12-18T15:40:55.128Z" }, + { url = "https://files.pythonhosted.org/packages/1a/90/2021e529d7f234a3909f419d4341d53382541ef77d957fa274a99c533b18/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696", size = 38079719, upload-time = "2023-12-18T15:41:02.565Z" }, + { url = "https://files.pythonhosted.org/packages/30/a9/474caf5fd54a6d5315aaf9284c6e8f5d071ca825325ad64c53137b646e1f/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a", size = 35429706, upload-time = "2023-12-18T15:41:09.955Z" }, + { url = "https://files.pythonhosted.org/packages/d9/f8/cfba56f5353e51c19b0c240380ce39483f4c76e5c4aee5a000f3d75b72da/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02", size = 38001476, upload-time = "2023-12-18T15:41:16.372Z" }, + { url = "https://files.pythonhosted.org/packages/43/3f/7bdf7dc3b3b0cfdcc60760e7880954ba99ccd0bc1e0df806f3dd61bc01cd/pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b", size = 24576230, upload-time = "2023-12-18T15:41:22.561Z" }, + { url = "https://files.pythonhosted.org/packages/69/5b/d8ab6c20c43b598228710e4e4a6cba03a01f6faa3d08afff9ce76fd0fd47/pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944", size = 26819585, upload-time = "2023-12-18T15:41:27.59Z" }, + { url = "https://files.pythonhosted.org/packages/2d/29/bed2643d0dd5e9570405244a61f6db66c7f4704a6e9ce313f84fa5a3675a/pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5", size = 23965222, upload-time = "2023-12-18T15:41:32.449Z" }, + { url = "https://files.pythonhosted.org/packages/2a/34/da464632e59a8cdd083370d69e6c14eae30221acb284f671c6bc9273fadd/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422", size = 35942036, upload-time = "2023-12-18T15:41:38.767Z" }, + { url = "https://files.pythonhosted.org/packages/a8/ff/cbed4836d543b29f00d2355af67575c934999ff1d43e3f438ab0b1b394f1/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07", size = 38089266, upload-time = "2023-12-18T15:41:47.617Z" }, + { url = "https://files.pythonhosted.org/packages/38/41/345011cb831d3dbb2dab762fc244c745a5df94b199223a99af52a5f7dff6/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591", size = 35404468, upload-time = "2023-12-18T15:41:54.49Z" }, + { url = "https://files.pythonhosted.org/packages/fd/af/2fc23ca2068ff02068d8dabf0fb85b6185df40ec825973470e613dbd8790/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379", size = 38003134, upload-time = "2023-12-18T15:42:01.593Z" }, + { url = "https://files.pythonhosted.org/packages/95/1f/9d912f66a87e3864f694e000977a6a70a644ea560289eac1d733983f215d/pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d", size = 25043754, upload-time = "2023-12-18T15:42:07.108Z" }, +] + [[package]] name = "pyasn1" version = "0.6.1" @@ -5468,6 +5564,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" }, ] +[[package]] +name = "testcontainers" +version = "4.10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docker" }, + { name = "python-dotenv" }, + { name = "typing-extensions" }, + { name = "urllib3" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/49/9c618aff1c50121d183cdfbc3a4a5cf2727a2cde1893efe6ca55c7009196/testcontainers-4.10.0.tar.gz", hash = "sha256:03f85c3e505d8b4edeb192c72a961cebbcba0dd94344ae778b4a159cb6dcf8d3", size = 63327, upload-time = "2025-04-02T16:13:27.582Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/0a/824b0c1ecf224802125279c3effff2e25ed785ed046e67da6e53d928de4c/testcontainers-4.10.0-py3-none-any.whl", hash = "sha256:31ed1a81238c7e131a2a29df6db8f23717d892b592fa5a1977fd0dcd0c23fc23", size = 107414, upload-time = "2025-04-02T16:13:25.785Z" }, +] + [[package]] name = "tidb-vector" version = "0.0.9" diff --git a/dev/pytest/pytest_all_tests.sh b/dev/pytest/pytest_all_tests.sh index 30898b4fcf..9123b2f8ad 100755 --- a/dev/pytest/pytest_all_tests.sh +++ b/dev/pytest/pytest_all_tests.sh @@ -15,3 +15,6 @@ dev/pytest/pytest_workflow.sh # Unit tests dev/pytest/pytest_unit_tests.sh + +# TestContainers tests +dev/pytest/pytest_testcontainers.sh diff --git a/dev/pytest/pytest_testcontainers.sh b/dev/pytest/pytest_testcontainers.sh new file mode 100755 index 0000000000..e55a436138 --- /dev/null +++ b/dev/pytest/pytest_testcontainers.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -x + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +cd "$SCRIPT_DIR/../.." + +pytest api/tests/test_containers_integration_tests diff --git a/docker/.env.example b/docker/.env.example index 7ecdf899fe..1b1e9cad7b 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -333,6 +333,25 @@ OPENDAL_SCHEME=fs # Configurations for OpenDAL Local File System. OPENDAL_FS_ROOT=storage +# ClickZetta Volume Configuration (for storage backend) +# To use ClickZetta Volume as storage backend, set STORAGE_TYPE=clickzetta-volume +# Note: ClickZetta Volume will reuse the existing CLICKZETTA_* connection parameters + +# Volume type selection (three types available): +# - user: Personal/small team use, simple config, user-level permissions +# - table: Enterprise multi-tenant, smart routing, table-level + user-level permissions +# - external: Data lake integration, external storage connection, volume-level + storage-level permissions +CLICKZETTA_VOLUME_TYPE=user + +# External Volume name (required only when TYPE=external) +CLICKZETTA_VOLUME_NAME= + +# Table Volume table prefix (used only when TYPE=table) +CLICKZETTA_VOLUME_TABLE_PREFIX=dataset_ + +# Dify file directory prefix (isolates from other apps, recommended to keep default) +CLICKZETTA_VOLUME_DIFY_PREFIX=dify_km + # S3 Configuration # S3_ENDPOINT= @@ -416,7 +435,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -653,6 +672,21 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com TABLESTORE_INSTANCE_NAME=instance-name TABLESTORE_ACCESS_KEY_ID=xxx TABLESTORE_ACCESS_KEY_SECRET=xxx +TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false + +# Clickzetta configuration, only available when VECTOR_STORE is `clickzetta` +CLICKZETTA_USERNAME= +CLICKZETTA_PASSWORD= +CLICKZETTA_INSTANCE= +CLICKZETTA_SERVICE=api.clickzetta.com +CLICKZETTA_WORKSPACE=quick_start +CLICKZETTA_VCLUSTER=default_ap +CLICKZETTA_SCHEMA=dify +CLICKZETTA_BATCH_SIZE=100 +CLICKZETTA_ENABLE_INVERTED_INDEX=true +CLICKZETTA_ANALYZER_TYPE=chinese +CLICKZETTA_ANALYZER_MODE=smart +CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # ------------------------------ # Knowledge Configuration diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index fe8e4602b7..b5ae4a425c 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -538,7 +538,7 @@ services: milvus-standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.5.0-beta + image: milvusdb/milvus:v2.5.15 profiles: - milvus command: [ 'milvus', 'run', 'standalone' ] diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index ae83aa758d..8e2d40883d 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -93,6 +93,10 @@ x-shared-env: &shared-api-worker-env STORAGE_TYPE: ${STORAGE_TYPE:-opendal} OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs} OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage} + CLICKZETTA_VOLUME_TYPE: ${CLICKZETTA_VOLUME_TYPE:-user} + CLICKZETTA_VOLUME_NAME: ${CLICKZETTA_VOLUME_NAME:-} + CLICKZETTA_VOLUME_TABLE_PREFIX: ${CLICKZETTA_VOLUME_TABLE_PREFIX:-dataset_} + CLICKZETTA_VOLUME_DIFY_PREFIX: ${CLICKZETTA_VOLUME_DIFY_PREFIX:-dify_km} S3_ENDPOINT: ${S3_ENDPOINT:-} S3_REGION: ${S3_REGION:-us-east-1} S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai} @@ -312,6 +316,19 @@ x-shared-env: &shared-api-worker-env TABLESTORE_INSTANCE_NAME: ${TABLESTORE_INSTANCE_NAME:-instance-name} TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx} TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx} + TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false} + CLICKZETTA_USERNAME: ${CLICKZETTA_USERNAME:-} + CLICKZETTA_PASSWORD: ${CLICKZETTA_PASSWORD:-} + CLICKZETTA_INSTANCE: ${CLICKZETTA_INSTANCE:-} + CLICKZETTA_SERVICE: ${CLICKZETTA_SERVICE:-api.clickzetta.com} + CLICKZETTA_WORKSPACE: ${CLICKZETTA_WORKSPACE:-quick_start} + CLICKZETTA_VCLUSTER: ${CLICKZETTA_VCLUSTER:-default_ap} + CLICKZETTA_SCHEMA: ${CLICKZETTA_SCHEMA:-dify} + CLICKZETTA_BATCH_SIZE: ${CLICKZETTA_BATCH_SIZE:-100} + CLICKZETTA_ENABLE_INVERTED_INDEX: ${CLICKZETTA_ENABLE_INVERTED_INDEX:-true} + CLICKZETTA_ANALYZER_TYPE: ${CLICKZETTA_ANALYZER_TYPE:-chinese} + CLICKZETTA_ANALYZER_MODE: ${CLICKZETTA_ANALYZER_MODE:-smart} + CLICKZETTA_VECTOR_DISTANCE_FUNCTION: ${CLICKZETTA_VECTOR_DISTANCE_FUNCTION:-cosine_distance} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} @@ -1086,7 +1103,7 @@ services: milvus-standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.5.0-beta + image: milvusdb/milvus:v2.5.15 profiles: - milvus command: [ 'milvus', 'run', 'standalone' ] diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts index 173aa96118..b4c4f1540d 100644 --- a/web/__tests__/check-i18n.test.ts +++ b/web/__tests__/check-i18n.test.ts @@ -49,9 +49,9 @@ describe('check-i18n script functionality', () => { } vm.runInNewContext(transpile(content), context) - const translationObj = moduleExports.default || moduleExports + const translationObj = (context.module.exports as any).default || context.module.exports - if(!translationObj || typeof translationObj !== 'object') + if (!translationObj || typeof translationObj !== 'object') throw new Error(`Error parsing file: ${filePath}`) const nestedKeys: string[] = [] @@ -62,7 +62,7 @@ describe('check-i18n script functionality', () => { // This is an object (but not array), recurse into it but don't add it as a key iterateKeys(obj[key], nestedKey) } - else { + else { // This is a leaf node (string, number, boolean, array, etc.), add it as a key nestedKeys.push(nestedKey) } @@ -73,7 +73,7 @@ describe('check-i18n script functionality', () => { const fileKeys = nestedKeys.map(key => `${camelCaseFileName}.${key}`) allKeys.push(...fileKeys) } - catch (error) { + catch (error) { reject(error) } }) @@ -265,16 +265,12 @@ export default translation fs.writeFileSync(path.join(testZhDir, 'pages.ts'), file2Content) const allEnKeys = await getKeysFromLanguage('en-US') - const allZhKeys = await getKeysFromLanguage('zh-Hans') // Test file filtering logic const targetFile = 'components' const filteredEnKeys = allEnKeys.filter(key => key.startsWith(targetFile.replace(/[-_](.)/g, (_, c) => c.toUpperCase())), ) - const filteredZhKeys = allZhKeys.filter(key => - key.startsWith(targetFile.replace(/[-_](.)/g, (_, c) => c.toUpperCase())), - ) expect(allEnKeys).toHaveLength(4) // 2 keys from each file expect(filteredEnKeys).toHaveLength(2) // only components keys @@ -566,4 +562,201 @@ export default translation expect(enKeys.length - zhKeysExtra.length).toBe(-2) // -2 means 2 extra keys }) }) + + describe('Auto-remove multiline key-value pairs', () => { + // Helper function to simulate removeExtraKeysFromFile logic + function removeExtraKeysFromFile(content: string, keysToRemove: string[]): string { + const lines = content.split('\n') + const linesToRemove: number[] = [] + + for (const keyToRemove of keysToRemove) { + let targetLineIndex = -1 + const linesToRemoveForKey: number[] = [] + + // Find the key line (simplified for single-level keys in test) + for (let i = 0; i < lines.length; i++) { + const line = lines[i] + const keyPattern = new RegExp(`^\\s*${keyToRemove}\\s*:`) + if (keyPattern.test(line)) { + targetLineIndex = i + break + } + } + + if (targetLineIndex !== -1) { + linesToRemoveForKey.push(targetLineIndex) + + // Check if this is a multiline key-value pair + const keyLine = lines[targetLineIndex] + const trimmedKeyLine = keyLine.trim() + + // If key line ends with ":" (not complete value), it's likely multiline + if (trimmedKeyLine.endsWith(':') && !trimmedKeyLine.includes('{') && !trimmedKeyLine.match(/:\s*['"`]/)) { + // Find the value lines that belong to this key + let currentLine = targetLineIndex + 1 + let foundValue = false + + while (currentLine < lines.length) { + const line = lines[currentLine] + const trimmed = line.trim() + + // Skip empty lines + if (trimmed === '') { + currentLine++ + continue + } + + // Check if this line starts a new key (indicates end of current value) + if (trimmed.match(/^\w+\s*:/)) + break + + // Check if this line is part of the value + if (trimmed.startsWith('\'') || trimmed.startsWith('"') || trimmed.startsWith('`') || foundValue) { + linesToRemoveForKey.push(currentLine) + foundValue = true + + // Check if this line ends the value (ends with quote and comma/no comma) + if ((trimmed.endsWith('\',') || trimmed.endsWith('",') || trimmed.endsWith('`,') + || trimmed.endsWith('\'') || trimmed.endsWith('"') || trimmed.endsWith('`')) + && !trimmed.startsWith('//')) + break + } + else { + break + } + + currentLine++ + } + } + + linesToRemove.push(...linesToRemoveForKey) + } + } + + // Remove duplicates and sort in reverse order + const uniqueLinesToRemove = [...new Set(linesToRemove)].sort((a, b) => b - a) + + for (const lineIndex of uniqueLinesToRemove) + lines.splice(lineIndex, 1) + + return lines.join('\n') + } + + it('should remove single-line key-value pairs correctly', () => { + const content = `const translation = { + keepThis: 'This should stay', + removeThis: 'This should be removed', + alsoKeep: 'This should also stay', +} + +export default translation` + + const result = removeExtraKeysFromFile(content, ['removeThis']) + + expect(result).toContain('keepThis: \'This should stay\'') + expect(result).toContain('alsoKeep: \'This should also stay\'') + expect(result).not.toContain('removeThis: \'This should be removed\'') + }) + + it('should remove multiline key-value pairs completely', () => { + const content = `const translation = { + keepThis: 'This should stay', + removeMultiline: + 'This is a multiline value that should be removed completely', + alsoKeep: 'This should also stay', +} + +export default translation` + + const result = removeExtraKeysFromFile(content, ['removeMultiline']) + + expect(result).toContain('keepThis: \'This should stay\'') + expect(result).toContain('alsoKeep: \'This should also stay\'') + expect(result).not.toContain('removeMultiline:') + expect(result).not.toContain('This is a multiline value that should be removed completely') + }) + + it('should handle mixed single-line and multiline removals', () => { + const content = `const translation = { + keepThis: 'Keep this', + removeSingle: 'Remove this single line', + removeMultiline: + 'Remove this multiline value', + anotherMultiline: + 'Another multiline that spans multiple lines', + keepAnother: 'Keep this too', +} + +export default translation` + + const result = removeExtraKeysFromFile(content, ['removeSingle', 'removeMultiline', 'anotherMultiline']) + + expect(result).toContain('keepThis: \'Keep this\'') + expect(result).toContain('keepAnother: \'Keep this too\'') + expect(result).not.toContain('removeSingle:') + expect(result).not.toContain('removeMultiline:') + expect(result).not.toContain('anotherMultiline:') + expect(result).not.toContain('Remove this single line') + expect(result).not.toContain('Remove this multiline value') + expect(result).not.toContain('Another multiline that spans multiple lines') + }) + + it('should properly detect multiline vs single-line patterns', () => { + const multilineContent = `const translation = { + singleLine: 'This is single line', + multilineKey: + 'This is multiline', + keyWithColon: 'Value with: colon inside', + objectKey: { + nested: 'value' + }, +} + +export default translation` + + // Test that single line with colon in value is not treated as multiline + const result1 = removeExtraKeysFromFile(multilineContent, ['keyWithColon']) + expect(result1).not.toContain('keyWithColon:') + expect(result1).not.toContain('Value with: colon inside') + + // Test that true multiline is handled correctly + const result2 = removeExtraKeysFromFile(multilineContent, ['multilineKey']) + expect(result2).not.toContain('multilineKey:') + expect(result2).not.toContain('This is multiline') + + // Test that object key removal works (note: this is a simplified test) + // In real scenario, object removal would be more complex + const result3 = removeExtraKeysFromFile(multilineContent, ['objectKey']) + expect(result3).not.toContain('objectKey: {') + // Note: Our simplified test function doesn't handle nested object removal perfectly + // This is acceptable as it's testing the main multiline string removal functionality + }) + + it('should handle real-world Polish translation structure', () => { + const polishContent = `const translation = { + createApp: 'UTWÓRZ APLIKACJĘ', + newApp: { + captionAppType: 'Jaki typ aplikacji chcesz stworzyć?', + chatbotDescription: + 'Zbuduj aplikację opartą na czacie. Ta aplikacja używa formatu pytań i odpowiedzi.', + agentDescription: + 'Zbuduj inteligentnego agenta, który może autonomicznie wybierać narzędzia.', + basic: 'Podstawowy', + }, +} + +export default translation` + + const result = removeExtraKeysFromFile(polishContent, ['captionAppType', 'chatbotDescription', 'agentDescription']) + + expect(result).toContain('createApp: \'UTWÓRZ APLIKACJĘ\'') + expect(result).toContain('basic: \'Podstawowy\'') + expect(result).not.toContain('captionAppType:') + expect(result).not.toContain('chatbotDescription:') + expect(result).not.toContain('agentDescription:') + expect(result).not.toContain('Jaki typ aplikacji') + expect(result).not.toContain('Zbuduj aplikację opartą na czacie') + expect(result).not.toContain('Zbuduj inteligentnego agenta') + }) + }) }) diff --git a/web/__tests__/description-validation.test.tsx b/web/__tests__/description-validation.test.tsx new file mode 100644 index 0000000000..85263b035f --- /dev/null +++ b/web/__tests__/description-validation.test.tsx @@ -0,0 +1,97 @@ +/** + * Description Validation Test + * + * Tests for the 400-character description validation across App and Dataset + * creation and editing workflows to ensure consistent validation behavior. + */ + +describe('Description Validation Logic', () => { + // Simulate backend validation function + const validateDescriptionLength = (description?: string | null) => { + if (description && description.length > 400) + throw new Error('Description cannot exceed 400 characters.') + + return description + } + + describe('Backend Validation Function', () => { + test('allows description within 400 characters', () => { + const validDescription = 'x'.repeat(400) + expect(() => validateDescriptionLength(validDescription)).not.toThrow() + expect(validateDescriptionLength(validDescription)).toBe(validDescription) + }) + + test('allows empty description', () => { + expect(() => validateDescriptionLength('')).not.toThrow() + expect(() => validateDescriptionLength(null)).not.toThrow() + expect(() => validateDescriptionLength(undefined)).not.toThrow() + }) + + test('rejects description exceeding 400 characters', () => { + const invalidDescription = 'x'.repeat(401) + expect(() => validateDescriptionLength(invalidDescription)).toThrow( + 'Description cannot exceed 400 characters.', + ) + }) + }) + + describe('Backend Validation Consistency', () => { + test('App and Dataset have consistent validation limits', () => { + const maxLength = 400 + const validDescription = 'x'.repeat(maxLength) + const invalidDescription = 'x'.repeat(maxLength + 1) + + // Both should accept exactly 400 characters + expect(validDescription.length).toBe(400) + expect(() => validateDescriptionLength(validDescription)).not.toThrow() + + // Both should reject 401 characters + expect(invalidDescription.length).toBe(401) + expect(() => validateDescriptionLength(invalidDescription)).toThrow() + }) + + test('validation error messages are consistent', () => { + const expectedErrorMessage = 'Description cannot exceed 400 characters.' + + // This would be the error message from both App and Dataset backend validation + expect(expectedErrorMessage).toBe('Description cannot exceed 400 characters.') + + const invalidDescription = 'x'.repeat(401) + try { + validateDescriptionLength(invalidDescription) + } + catch (error) { + expect((error as Error).message).toBe(expectedErrorMessage) + } + }) + }) + + describe('Character Length Edge Cases', () => { + const testCases = [ + { length: 0, shouldPass: true, description: 'empty description' }, + { length: 1, shouldPass: true, description: '1 character' }, + { length: 399, shouldPass: true, description: '399 characters' }, + { length: 400, shouldPass: true, description: '400 characters (boundary)' }, + { length: 401, shouldPass: false, description: '401 characters (over limit)' }, + { length: 500, shouldPass: false, description: '500 characters' }, + { length: 1000, shouldPass: false, description: '1000 characters' }, + ] + + testCases.forEach(({ length, shouldPass, description }) => { + test(`handles ${description} correctly`, () => { + const testDescription = length > 0 ? 'x'.repeat(length) : '' + expect(testDescription.length).toBe(length) + + if (shouldPass) { + expect(() => validateDescriptionLength(testDescription)).not.toThrow() + expect(validateDescriptionLength(testDescription)).toBe(testDescription) + } + else { + expect(() => validateDescriptionLength(testDescription)).toThrow( + 'Description cannot exceed 400 characters.', + ) + } + }) + }) + }) +}) diff --git a/web/__tests__/document-detail-navigation-fix.test.tsx b/web/__tests__/document-detail-navigation-fix.test.tsx new file mode 100644 index 0000000000..200ed09ea9 --- /dev/null +++ b/web/__tests__/document-detail-navigation-fix.test.tsx @@ -0,0 +1,305 @@ +/** + * Document Detail Navigation Fix Verification Test + * + * This test specifically validates that the backToPrev function in the document detail + * component correctly preserves pagination and filter states. + */ + +import { fireEvent, render, screen } from '@testing-library/react' +import { useRouter } from 'next/navigation' +import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document' + +// Mock Next.js router +const mockPush = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: jest.fn(() => ({ + push: mockPush, + })), +})) + +// Mock the document service hooks +jest.mock('@/service/knowledge/use-document', () => ({ + useDocumentDetail: jest.fn(), + useDocumentMetadata: jest.fn(), + useInvalidDocumentList: jest.fn(() => jest.fn()), +})) + +// Mock other dependencies +jest.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContext: jest.fn(() => [null]), +})) + +jest.mock('@/service/use-base', () => ({ + useInvalid: jest.fn(() => jest.fn()), +})) + +jest.mock('@/service/knowledge/use-segment', () => ({ + useSegmentListKey: jest.fn(), + useChildSegmentListKey: jest.fn(), +})) + +// Create a minimal version of the DocumentDetail component that includes our fix +const DocumentDetailWithFix = ({ datasetId, documentId }: { datasetId: string; documentId: string }) => { + const router = useRouter() + + // This is the FIXED implementation from detail/index.tsx + const backToPrev = () => { + // Preserve pagination and filter states when navigating back + const searchParams = new URLSearchParams(window.location.search) + const queryString = searchParams.toString() + const separator = queryString ? '?' : '' + const backPath = `/datasets/${datasetId}/documents${separator}${queryString}` + router.push(backPath) + } + + return ( +
+ +
+ Dataset: {datasetId}, Document: {documentId} +
+
+ ) +} + +describe('Document Detail Navigation Fix Verification', () => { + beforeEach(() => { + jest.clearAllMocks() + + // Mock successful API responses + ;(useDocumentDetail as jest.Mock).mockReturnValue({ + data: { + id: 'doc-123', + name: 'Test Document', + display_status: 'available', + enabled: true, + archived: false, + }, + error: null, + }) + + ;(useDocumentMetadata as jest.Mock).mockReturnValue({ + data: null, + error: null, + }) + }) + + describe('Query Parameter Preservation', () => { + test('preserves pagination state (page 3, limit 25)', () => { + // Simulate user coming from page 3 with 25 items per page + Object.defineProperty(window, 'location', { + value: { + search: '?page=3&limit=25', + }, + writable: true, + }) + + render() + + // User clicks back button + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // Should preserve the pagination state + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=3&limit=25') + + console.log('✅ Pagination state preserved: page=3&limit=25') + }) + + test('preserves search keyword and filters', () => { + // Simulate user with search and filters applied + Object.defineProperty(window, 'location', { + value: { + search: '?page=2&limit=10&keyword=API%20documentation&status=active', + }, + writable: true, + }) + + render() + + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // Should preserve all query parameters + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=10&keyword=API+documentation&status=active') + + console.log('✅ Search and filters preserved') + }) + + test('handles complex query parameters with special characters', () => { + // Test with complex query string including encoded characters + Object.defineProperty(window, 'location', { + value: { + search: '?page=1&limit=50&keyword=test%20%26%20debug&sort=name&order=desc&filter=%7B%22type%22%3A%22pdf%22%7D', + }, + writable: true, + }) + + render() + + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // URLSearchParams will normalize the encoding, but preserve all parameters + const expectedCall = mockPush.mock.calls[0][0] + expect(expectedCall).toMatch(/^\/datasets\/dataset-123\/documents\?/) + expect(expectedCall).toMatch(/page=1/) + expect(expectedCall).toMatch(/limit=50/) + expect(expectedCall).toMatch(/keyword=test/) + expect(expectedCall).toMatch(/sort=name/) + expect(expectedCall).toMatch(/order=desc/) + + console.log('✅ Complex query parameters handled:', expectedCall) + }) + + test('handles empty query parameters gracefully', () => { + // No query parameters in URL + Object.defineProperty(window, 'location', { + value: { + search: '', + }, + writable: true, + }) + + render() + + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // Should navigate to clean documents URL + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents') + + console.log('✅ Empty parameters handled gracefully') + }) + }) + + describe('Different Dataset IDs', () => { + test('works with different dataset identifiers', () => { + Object.defineProperty(window, 'location', { + value: { + search: '?page=5&limit=10', + }, + writable: true, + }) + + // Test with different dataset ID format + render() + + fireEvent.click(screen.getByTestId('back-button-fixed')) + + expect(mockPush).toHaveBeenCalledWith('/datasets/ds-prod-2024-001/documents?page=5&limit=10') + + console.log('✅ Works with different dataset ID formats') + }) + }) + + describe('Real User Scenarios', () => { + test('scenario: user searches, goes to page 3, views document, clicks back', () => { + // User searched for "API" and navigated to page 3 + Object.defineProperty(window, 'location', { + value: { + search: '?keyword=API&page=3&limit=10', + }, + writable: true, + }) + + render() + + // User decides to go back to continue browsing + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // Should return to page 3 of API search results + expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?keyword=API&page=3&limit=10') + + console.log('✅ Real user scenario: search + pagination preserved') + }) + + test('scenario: user applies multiple filters, goes to document, returns', () => { + // User has applied multiple filters and is on page 2 + Object.defineProperty(window, 'location', { + value: { + search: '?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc', + }, + writable: true, + }) + + render() + + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // All filters should be preserved + expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-dataset/documents?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc') + + console.log('✅ Complex filtering scenario preserved') + }) + }) + + describe('Error Handling and Edge Cases', () => { + test('handles malformed query parameters gracefully', () => { + // Test with potentially problematic query string + Object.defineProperty(window, 'location', { + value: { + search: '?page=invalid&limit=&keyword=test&=emptykey&malformed', + }, + writable: true, + }) + + render() + + // Should not throw errors + expect(() => { + fireEvent.click(screen.getByTestId('back-button-fixed')) + }).not.toThrow() + + // Should still attempt navigation (URLSearchParams will clean up the parameters) + expect(mockPush).toHaveBeenCalled() + const navigationPath = mockPush.mock.calls[0][0] + expect(navigationPath).toMatch(/^\/datasets\/dataset-123\/documents/) + + console.log('✅ Malformed parameters handled gracefully:', navigationPath) + }) + + test('handles very long query strings', () => { + // Test with a very long query string + const longKeyword = 'a'.repeat(1000) + Object.defineProperty(window, 'location', { + value: { + search: `?page=1&keyword=${longKeyword}`, + }, + writable: true, + }) + + render() + + expect(() => { + fireEvent.click(screen.getByTestId('back-button-fixed')) + }).not.toThrow() + + expect(mockPush).toHaveBeenCalled() + + console.log('✅ Long query strings handled') + }) + }) + + describe('Performance Verification', () => { + test('navigation function executes quickly', () => { + Object.defineProperty(window, 'location', { + value: { + search: '?page=1&limit=10&keyword=test', + }, + writable: true, + }) + + render() + + const startTime = performance.now() + fireEvent.click(screen.getByTestId('back-button-fixed')) + const endTime = performance.now() + + const executionTime = endTime - startTime + + // Should execute in less than 10ms + expect(executionTime).toBeLessThan(10) + + console.log(`⚡ Navigation execution time: ${executionTime.toFixed(2)}ms`) + }) + }) +}) diff --git a/web/__tests__/document-list-sorting.test.tsx b/web/__tests__/document-list-sorting.test.tsx new file mode 100644 index 0000000000..1510dbec23 --- /dev/null +++ b/web/__tests__/document-list-sorting.test.tsx @@ -0,0 +1,83 @@ +/** + * Document List Sorting Tests + */ + +describe('Document List Sorting', () => { + const mockDocuments = [ + { id: '1', name: 'Beta.pdf', word_count: 500, hit_count: 10, created_at: 1699123456 }, + { id: '2', name: 'Alpha.txt', word_count: 200, hit_count: 25, created_at: 1699123400 }, + { id: '3', name: 'Gamma.docx', word_count: 800, hit_count: 5, created_at: 1699123500 }, + ] + + const sortDocuments = (docs: any[], field: string, order: 'asc' | 'desc') => { + return [...docs].sort((a, b) => { + let aValue: any + let bValue: any + + switch (field) { + case 'name': + aValue = a.name?.toLowerCase() || '' + bValue = b.name?.toLowerCase() || '' + break + case 'word_count': + aValue = a.word_count || 0 + bValue = b.word_count || 0 + break + case 'hit_count': + aValue = a.hit_count || 0 + bValue = b.hit_count || 0 + break + case 'created_at': + aValue = a.created_at + bValue = b.created_at + break + default: + return 0 + } + + if (field === 'name') { + const result = aValue.localeCompare(bValue) + return order === 'asc' ? result : -result + } + else { + const result = aValue - bValue + return order === 'asc' ? result : -result + } + }) + } + + test('sorts by name descending (default for UI consistency)', () => { + const sorted = sortDocuments(mockDocuments, 'name', 'desc') + expect(sorted.map(doc => doc.name)).toEqual(['Gamma.docx', 'Beta.pdf', 'Alpha.txt']) + }) + + test('sorts by name ascending (after toggle)', () => { + const sorted = sortDocuments(mockDocuments, 'name', 'asc') + expect(sorted.map(doc => doc.name)).toEqual(['Alpha.txt', 'Beta.pdf', 'Gamma.docx']) + }) + + test('sorts by word_count descending', () => { + const sorted = sortDocuments(mockDocuments, 'word_count', 'desc') + expect(sorted.map(doc => doc.word_count)).toEqual([800, 500, 200]) + }) + + test('sorts by hit_count descending', () => { + const sorted = sortDocuments(mockDocuments, 'hit_count', 'desc') + expect(sorted.map(doc => doc.hit_count)).toEqual([25, 10, 5]) + }) + + test('sorts by created_at descending (newest first)', () => { + const sorted = sortDocuments(mockDocuments, 'created_at', 'desc') + expect(sorted.map(doc => doc.created_at)).toEqual([1699123500, 1699123456, 1699123400]) + }) + + test('handles empty values correctly', () => { + const docsWithEmpty = [ + { id: '1', name: 'Test', word_count: 100, hit_count: 5, created_at: 1699123456 }, + { id: '2', name: 'Empty', word_count: 0, hit_count: 0, created_at: 1699123400 }, + ] + + const sorted = sortDocuments(docsWithEmpty, 'word_count', 'desc') + expect(sorted.map(doc => doc.word_count)).toEqual([100, 0]) + }) +}) diff --git a/web/__tests__/navigation-utils.test.ts b/web/__tests__/navigation-utils.test.ts new file mode 100644 index 0000000000..9a388505d6 --- /dev/null +++ b/web/__tests__/navigation-utils.test.ts @@ -0,0 +1,290 @@ +/** + * Navigation Utilities Test + * + * Tests for the navigation utility functions to ensure they handle + * query parameter preservation correctly across different scenarios. + */ + +import { + createBackNavigation, + createNavigationPath, + createNavigationPathWithParams, + datasetNavigation, + extractQueryParams, + mergeQueryParams, +} from '@/utils/navigation' + +// Mock router for testing +const mockPush = jest.fn() +const mockRouter = { push: mockPush } + +describe('Navigation Utilities', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('createNavigationPath', () => { + test('preserves query parameters by default', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10&keyword=test' }, + writable: true, + }) + + const path = createNavigationPath('/datasets/123/documents') + expect(path).toBe('/datasets/123/documents?page=3&limit=10&keyword=test') + }) + + test('returns clean path when preserveParams is false', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10' }, + writable: true, + }) + + const path = createNavigationPath('/datasets/123/documents', false) + expect(path).toBe('/datasets/123/documents') + }) + + test('handles empty query parameters', () => { + Object.defineProperty(window, 'location', { + value: { search: '' }, + writable: true, + }) + + const path = createNavigationPath('/datasets/123/documents') + expect(path).toBe('/datasets/123/documents') + }) + + test('handles errors gracefully', () => { + // Mock window.location to throw an error + Object.defineProperty(window, 'location', { + get: () => { + throw new Error('Location access denied') + }, + configurable: true, + }) + + const consoleSpy = jest.spyOn(console, 'warn').mockImplementation() + const path = createNavigationPath('/datasets/123/documents') + + expect(path).toBe('/datasets/123/documents') + expect(consoleSpy).toHaveBeenCalledWith('Failed to preserve query parameters:', expect.any(Error)) + + consoleSpy.mockRestore() + }) + }) + + describe('createBackNavigation', () => { + test('creates function that navigates with preserved params', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=2&limit=25' }, + writable: true, + }) + + const backFn = createBackNavigation(mockRouter, '/datasets/123/documents') + backFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents?page=2&limit=25') + }) + + test('creates function that navigates without params when specified', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=2&limit=25' }, + writable: true, + }) + + const backFn = createBackNavigation(mockRouter, '/datasets/123/documents', false) + backFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents') + }) + }) + + describe('extractQueryParams', () => { + test('extracts specified parameters', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10&keyword=test&other=value' }, + writable: true, + }) + + const params = extractQueryParams(['page', 'limit', 'keyword']) + expect(params).toEqual({ + page: '3', + limit: '10', + keyword: 'test', + }) + }) + + test('handles missing parameters', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3' }, + writable: true, + }) + + const params = extractQueryParams(['page', 'limit', 'missing']) + expect(params).toEqual({ + page: '3', + }) + }) + + test('handles errors gracefully', () => { + Object.defineProperty(window, 'location', { + get: () => { + throw new Error('Location access denied') + }, + configurable: true, + }) + + const consoleSpy = jest.spyOn(console, 'warn').mockImplementation() + const params = extractQueryParams(['page', 'limit']) + + expect(params).toEqual({}) + expect(consoleSpy).toHaveBeenCalledWith('Failed to extract query parameters:', expect.any(Error)) + + consoleSpy.mockRestore() + }) + }) + + describe('createNavigationPathWithParams', () => { + test('creates path with specified parameters', () => { + const path = createNavigationPathWithParams('/datasets/123/documents', { + page: 1, + limit: 25, + keyword: 'search term', + }) + + expect(path).toBe('/datasets/123/documents?page=1&limit=25&keyword=search+term') + }) + + test('filters out empty values', () => { + const path = createNavigationPathWithParams('/datasets/123/documents', { + page: 1, + limit: '', + keyword: 'test', + empty: null, + undefined, + }) + + expect(path).toBe('/datasets/123/documents?page=1&keyword=test') + }) + + test('handles errors gracefully', () => { + // Mock URLSearchParams to throw an error + const originalURLSearchParams = globalThis.URLSearchParams + globalThis.URLSearchParams = jest.fn(() => { + throw new Error('URLSearchParams error') + }) as any + + const consoleSpy = jest.spyOn(console, 'warn').mockImplementation() + const path = createNavigationPathWithParams('/datasets/123/documents', { page: 1 }) + + expect(path).toBe('/datasets/123/documents') + expect(consoleSpy).toHaveBeenCalledWith('Failed to create navigation path with params:', expect.any(Error)) + + consoleSpy.mockRestore() + globalThis.URLSearchParams = originalURLSearchParams + }) + }) + + describe('mergeQueryParams', () => { + test('merges new params with existing ones', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10' }, + writable: true, + }) + + const merged = mergeQueryParams({ keyword: 'test', page: '1' }) + const result = merged.toString() + + expect(result).toContain('page=1') // overridden + expect(result).toContain('limit=10') // preserved + expect(result).toContain('keyword=test') // added + }) + + test('removes parameters when value is null', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10&keyword=test' }, + writable: true, + }) + + const merged = mergeQueryParams({ keyword: null, filter: 'active' }) + const result = merged.toString() + + expect(result).toContain('page=3') + expect(result).toContain('limit=10') + expect(result).not.toContain('keyword') + expect(result).toContain('filter=active') + }) + + test('creates fresh params when preserveExisting is false', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10' }, + writable: true, + }) + + const merged = mergeQueryParams({ keyword: 'test' }, false) + const result = merged.toString() + + expect(result).toBe('keyword=test') + }) + }) + + describe('datasetNavigation', () => { + test('backToDocuments creates correct navigation function', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=2&limit=25' }, + writable: true, + }) + + const backFn = datasetNavigation.backToDocuments(mockRouter, 'dataset-123') + backFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=25') + }) + + test('toDocumentDetail creates correct navigation function', () => { + const detailFn = datasetNavigation.toDocumentDetail(mockRouter, 'dataset-123', 'doc-456') + detailFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456') + }) + + test('toDocumentSettings creates correct navigation function', () => { + const settingsFn = datasetNavigation.toDocumentSettings(mockRouter, 'dataset-123', 'doc-456') + settingsFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456/settings') + }) + }) + + describe('Real-world Integration Scenarios', () => { + test('complete user workflow: list -> detail -> back', () => { + // User starts on page 3 with search + Object.defineProperty(window, 'location', { + value: { search: '?page=3&keyword=API&limit=25' }, + writable: true, + }) + + // Create back navigation function (as would be done in detail component) + const backToDocuments = datasetNavigation.backToDocuments(mockRouter, 'main-dataset') + + // User clicks back + backToDocuments() + + // Should return to exact same list state + expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?page=3&keyword=API&limit=25') + }) + + test('user applies filters then views document', () => { + // Complex filter state + Object.defineProperty(window, 'location', { + value: { search: '?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc' }, + writable: true, + }) + + const backFn = createBackNavigation(mockRouter, '/datasets/filtered-set/documents') + backFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-set/documents?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc') + }) + }) +}) diff --git a/web/__tests__/plugin-tool-workflow-error.test.tsx b/web/__tests__/plugin-tool-workflow-error.test.tsx new file mode 100644 index 0000000000..370052bc80 --- /dev/null +++ b/web/__tests__/plugin-tool-workflow-error.test.tsx @@ -0,0 +1,207 @@ +/** + * Test cases to reproduce the plugin tool workflow error + * Issue: #23154 - Application error when loading plugin tools in workflow + * Root cause: split() operation called on null/undefined values + */ + +describe('Plugin Tool Workflow Error Reproduction', () => { + /** + * Mock function to simulate the problematic code in switch-plugin-version.tsx:29 + * const [pluginId] = uniqueIdentifier.split(':') + */ + const mockSwitchPluginVersionLogic = (uniqueIdentifier: string | null | undefined) => { + // This directly reproduces the problematic line from switch-plugin-version.tsx:29 + const [pluginId] = uniqueIdentifier!.split(':') + return pluginId + } + + /** + * Test case 1: Simulate null uniqueIdentifier + * This should reproduce the error mentioned in the issue + */ + it('should reproduce error when uniqueIdentifier is null', () => { + expect(() => { + mockSwitchPluginVersionLogic(null) + }).toThrow('Cannot read properties of null (reading \'split\')') + }) + + /** + * Test case 2: Simulate undefined uniqueIdentifier + */ + it('should reproduce error when uniqueIdentifier is undefined', () => { + expect(() => { + mockSwitchPluginVersionLogic(undefined) + }).toThrow('Cannot read properties of undefined (reading \'split\')') + }) + + /** + * Test case 3: Simulate empty string uniqueIdentifier + */ + it('should handle empty string uniqueIdentifier', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('') + expect(result).toBe('') // Empty string split by ':' returns [''] + }).not.toThrow() + }) + + /** + * Test case 4: Simulate malformed uniqueIdentifier without colon separator + */ + it('should handle malformed uniqueIdentifier without colon separator', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('malformed-identifier-without-colon') + expect(result).toBe('malformed-identifier-without-colon') // No colon means full string returned + }).not.toThrow() + }) + + /** + * Test case 5: Simulate valid uniqueIdentifier + */ + it('should work correctly with valid uniqueIdentifier', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('valid-plugin-id:1.0.0') + expect(result).toBe('valid-plugin-id') + }).not.toThrow() + }) +}) + +/** + * Test for the variable processing split error in use-single-run-form-params + */ +describe('Variable Processing Split Error', () => { + /** + * Mock function to simulate the problematic code in use-single-run-form-params.ts:91 + * const getDependentVars = () => { + * return varInputs.map(item => item.variable.slice(1, -1).split('.')) + * } + */ + const mockGetDependentVars = (varInputs: Array<{ variable: string | null | undefined }>) => { + return varInputs.map((item) => { + // Guard against null/undefined variable to prevent app crash + if (!item.variable || typeof item.variable !== 'string') + return [] + + return item.variable.slice(1, -1).split('.') + }).filter(arr => arr.length > 0) // Filter out empty arrays + } + + /** + * Test case 1: Variable processing with null variable + */ + it('should handle null variable safely', () => { + const varInputs = [{ variable: null }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // null variables are filtered out + }) + + /** + * Test case 2: Variable processing with undefined variable + */ + it('should handle undefined variable safely', () => { + const varInputs = [{ variable: undefined }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // undefined variables are filtered out + }) + + /** + * Test case 3: Variable processing with empty string + */ + it('should handle empty string variable', () => { + const varInputs = [{ variable: '' }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // Empty string is filtered out, so result is empty array + }) + + /** + * Test case 4: Variable processing with valid variable format + */ + it('should work correctly with valid variable format', () => { + const varInputs = [{ variable: '{{workflow.node.output}}' }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result[0]).toEqual(['{workflow', 'node', 'output}']) + }) +}) + +/** + * Integration test to simulate the complete workflow scenario + */ +describe('Plugin Tool Workflow Integration', () => { + /** + * Simulate the scenario where plugin metadata is incomplete or corrupted + * This can happen when: + * 1. Plugin is being loaded from marketplace but metadata request fails + * 2. Plugin configuration is corrupted in database + * 3. Network issues during plugin loading + */ + it('should reproduce the client-side exception scenario', () => { + // Mock incomplete plugin data that could cause the error + const incompletePluginData = { + // Missing or null uniqueIdentifier + uniqueIdentifier: null, + meta: null, + minimum_dify_version: undefined, + } + + // This simulates the error path that leads to the white screen + expect(() => { + // Simulate the code path in switch-plugin-version.tsx:29 + // The actual problematic code doesn't use optional chaining + const _pluginId = (incompletePluginData.uniqueIdentifier as any).split(':')[0] + }).toThrow('Cannot read properties of null (reading \'split\')') + }) + + /** + * Test the scenario mentioned in the issue where plugin tools are loaded in workflow + */ + it('should simulate plugin tool loading in workflow context', () => { + // Mock the workflow context where plugin tools are being loaded + const workflowPluginTools = [ + { + provider_name: 'test-plugin', + uniqueIdentifier: null, // This is the problematic case + tool_name: 'test-tool', + }, + { + provider_name: 'valid-plugin', + uniqueIdentifier: 'valid-plugin:1.0.0', + tool_name: 'valid-tool', + }, + ] + + // Process each plugin tool + workflowPluginTools.forEach((tool, _index) => { + if (tool.uniqueIdentifier === null) { + // This reproduces the exact error scenario + expect(() => { + const _pluginId = (tool.uniqueIdentifier as any).split(':')[0] + }).toThrow() + } + else { + // Valid tools should work fine + expect(() => { + const _pluginId = tool.uniqueIdentifier.split(':')[0] + }).not.toThrow() + } + }) + }) +}) diff --git a/web/__tests__/unified-tags-logic.test.ts b/web/__tests__/unified-tags-logic.test.ts new file mode 100644 index 0000000000..c920e28e0a --- /dev/null +++ b/web/__tests__/unified-tags-logic.test.ts @@ -0,0 +1,396 @@ +/** + * Unified Tags Editing - Pure Logic Tests + * + * This test file validates the core business logic and state management + * behaviors introduced in the recent 7 commits without requiring complex mocks. + */ + +describe('Unified Tags Editing - Pure Logic Tests', () => { + describe('Tag State Management Logic', () => { + it('should detect when tag values have changed', () => { + const currentValue = ['tag1', 'tag2'] + const newSelectedTagIDs = ['tag1', 'tag3'] + + // This is the valueNotChanged logic from TagSelector component + const valueNotChanged + = currentValue.length === newSelectedTagIDs.length + && currentValue.every(v => newSelectedTagIDs.includes(v)) + && newSelectedTagIDs.every(v => currentValue.includes(v)) + + expect(valueNotChanged).toBe(false) + }) + + it('should correctly identify unchanged tag values', () => { + const currentValue = ['tag1', 'tag2'] + const newSelectedTagIDs = ['tag2', 'tag1'] // Same tags, different order + + const valueNotChanged + = currentValue.length === newSelectedTagIDs.length + && currentValue.every(v => newSelectedTagIDs.includes(v)) + && newSelectedTagIDs.every(v => currentValue.includes(v)) + + expect(valueNotChanged).toBe(true) + }) + + it('should calculate correct tag operations for binding/unbinding', () => { + const currentValue = ['tag1', 'tag2'] + const selectedTagIDs = ['tag2', 'tag3'] + + // This is the handleValueChange logic from TagSelector + const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v)) + const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v)) + + expect(addTagIDs).toEqual(['tag3']) + expect(removeTagIDs).toEqual(['tag1']) + }) + + it('should handle empty tag arrays correctly', () => { + const currentValue: string[] = [] + const selectedTagIDs = ['tag1'] + + const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v)) + const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v)) + + expect(addTagIDs).toEqual(['tag1']) + expect(removeTagIDs).toEqual([]) + expect(currentValue.length).toBe(0) // Verify empty array usage + }) + + it('should handle removing all tags', () => { + const currentValue = ['tag1', 'tag2'] + const selectedTagIDs: string[] = [] + + const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v)) + const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v)) + + expect(addTagIDs).toEqual([]) + expect(removeTagIDs).toEqual(['tag1', 'tag2']) + expect(selectedTagIDs.length).toBe(0) // Verify empty array usage + }) + }) + + describe('Fallback Logic (from layout-main.tsx)', () => { + it('should trigger fallback when tags are missing or empty', () => { + const appDetailWithoutTags = { tags: [] } + const appDetailWithTags = { tags: [{ id: 'tag1' }] } + const appDetailWithUndefinedTags = { tags: undefined as any } + + // This simulates the condition in layout-main.tsx + const shouldFallback1 = !appDetailWithoutTags.tags || appDetailWithoutTags.tags.length === 0 + const shouldFallback2 = !appDetailWithTags.tags || appDetailWithTags.tags.length === 0 + const shouldFallback3 = !appDetailWithUndefinedTags.tags || appDetailWithUndefinedTags.tags.length === 0 + + expect(shouldFallback1).toBe(true) // Empty array should trigger fallback + expect(shouldFallback2).toBe(false) // Has tags, no fallback needed + expect(shouldFallback3).toBe(true) // Undefined tags should trigger fallback + }) + + it('should preserve tags when fallback succeeds', () => { + const originalAppDetail = { tags: [] as any[] } + const fallbackResult = { tags: [{ id: 'tag1', name: 'fallback-tag' }] } + + // This simulates the successful fallback in layout-main.tsx + if (fallbackResult?.tags) + originalAppDetail.tags = fallbackResult.tags + + expect(originalAppDetail.tags).toEqual(fallbackResult.tags) + expect(originalAppDetail.tags.length).toBe(1) + }) + + it('should continue with empty tags when fallback fails', () => { + const originalAppDetail: { tags: any[] } = { tags: [] } + const fallbackResult: { tags?: any[] } | null = null + + // This simulates fallback failure in layout-main.tsx + if (fallbackResult?.tags) + originalAppDetail.tags = fallbackResult.tags + + expect(originalAppDetail.tags).toEqual([]) + }) + }) + + describe('TagSelector Auto-initialization Logic', () => { + it('should trigger getTagList when tagList is empty', () => { + const tagList: any[] = [] + let getTagListCalled = false + const getTagList = () => { + getTagListCalled = true + } + + // This simulates the useEffect in TagSelector + if (tagList.length === 0) + getTagList() + + expect(getTagListCalled).toBe(true) + }) + + it('should not trigger getTagList when tagList has items', () => { + const tagList = [{ id: 'tag1', name: 'existing-tag' }] + let getTagListCalled = false + const getTagList = () => { + getTagListCalled = true + } + + // This simulates the useEffect in TagSelector + if (tagList.length === 0) + getTagList() + + expect(getTagListCalled).toBe(false) + }) + }) + + describe('State Initialization Patterns', () => { + it('should maintain AppCard tag state pattern', () => { + const app = { tags: [{ id: 'tag1', name: 'test' }] } + + // Original AppCard pattern: useState(app.tags) + const initialTags = app.tags + expect(Array.isArray(initialTags)).toBe(true) + expect(initialTags.length).toBe(1) + expect(initialTags).toBe(app.tags) // Reference equality for AppCard + }) + + it('should maintain AppInfo tag state pattern', () => { + const appDetail = { tags: [{ id: 'tag1', name: 'test' }] } + + // New AppInfo pattern: useState(appDetail?.tags || []) + const initialTags = appDetail?.tags || [] + expect(Array.isArray(initialTags)).toBe(true) + expect(initialTags.length).toBe(1) + }) + + it('should handle undefined appDetail gracefully in AppInfo', () => { + const appDetail = undefined + + // AppInfo pattern with undefined appDetail + const initialTags = (appDetail as any)?.tags || [] + expect(Array.isArray(initialTags)).toBe(true) + expect(initialTags.length).toBe(0) + }) + }) + + describe('CSS Class and Layout Logic', () => { + it('should apply correct minimum width condition', () => { + const minWidth = 'true' + + // This tests the minWidth logic in TagSelector + const shouldApplyMinWidth = minWidth && '!min-w-80' + expect(shouldApplyMinWidth).toBe('!min-w-80') + }) + + it('should not apply minimum width when not specified', () => { + const minWidth = undefined + + const shouldApplyMinWidth = minWidth && '!min-w-80' + expect(shouldApplyMinWidth).toBeFalsy() + }) + + it('should handle overflow layout classes correctly', () => { + // This tests the layout pattern from AppCard and new AppInfo + const overflowLayoutClasses = { + container: 'flex w-0 grow items-center', + inner: 'w-full', + truncate: 'truncate', + } + + expect(overflowLayoutClasses.container).toContain('w-0 grow') + expect(overflowLayoutClasses.inner).toContain('w-full') + expect(overflowLayoutClasses.truncate).toBe('truncate') + }) + }) + + describe('fetchAppWithTags Service Logic', () => { + it('should correctly find app by ID from app list', () => { + const appList = [ + { id: 'app1', name: 'App 1', tags: [] }, + { id: 'test-app-id', name: 'Test App', tags: [{ id: 'tag1', name: 'test' }] }, + { id: 'app3', name: 'App 3', tags: [] }, + ] + const targetAppId = 'test-app-id' + + // This simulates the logic in fetchAppWithTags + const foundApp = appList.find(app => app.id === targetAppId) + + expect(foundApp).toBeDefined() + expect(foundApp?.id).toBe('test-app-id') + expect(foundApp?.tags.length).toBe(1) + }) + + it('should return null when app not found', () => { + const appList = [ + { id: 'app1', name: 'App 1' }, + { id: 'app2', name: 'App 2' }, + ] + const targetAppId = 'nonexistent-app' + + const foundApp = appList.find(app => app.id === targetAppId) || null + + expect(foundApp).toBeNull() + }) + + it('should handle empty app list', () => { + const appList: any[] = [] + const targetAppId = 'any-app' + + const foundApp = appList.find(app => app.id === targetAppId) || null + + expect(foundApp).toBeNull() + expect(appList.length).toBe(0) // Verify empty array usage + }) + }) + + describe('Data Structure Validation', () => { + it('should maintain consistent tag data structure', () => { + const tag = { + id: 'tag1', + name: 'test-tag', + type: 'app', + binding_count: 1, + } + + expect(tag).toHaveProperty('id') + expect(tag).toHaveProperty('name') + expect(tag).toHaveProperty('type') + expect(tag).toHaveProperty('binding_count') + expect(tag.type).toBe('app') + expect(typeof tag.binding_count).toBe('number') + }) + + it('should handle tag arrays correctly', () => { + const tags = [ + { id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 }, + { id: 'tag2', name: 'Tag 2', type: 'app', binding_count: 0 }, + ] + + expect(Array.isArray(tags)).toBe(true) + expect(tags.length).toBe(2) + expect(tags.every(tag => tag.type === 'app')).toBe(true) + }) + + it('should validate app data structure with tags', () => { + const app = { + id: 'test-app', + name: 'Test App', + tags: [ + { id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 }, + ], + } + + expect(app).toHaveProperty('id') + expect(app).toHaveProperty('name') + expect(app).toHaveProperty('tags') + expect(Array.isArray(app.tags)).toBe(true) + expect(app.tags.length).toBe(1) + }) + }) + + describe('Performance and Edge Cases', () => { + it('should handle large tag arrays efficiently', () => { + const largeTags = Array.from({ length: 100 }, (_, i) => `tag${i}`) + const selectedTags = ['tag1', 'tag50', 'tag99'] + + // Performance test: filtering should be efficient + const startTime = Date.now() + const addTags = selectedTags.filter(tag => !largeTags.includes(tag)) + const removeTags = largeTags.filter(tag => !selectedTags.includes(tag)) + const endTime = Date.now() + + expect(endTime - startTime).toBeLessThan(10) // Should be very fast + expect(addTags.length).toBe(0) // All selected tags exist + expect(removeTags.length).toBe(97) // 100 - 3 = 97 tags to remove + }) + + it('should handle malformed tag data gracefully', () => { + const mixedData = [ + { id: 'valid1', name: 'Valid Tag', type: 'app', binding_count: 1 }, + { id: 'invalid1' }, // Missing required properties + null, + undefined, + { id: 'valid2', name: 'Another Valid', type: 'app', binding_count: 0 }, + ] + + // Filter out invalid entries + const validTags = mixedData.filter((tag): tag is { id: string; name: string; type: string; binding_count: number } => + tag != null + && typeof tag === 'object' + && 'id' in tag + && 'name' in tag + && 'type' in tag + && 'binding_count' in tag + && typeof tag.binding_count === 'number', + ) + + expect(validTags.length).toBe(2) + expect(validTags.every(tag => tag.id && tag.name)).toBe(true) + }) + + it('should handle concurrent tag operations correctly', () => { + const operations = [ + { type: 'add', tagIds: ['tag1', 'tag2'] }, + { type: 'remove', tagIds: ['tag3'] }, + { type: 'add', tagIds: ['tag4'] }, + ] + + // Simulate processing operations + const results = operations.map(op => ({ + ...op, + processed: true, + timestamp: Date.now(), + })) + + expect(results.length).toBe(3) + expect(results.every(result => result.processed)).toBe(true) + }) + }) + + describe('Backward Compatibility Verification', () => { + it('should not break existing AppCard behavior', () => { + // Verify AppCard continues to work with original patterns + const originalAppCardLogic = { + initializeTags: (app: any) => app.tags, + updateTags: (_currentTags: any[], newTags: any[]) => newTags, + shouldRefresh: true, + } + + const app = { tags: [{ id: 'tag1', name: 'original' }] } + const initializedTags = originalAppCardLogic.initializeTags(app) + + expect(initializedTags).toBe(app.tags) + expect(originalAppCardLogic.shouldRefresh).toBe(true) + }) + + it('should ensure AppInfo follows AppCard patterns', () => { + // Verify AppInfo uses compatible state management + const appCardPattern = (app: any) => app.tags + const appInfoPattern = (appDetail: any) => appDetail?.tags || [] + + const appWithTags = { tags: [{ id: 'tag1' }] } + const appWithoutTags = { tags: [] } + const undefinedApp = undefined + + expect(appCardPattern(appWithTags)).toEqual(appInfoPattern(appWithTags)) + expect(appInfoPattern(appWithoutTags)).toEqual([]) + expect(appInfoPattern(undefinedApp)).toEqual([]) + }) + + it('should maintain consistent API parameters', () => { + // Verify service layer maintains expected parameters + const fetchAppListParams = { + url: '/apps', + params: { page: 1, limit: 100 }, + } + + const tagApiParams = { + bindTag: (tagIDs: string[], targetID: string, type: string) => ({ tagIDs, targetID, type }), + unBindTag: (tagID: string, targetID: string, type: string) => ({ tagID, targetID, type }), + } + + expect(fetchAppListParams.url).toBe('/apps') + expect(fetchAppListParams.params.limit).toBe(100) + + const bindResult = tagApiParams.bindTag(['tag1'], 'app1', 'app') + expect(bindResult.tagIDs).toEqual(['tag1']) + expect(bindResult.type).toBe('app') + }) + }) +}) diff --git a/web/__tests__/xss-fix-verification.test.tsx b/web/__tests__/xss-fix-verification.test.tsx new file mode 100644 index 0000000000..2fa5ab3c05 --- /dev/null +++ b/web/__tests__/xss-fix-verification.test.tsx @@ -0,0 +1,212 @@ +/** + * XSS Fix Verification Test + * + * This test verifies that the XSS vulnerability in check-code pages has been + * properly fixed by replacing dangerouslySetInnerHTML with safe React rendering. + */ + +import React from 'react' +import { cleanup, render } from '@testing-library/react' +import '@testing-library/jest-dom' + +// Mock i18next with the new safe translation structure +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + if (key === 'login.checkCode.tipsPrefix') + return 'We send a verification code to ' + + return key + }, + }), +})) + +// Mock Next.js useSearchParams +jest.mock('next/navigation', () => ({ + useSearchParams: () => ({ + get: (key: string) => { + if (key === 'email') + return 'test@example.com' + return null + }, + }), +})) + +// Fixed CheckCode component implementation (current secure version) +const SecureCheckCodeComponent = ({ email }: { email: string }) => { + const { t } = require('react-i18next').useTranslation() + + return ( +
+

Check Code

+

+ + {t('login.checkCode.tipsPrefix')} + {email} + +

+
+ ) +} + +// Vulnerable implementation for comparison (what we fixed) +const VulnerableCheckCodeComponent = ({ email }: { email: string }) => { + const mockTranslation = (key: string, params?: any) => { + if (key === 'login.checkCode.tips' && params?.email) + return `We send a verification code to ${params.email}` + + return key + } + + return ( +
+

Check Code

+

+ +

+
+ ) +} + +describe('XSS Fix Verification - Check Code Pages Security', () => { + afterEach(() => { + cleanup() + }) + + const maliciousEmail = 'test@example.com' + + it('should securely render email with HTML characters as text (FIXED VERSION)', () => { + console.log('\n🔒 Security Fix Verification Report') + console.log('===================================') + + const { container } = render() + + const spanElement = container.querySelector('span') + const strongElement = container.querySelector('strong') + const scriptElements = container.querySelectorAll('script') + + console.log('\n✅ Fixed Implementation Results:') + console.log('- Email rendered in strong tag:', strongElement?.textContent) + console.log('- HTML tags visible as text:', strongElement?.textContent?.includes('', + 'normal@email.com', + ] + + testCases.forEach((testEmail, index) => { + const { container } = render() + + const strongElement = container.querySelector('strong') + const scriptElements = container.querySelectorAll('script') + const imgElements = container.querySelectorAll('img') + const divElements = container.querySelectorAll('div:not([data-testid])') + + console.log(`\n📧 Test Case ${index + 1}: ${testEmail.substring(0, 20)}...`) + console.log(` - Script elements: ${scriptElements.length}`) + console.log(` - Img elements: ${imgElements.length}`) + console.log(` - Malicious divs: ${divElements.length - 1}`) // -1 for container div + console.log(` - Text content: ${strongElement?.textContent === testEmail ? 'SAFE' : 'ISSUE'}`) + + // All should be safe + expect(scriptElements).toHaveLength(0) + expect(imgElements).toHaveLength(0) + expect(strongElement?.textContent).toBe(testEmail) + }) + + console.log('\n✅ All test cases passed - secure rendering confirmed') + }) + + it('should validate the translation structure is secure', () => { + console.log('\n🔍 Translation Security Analysis') + console.log('=================================') + + const { t } = require('react-i18next').useTranslation() + const prefix = t('login.checkCode.tipsPrefix') + + console.log('- Translation key used: login.checkCode.tipsPrefix') + console.log('- Translation value:', prefix) + console.log('- Contains HTML tags:', prefix.includes('<')) + console.log('- Pure text content:', !prefix.includes('<') && !prefix.includes('>')) + + // Verify translation is plain text + expect(prefix).toBe('We send a verification code to ') + expect(prefix).not.toContain('<') + expect(prefix).not.toContain('>') + expect(typeof prefix).toBe('string') + + console.log('\n✅ Translation structure is secure - no HTML content') + }) + + it('should confirm React automatic escaping works correctly', () => { + console.log('\n⚡ React Security Mechanism Test') + console.log('=================================') + + // Test React's automatic escaping with various inputs + const dangerousInputs = [ + '', + '', + '">', + '\'>alert(3)', + '
click
', + ] + + dangerousInputs.forEach((input, index) => { + const TestComponent = () => {input} + const { container } = render() + + const strongElement = container.querySelector('strong') + const scriptElements = container.querySelectorAll('script') + + console.log(`\n🧪 Input ${index + 1}: ${input.substring(0, 30)}...`) + console.log(` - Rendered as text: ${strongElement?.textContent === input}`) + console.log(` - No script execution: ${scriptElements.length === 0}`) + + expect(strongElement?.textContent).toBe(input) + expect(scriptElements).toHaveLength(0) + }) + + console.log('\n🛡️ React automatic escaping is working perfectly') + }) +}) + +export {} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 6b3807f1c6..6d337e3c47 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -20,12 +20,18 @@ import cn from '@/utils/classnames' import { useStore } from '@/app/components/app/store' import AppSideBar from '@/app/components/app-sidebar' import type { NavIcon } from '@/app/components/app-sidebar/navLink' -import { fetchAppDetail } from '@/service/apps' +import { fetchAppDetailDirect } from '@/service/apps' import { useAppContext } from '@/context/app-context' import Loading from '@/app/components/base/loading' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import type { App } from '@/types/app' import useDocumentTitle from '@/hooks/use-document-title' +import { useStore as useTagStore } from '@/app/components/base/tag-management/store' +import dynamic from 'next/dynamic' + +const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), { + ssr: false, +}) export type IAppDetailLayoutProps = { children: React.ReactNode @@ -48,6 +54,7 @@ const AppDetailLayout: FC = (props) => { setAppDetail: state.setAppDetail, setAppSiderbarExpand: state.setAppSiderbarExpand, }))) + const showTagManagementModal = useTagStore(s => s.showTagManagementModal) const [isLoadingAppDetail, setIsLoadingAppDetail] = useState(false) const [appDetailRes, setAppDetailRes] = useState(null) const [navigation, setNavigation] = useState = (props) => { useEffect(() => { setAppDetail() setIsLoadingAppDetail(true) - fetchAppDetail({ url: '/apps', id: appId }).then((res) => { + fetchAppDetailDirect({ url: '/apps', id: appId }).then((res: App) => { setAppDetailRes(res) }).catch((e: any) => { if (e.status === 404) @@ -163,6 +170,9 @@ const AppDetailLayout: FC = (props) => {
{children}
+ {showTagManagementModal && ( + + )} ) } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx new file mode 100644 index 0000000000..a3281be8eb --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx @@ -0,0 +1,156 @@ +import React from 'react' +import { render } from '@testing-library/react' +import '@testing-library/jest-dom' +import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing' + +// Mock dependencies to isolate the SVG rendering issue +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('SVG Attribute Error Reproduction', () => { + // Capture console errors + const originalError = console.error + let errorMessages: string[] = [] + + beforeEach(() => { + errorMessages = [] + console.error = jest.fn((message) => { + errorMessages.push(message) + originalError(message) + }) + }) + + afterEach(() => { + console.error = originalError + }) + + it('should reproduce inkscape attribute errors when rendering OpikIconBig', () => { + console.log('\n=== TESTING OpikIconBig SVG ATTRIBUTE ERRORS ===') + + // Test multiple renders to check for inconsistency + for (let i = 0; i < 5; i++) { + console.log(`\nRender attempt ${i + 1}:`) + + const { unmount } = render() + + // Check for specific inkscape attribute errors + const inkscapeErrors = errorMessages.filter(msg => + typeof msg === 'string' && msg.includes('inkscape'), + ) + + if (inkscapeErrors.length > 0) { + console.log(`Found ${inkscapeErrors.length} inkscape errors:`) + inkscapeErrors.forEach((error, index) => { + console.log(` ${index + 1}. ${error.substring(0, 100)}...`) + }) + } + else { + console.log('No inkscape errors found in this render') + } + + unmount() + + // Clear errors for next iteration + errorMessages = [] + } + }) + + it('should analyze the SVG structure causing the errors', () => { + console.log('\n=== ANALYZING SVG STRUCTURE ===') + + // Import the JSON data directly + const iconData = require('@/app/components/base/icons/src/public/tracing/OpikIconBig.json') + + console.log('Icon structure analysis:') + console.log('- Root element:', iconData.icon.name) + console.log('- Children count:', iconData.icon.children?.length || 0) + + // Find problematic elements + const findProblematicElements = (node: any, path = '') => { + const problematicElements: any[] = [] + + if (node.name && (node.name.includes(':') || node.name.startsWith('sodipodi'))) { + problematicElements.push({ + path, + name: node.name, + attributes: Object.keys(node.attributes || {}), + }) + } + + // Check attributes for inkscape/sodipodi properties + if (node.attributes) { + const problematicAttrs = Object.keys(node.attributes).filter(attr => + attr.startsWith('inkscape:') || attr.startsWith('sodipodi:'), + ) + + if (problematicAttrs.length > 0) { + problematicElements.push({ + path, + name: node.name, + problematicAttributes: problematicAttrs, + }) + } + } + + if (node.children) { + node.children.forEach((child: any, index: number) => { + problematicElements.push( + ...findProblematicElements(child, `${path}/${node.name}[${index}]`), + ) + }) + } + + return problematicElements + } + + const problematicElements = findProblematicElements(iconData.icon, 'root') + + console.log(`\n🚨 Found ${problematicElements.length} problematic elements:`) + problematicElements.forEach((element, index) => { + console.log(`\n${index + 1}. Element: ${element.name}`) + console.log(` Path: ${element.path}`) + if (element.problematicAttributes) + console.log(` Problematic attributes: ${element.problematicAttributes.join(', ')}`) + }) + }) + + it('should test the normalizeAttrs function behavior', () => { + console.log('\n=== TESTING normalizeAttrs FUNCTION ===') + + const { normalizeAttrs } = require('@/app/components/base/icons/utils') + + const testAttributes = { + 'inkscape:showpageshadow': '2', + 'inkscape:pageopacity': '0.0', + 'inkscape:pagecheckerboard': '0', + 'inkscape:deskcolor': '#d1d1d1', + 'sodipodi:docname': 'opik-icon-big.svg', + 'xmlns:inkscape': 'https://www.inkscape.org/namespaces/inkscape', + 'xmlns:sodipodi': 'https://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd', + 'xmlns:svg': 'https://www.w3.org/2000/svg', + 'data-name': 'Layer 1', + 'normal-attr': 'value', + 'class': 'test-class', + } + + console.log('Input attributes:', Object.keys(testAttributes)) + + const normalized = normalizeAttrs(testAttributes) + + console.log('Normalized attributes:', Object.keys(normalized)) + console.log('Normalized values:', normalized) + + // Check if problematic attributes are still present + const problematicKeys = Object.keys(normalized).filter(key => + key.toLowerCase().includes('inkscape') || key.toLowerCase().includes('sodipodi'), + ) + + if (problematicKeys.length > 0) + console.log(`🚨 PROBLEM: Still found problematic attributes: ${problematicKeys.join(', ')}`) + else + console.log('✅ No problematic attributes found after normalization') + }) +}) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx index 3d05575127..1ab40e31bf 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx @@ -1,12 +1,9 @@ 'use client' import type { FC } from 'react' -import React, { useCallback, useEffect, useRef, useState } from 'react' -import { - RiEqualizer2Line, -} from '@remixicon/react' +import React, { useCallback, useRef, useState } from 'react' + import type { PopupProps } from './config-popup' import ConfigPopup from './config-popup' -import cn from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, @@ -17,13 +14,13 @@ type Props = { readOnly: boolean className?: string hasConfigured: boolean - controlShowPopup?: number + children?: React.ReactNode } & PopupProps const ConfigBtn: FC = ({ className, hasConfigured, - controlShowPopup, + children, ...popupProps }) => { const [open, doSetOpen] = useState(false) @@ -37,13 +34,6 @@ const ConfigBtn: FC = ({ setOpen(!openRef.current) }, [setOpen]) - useEffect(() => { - if (controlShowPopup) - // setOpen(!openRef.current) - setOpen(true) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [controlShowPopup]) - if (popupProps.readOnly && !hasConfigured) return null @@ -52,14 +42,11 @@ const ConfigBtn: FC = ({ open={open} onOpenChange={setOpen} placement='bottom-end' - offset={{ - mainAxis: 12, - crossAxis: hasConfigured ? 8 : 49, - }} + offset={12} > -
- +
+ {children}
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index d082523222..7564a0f3c8 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -1,8 +1,9 @@ 'use client' import type { FC } from 'react' -import React, { useCallback, useEffect, useState } from 'react' +import React, { useEffect, useState } from 'react' import { RiArrowDownDoubleLine, + RiEqualizer2Line, } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { usePathname } from 'next/navigation' @@ -180,10 +181,6 @@ const Panel: FC = () => { })() }, []) - const [controlShowPopup, setControlShowPopup] = useState(0) - const showPopup = useCallback(() => { - setControlShowPopup(Date.now()) - }, [setControlShowPopup]) if (!isLoaded) { return (
@@ -196,46 +193,66 @@ const Panel: FC = () => { return (
-
- {!inUseTracingProvider && ( - <> + {!inUseTracingProvider && ( + +
{t(`${I18N_PREFIX}.title`)}
-
e.stopPropagation()}> - +
+
- - )} - {hasConfiguredTracing && ( - <> +
+ + )} + {hasConfiguredTracing && ( + +
@@ -243,33 +260,14 @@ const Panel: FC = () => {
{InUseProviderIcon && } - -
e.stopPropagation()}> - +
+
- - )} -
-
+ +
+
+ )} +
) } export default React.memo(Panel) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 426778c835..d70179266a 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -56,33 +56,50 @@ const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => { }, [isMobile, setShowTips]) return
- {hasRelatedApps && ( - <> - {!isMobile && ( - - } - > -
- {relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')} - -
-
- )} + {/* Related apps for desktop */} +
+ + } + > +
+ {relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')} + +
+
+
- {isMobile &&
- {relatedAppsTotal || '--'} - -
} - - )} - {!hasRelatedApps && !expand && ( + {/* Related apps for mobile */} +
+
+ {relatedAppsTotal || '--'} + +
+
+ + {/* No related apps tooltip */} +
{
} > -
+
{t('common.datasetMenus.noRelatedApp')}
- )} +
} diff --git a/web/app/(commonLayout)/datasets/Doc.tsx b/web/app/(commonLayout)/datasets/Doc.tsx deleted file mode 100644 index 042a90f4af..0000000000 --- a/web/app/(commonLayout)/datasets/Doc.tsx +++ /dev/null @@ -1,131 +0,0 @@ -'use client' - -import { useEffect, useMemo, useState } from 'react' -import { useContext } from 'use-context-selector' -import { useTranslation } from 'react-i18next' -import { RiListUnordered } from '@remixicon/react' -import TemplateEn from './template/template.en.mdx' -import TemplateZh from './template/template.zh.mdx' -import TemplateJa from './template/template.ja.mdx' -import I18n from '@/context/i18n' -import { LanguagesSupported } from '@/i18n-config/language' -import useTheme from '@/hooks/use-theme' -import { Theme } from '@/types/app' -import cn from '@/utils/classnames' - -type DocProps = { - apiBaseUrl: string -} - -const Doc = ({ apiBaseUrl }: DocProps) => { - const { locale } = useContext(I18n) - const { t } = useTranslation() - const [toc, setToc] = useState>([]) - const [isTocExpanded, setIsTocExpanded] = useState(false) - const { theme } = useTheme() - - // Set initial TOC expanded state based on screen width - useEffect(() => { - const mediaQuery = window.matchMedia('(min-width: 1280px)') - setIsTocExpanded(mediaQuery.matches) - }, []) - - // Extract TOC from article content - useEffect(() => { - const extractTOC = () => { - const article = document.querySelector('article') - if (article) { - const headings = article.querySelectorAll('h2') - const tocItems = Array.from(headings).map((heading) => { - const anchor = heading.querySelector('a') - if (anchor) { - return { - href: anchor.getAttribute('href') || '', - text: anchor.textContent || '', - } - } - return null - }).filter((item): item is { href: string; text: string } => item !== null) - setToc(tocItems) - } - } - - setTimeout(extractTOC, 0) - }, [locale]) - - // Handle TOC item click - const handleTocClick = (e: React.MouseEvent, item: { href: string; text: string }) => { - e.preventDefault() - const targetId = item.href.replace('#', '') - const element = document.getElementById(targetId) - if (element) { - const scrollContainer = document.querySelector('.scroll-container') - if (scrollContainer) { - const headerOffset = -40 - const elementTop = element.offsetTop - headerOffset - scrollContainer.scrollTo({ - top: elementTop, - behavior: 'smooth', - }) - } - } - } - - const Template = useMemo(() => { - switch (locale) { - case LanguagesSupported[1]: - return - case LanguagesSupported[7]: - return - default: - return - } - }, [apiBaseUrl, locale]) - - return ( -
-
- {isTocExpanded - ? ( - - ) - : ( - - )} -
-
- {Template} -
-
- ) -} - -export default Doc diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/container.tsx similarity index 98% rename from web/app/(commonLayout)/datasets/Container.tsx rename to web/app/(commonLayout)/datasets/container.tsx index 112b6a752e..444119332b 100644 --- a/web/app/(commonLayout)/datasets/Container.tsx +++ b/web/app/(commonLayout)/datasets/container.tsx @@ -9,10 +9,10 @@ import { useQuery } from '@tanstack/react-query' // Components import ExternalAPIPanel from '../../components/datasets/external-api/external-api-panel' -import Datasets from './Datasets' -import DatasetFooter from './DatasetFooter' +import Datasets from './datasets' +import DatasetFooter from './dataset-footer' import ApiServer from '../../components/develop/ApiServer' -import Doc from './Doc' +import Doc from './doc' import TabSliderNew from '@/app/components/base/tab-slider-new' import TagManagementModal from '@/app/components/base/tag-management' import TagFilter from '@/app/components/base/tag-management/filter' diff --git a/web/app/(commonLayout)/datasets/create/page.tsx b/web/app/(commonLayout)/datasets/create/page.tsx index 663a830665..50fd1f5a19 100644 --- a/web/app/(commonLayout)/datasets/create/page.tsx +++ b/web/app/(commonLayout)/datasets/create/page.tsx @@ -1,9 +1,7 @@ import React from 'react' import DatasetUpdateForm from '@/app/components/datasets/create' -type Props = {} - -const DatasetCreation = async (props: Props) => { +const DatasetCreation = async () => { return ( ) diff --git a/web/app/(commonLayout)/datasets/DatasetCard.tsx b/web/app/(commonLayout)/datasets/dataset-card.tsx similarity index 93% rename from web/app/(commonLayout)/datasets/DatasetCard.tsx rename to web/app/(commonLayout)/datasets/dataset-card.tsx index 4b40be2c7f..2f0563d47e 100644 --- a/web/app/(commonLayout)/datasets/DatasetCard.tsx +++ b/web/app/(commonLayout)/datasets/dataset-card.tsx @@ -5,6 +5,7 @@ import { useRouter } from 'next/navigation' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { RiMoreFill } from '@remixicon/react' +import { mutate } from 'swr' import cn from '@/utils/classnames' import Confirm from '@/app/components/base/confirm' import { ToastContext } from '@/app/components/base/toast' @@ -57,6 +58,19 @@ const DatasetCard = ({ const onConfirmDelete = useCallback(async () => { try { await deleteDataset(dataset.id) + + // Clear SWR cache to prevent stale data in knowledge retrieval nodes + mutate( + (key) => { + if (typeof key === 'string') return key.includes('/datasets') + if (typeof key === 'object' && key !== null) + return key.url === '/datasets' || key.url?.includes('/datasets') + return false + }, + undefined, + { revalidate: true }, + ) + notify({ type: 'success', message: t('dataset.datasetDeleted') }) if (onSuccess) onSuccess() @@ -162,24 +176,19 @@ const DatasetCard = ({
{dataset.description}
-
+
{ e.stopPropagation() e.preventDefault() }}>
+ containerRef: React.RefObject tags: string[] keywords: string includeAll: boolean diff --git a/web/app/(commonLayout)/datasets/doc.tsx b/web/app/(commonLayout)/datasets/doc.tsx new file mode 100644 index 0000000000..b31e0a4161 --- /dev/null +++ b/web/app/(commonLayout)/datasets/doc.tsx @@ -0,0 +1,203 @@ +'use client' + +import { useEffect, useMemo, useState } from 'react' +import { useContext } from 'use-context-selector' +import { useTranslation } from 'react-i18next' +import { RiCloseLine, RiListUnordered } from '@remixicon/react' +import TemplateEn from './template/template.en.mdx' +import TemplateZh from './template/template.zh.mdx' +import TemplateJa from './template/template.ja.mdx' +import I18n from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import cn from '@/utils/classnames' + +type DocProps = { + apiBaseUrl: string +} + +const Doc = ({ apiBaseUrl }: DocProps) => { + const { locale } = useContext(I18n) + const { t } = useTranslation() + const [toc, setToc] = useState>([]) + const [isTocExpanded, setIsTocExpanded] = useState(false) + const [activeSection, setActiveSection] = useState('') + const { theme } = useTheme() + + // Set initial TOC expanded state based on screen width + useEffect(() => { + const mediaQuery = window.matchMedia('(min-width: 1280px)') + setIsTocExpanded(mediaQuery.matches) + }, []) + + // Extract TOC from article content + useEffect(() => { + const extractTOC = () => { + const article = document.querySelector('article') + if (article) { + const headings = article.querySelectorAll('h2') + const tocItems = Array.from(headings).map((heading) => { + const anchor = heading.querySelector('a') + if (anchor) { + return { + href: anchor.getAttribute('href') || '', + text: anchor.textContent || '', + } + } + return null + }).filter((item): item is { href: string; text: string } => item !== null) + setToc(tocItems) + // Set initial active section + if (tocItems.length > 0) + setActiveSection(tocItems[0].href.replace('#', '')) + } + } + + setTimeout(extractTOC, 0) + }, [locale]) + + // Track scroll position for active section highlighting + useEffect(() => { + const handleScroll = () => { + const scrollContainer = document.querySelector('.scroll-container') + if (!scrollContainer || toc.length === 0) + return + + // Find active section based on scroll position + let currentSection = '' + toc.forEach((item) => { + const targetId = item.href.replace('#', '') + const element = document.getElementById(targetId) + if (element) { + const rect = element.getBoundingClientRect() + // Consider section active if its top is above the middle of viewport + if (rect.top <= window.innerHeight / 2) + currentSection = targetId + } + }) + + if (currentSection && currentSection !== activeSection) + setActiveSection(currentSection) + } + + const scrollContainer = document.querySelector('.scroll-container') + if (scrollContainer) { + scrollContainer.addEventListener('scroll', handleScroll) + handleScroll() // Initial check + return () => scrollContainer.removeEventListener('scroll', handleScroll) + } + }, [toc, activeSection]) + + // Handle TOC item click + const handleTocClick = (e: React.MouseEvent, item: { href: string; text: string }) => { + e.preventDefault() + const targetId = item.href.replace('#', '') + const element = document.getElementById(targetId) + if (element) { + const scrollContainer = document.querySelector('.scroll-container') + if (scrollContainer) { + const headerOffset = -40 + const elementTop = element.offsetTop - headerOffset + scrollContainer.scrollTo({ + top: elementTop, + behavior: 'smooth', + }) + } + } + } + + const Template = useMemo(() => { + switch (locale) { + case LanguagesSupported[1]: + return + case LanguagesSupported[7]: + return + default: + return + } + }, [apiBaseUrl, locale]) + + return ( +
+
+ {isTocExpanded + ? ( + + ) + : ( + + )} +
+
+ {Template} +
+
+ ) +} + +export default Doc diff --git a/web/app/(commonLayout)/datasets/NewDatasetCard.tsx b/web/app/(commonLayout)/datasets/new-dataset-card.tsx similarity index 100% rename from web/app/(commonLayout)/datasets/NewDatasetCard.tsx rename to web/app/(commonLayout)/datasets/new-dataset-card.tsx diff --git a/web/app/(commonLayout)/datasets/page.tsx b/web/app/(commonLayout)/datasets/page.tsx index 60a542f0a2..cbfe25ebd2 100644 --- a/web/app/(commonLayout)/datasets/page.tsx +++ b/web/app/(commonLayout)/datasets/page.tsx @@ -1,6 +1,6 @@ 'use client' import { useTranslation } from 'react-i18next' -import Container from './Container' +import Container from './container' import useDocumentTitle from '@/hooks/use-document-title' const AppList = () => { diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index ebb2e6a806..f1bb5d9156 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
Okay, I will translate the Chinese text in your document while keeping all formatting and code content unchanged. -
+
-
+
-
+
-
+
-
+
-
+
-
+
diff --git a/web/app/(commonLayout)/datasets/template/template.ja.mdx b/web/app/(commonLayout)/datasets/template/template.ja.mdx index 6c0e20e1bb..3011cecbc1 100644 --- a/web/app/(commonLayout)/datasets/template/template.ja.mdx +++ b/web/app/(commonLayout)/datasets/template/template.ja.mdx @@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index c21ce3bf5f..b7ea889a46 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
@@ -1915,7 +1915,7 @@ ___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx index da754794b1..91e1021610 100644 --- a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -70,7 +70,10 @@ export default function CheckCode() {

{t('login.checkCode.checkYourEmail')}

- + + {t('login.checkCode.tipsPrefix')} + {email} +
{t('login.checkCode.validTime')}

diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index a2ba620ace..c80a006583 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -93,7 +93,10 @@ export default function CheckCode() {

{t('login.checkCode.checkYourEmail')}

- + + {t('login.checkCode.tipsPrefix')} + {email} +
{t('login.checkCode.validTime')}

diff --git a/web/app/account/account-page/AvatarWithEdit.tsx b/web/app/account/account-page/AvatarWithEdit.tsx index 8250789def..41a6971bf5 100644 --- a/web/app/account/account-page/AvatarWithEdit.tsx +++ b/web/app/account/account-page/AvatarWithEdit.tsx @@ -87,7 +87,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
{ setIsShowAvatarPicker(true) }} - className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black bg-opacity-50 opacity-0 transition-opacity group-hover:opacity-100" + className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black/50 opacity-0 transition-opacity group-hover:opacity-100" > diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 58c9f7e5ca..288dcf8c8b 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -12,7 +12,6 @@ import { RiFileUploadLine, } from '@remixicon/react' import AppIcon from '../base/app-icon' -import cn from '@/utils/classnames' import { useStore as useAppStore } from '@/app/components/app/store' import { ToastContext } from '@/app/components/base/toast' import { useAppContext } from '@/context/app-context' @@ -31,6 +30,7 @@ import Divider from '../base/divider' import type { Operation } from './app-operations' import AppOperations from './app-operations' import dynamic from 'next/dynamic' +import cn from '@/utils/classnames' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false, @@ -256,31 +256,40 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx }} className='block w-full' > -
-
- -
-
+
+
+
+ +
+ {expand && ( +
+
+ +
+
+ )} +
+ {!expand && ( +
+
-
- { - expand && ( -
-
-
{appDetail.name}
-
-
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
+ )} + {expand && ( +
+
+
{appDetail.name}
- ) - } +
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
+
+ )}
)} diff --git a/web/app/components/app-sidebar/dataset-info.tsx b/web/app/components/app-sidebar/dataset-info.tsx index 73740133ce..eee7cb3a2e 100644 --- a/web/app/components/app-sidebar/dataset-info.tsx +++ b/web/app/components/app-sidebar/dataset-info.tsx @@ -29,15 +29,17 @@ const DatasetInfo: FC = ({
- {expand && ( -
-
- {name} -
-
{isExternal ? t('dataset.externalTag') : t('dataset.localDocs')}
-
{description}
+
+
+ {name}
- )} +
{isExternal ? t('dataset.externalTag') : t('dataset.localDocs')}
+
{description}
+
{extraInfo}
) diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index b6bfc0e9ac..cf32339b8a 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -124,10 +124,7 @@ const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigati { !isMobile && (
({ + useSelectedLayoutSegment: () => 'overview', +})) + +// Mock Next.js Link component +jest.mock('next/link', () => { + return function MockLink({ children, href, className, title }: any) { + return ( + + {children} + + ) + } +}) + +// Mock RemixIcon components +const MockIcon = ({ className }: { className?: string }) => ( + +) + +describe('NavLink Text Animation Issues', () => { + const mockProps: NavLinkProps = { + name: 'Orchestrate', + href: '/app/123/workflow', + iconMap: { + selected: MockIcon, + normal: MockIcon, + }, + } + + beforeEach(() => { + // Mock getComputedStyle for transition testing + Object.defineProperty(window, 'getComputedStyle', { + value: jest.fn((element) => { + const isExpanded = element.getAttribute('data-mode') === 'expand' + return { + transition: 'all 0.3s ease', + opacity: isExpanded ? '1' : '0', + width: isExpanded ? 'auto' : '0px', + overflow: 'hidden', + paddingLeft: isExpanded ? '12px' : '10px', // px-3 vs px-2.5 + paddingRight: isExpanded ? '12px' : '10px', + } + }), + writable: true, + }) + }) + + describe('Text Squeeze Animation Issue', () => { + it('should show text squeeze effect when switching from collapse to expand', async () => { + const { rerender } = render() + + // In collapse mode, text should be in DOM but hidden via CSS + const textElement = screen.getByText('Orchestrate') + expect(textElement).toBeInTheDocument() + expect(textElement).toHaveClass('opacity-0') + expect(textElement).toHaveClass('w-0') + expect(textElement).toHaveClass('overflow-hidden') + + // Icon should still be present + expect(screen.getByTestId('nav-icon')).toBeInTheDocument() + + // Check padding in collapse mode + const linkElement = screen.getByTestId('nav-link') + expect(linkElement).toHaveClass('px-2.5') + + // Switch to expand mode - this is where the squeeze effect occurs + rerender() + + // Text should now appear + expect(screen.getByText('Orchestrate')).toBeInTheDocument() + + // Check padding change - this contributes to the squeeze effect + expect(linkElement).toHaveClass('px-3') + + // The bug: text appears abruptly without smooth transition + // This test documents the current behavior that causes the squeeze effect + const expandedTextElement = screen.getByText('Orchestrate') + expect(expandedTextElement).toBeInTheDocument() + + // In a properly animated version, we would expect: + // - Opacity transition from 0 to 1 + // - Width transition from 0 to auto + // - No layout shift from padding changes + }) + + it('should maintain icon position consistency during text appearance', () => { + const { rerender } = render() + + const iconElement = screen.getByTestId('nav-icon') + const initialIconClasses = iconElement.className + + // Icon should have mr-0 in collapse mode + expect(iconElement).toHaveClass('mr-0') + + rerender() + + const expandedIconClasses = iconElement.className + + // Icon should have mr-2 in expand mode - this shift contributes to the squeeze effect + expect(iconElement).toHaveClass('mr-2') + + console.log('Collapsed icon classes:', initialIconClasses) + console.log('Expanded icon classes:', expandedIconClasses) + + // This margin change causes the icon to shift when text appears + }) + + it('should document the abrupt text rendering issue', () => { + const { rerender } = render() + + // Text is present in DOM but hidden via CSS classes + const collapsedText = screen.getByText('Orchestrate') + expect(collapsedText).toBeInTheDocument() + expect(collapsedText).toHaveClass('opacity-0') + expect(collapsedText).toHaveClass('pointer-events-none') + + rerender() + + // Text suddenly appears in DOM - no transition + expect(screen.getByText('Orchestrate')).toBeInTheDocument() + + // The issue: {mode === 'expand' && name} causes abrupt show/hide + // instead of smooth opacity/width transition + }) + }) + + describe('Layout Shift Issues', () => { + it('should detect padding differences causing layout shifts', () => { + const { rerender } = render() + + const linkElement = screen.getByTestId('nav-link') + + // Collapsed state padding + expect(linkElement).toHaveClass('px-2.5') + + rerender() + + // Expanded state padding - different value causes layout shift + expect(linkElement).toHaveClass('px-3') + + // This 2px difference (10px vs 12px) contributes to the squeeze effect + }) + + it('should detect icon margin changes causing shifts', () => { + const { rerender } = render() + + const iconElement = screen.getByTestId('nav-icon') + + // Collapsed: no right margin + expect(iconElement).toHaveClass('mr-0') + + rerender() + + // Expanded: 8px right margin (mr-2) + expect(iconElement).toHaveClass('mr-2') + + // This sudden margin appearance causes the squeeze effect + }) + }) + + describe('Active State Handling', () => { + it('should handle active state correctly in both modes', () => { + // Test non-active state + const { rerender } = render() + + let linkElement = screen.getByTestId('nav-link') + expect(linkElement).not.toHaveClass('bg-state-accent-active') + + // Test with active state (when href matches current segment) + const activeProps = { + ...mockProps, + href: '/app/123/overview', // matches mocked segment + } + + rerender() + + linkElement = screen.getByTestId('nav-link') + expect(linkElement).toHaveClass('bg-state-accent-active') + }) + }) +}) diff --git a/web/app/components/app-sidebar/navLink.tsx b/web/app/components/app-sidebar/navLink.tsx index 295b553b04..4607f7b693 100644 --- a/web/app/components/app-sidebar/navLink.tsx +++ b/web/app/components/app-sidebar/navLink.tsx @@ -44,20 +44,29 @@ export default function NavLink({ key={name} href={href} className={classNames( - isActive ? 'bg-state-accent-active text-text-accent font-semibold' : 'text-components-menu-item-text hover:bg-state-base-hover hover:text-components-menu-item-text-hover', - 'group flex items-center h-9 rounded-md py-2 text-sm font-normal', + isActive ? 'bg-state-accent-active font-semibold text-text-accent' : 'text-components-menu-item-text hover:bg-state-base-hover hover:text-components-menu-item-text-hover', + 'group flex h-9 items-center rounded-md py-2 text-sm font-normal', mode === 'expand' ? 'px-3' : 'px-2.5', )} title={mode === 'collapse' ? name : ''} >