diff --git a/.env.example b/.env.example deleted file mode 100644 index 3e95f2e982..0000000000 --- a/.env.example +++ /dev/null @@ -1,1197 +0,0 @@ -# ------------------------------ -# Environment Variables for API service & worker -# ------------------------------ - -# ------------------------------ -# Common Variables -# ------------------------------ - -# The backend URL of the console API, -# used to concatenate the authorization callback. -# If empty, it is the same domain. -# Example: https://api.console.dify.ai -CONSOLE_API_URL= - -# The front-end URL of the console web, -# used to concatenate some front-end addresses and for CORS configuration use. -# If empty, it is the same domain. -# Example: https://console.dify.ai -CONSOLE_WEB_URL= - -# Service API Url, -# used to display Service API Base Url to the front-end. -# If empty, it is the same domain. -# Example: https://api.dify.ai -SERVICE_API_URL= - -# WebApp API backend Url, -# used to declare the back-end URL for the front-end API. -# If empty, it is the same domain. -# Example: https://api.app.dify.ai -APP_API_URL= - -# WebApp Url, -# used to display WebAPP API Base Url to the front-end. -# If empty, it is the same domain. -# Example: https://app.dify.ai -APP_WEB_URL= - -# File preview or download Url prefix. -# used to display File preview or download Url to the front-end or as Multi-model inputs; -# Url is signed and has expiration time. -# Setting FILES_URL is required for file processing plugins. -# - For https://example.com, use FILES_URL=https://example.com -# - For http://example.com, use FILES_URL=http://example.com -# Recommendation: use a dedicated domain (e.g., https://upload.example.com). -# Alternatively, use http://:5001 or http://api:5001, -# ensuring port 5001 is externally accessible (see docker-compose.yaml). -FILES_URL= - -# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network. -# Set this to the internal Docker service URL for proper plugin file access. -# Example: INTERNAL_FILES_URL=http://api:5001 -INTERNAL_FILES_URL= - -# ------------------------------ -# Server Configuration -# ------------------------------ - -# The log level for the application. -# Supported values are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` -LOG_LEVEL=INFO -# Log file path -LOG_FILE=/app/logs/server.log -# Log file max size, the unit is MB -LOG_FILE_MAX_SIZE=20 -# Log file max backup count -LOG_FILE_BACKUP_COUNT=5 -# Log dateformat -LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S -# Log Timezone -LOG_TZ=UTC - -# Debug mode, default is false. -# It is recommended to turn on this configuration for local development -# to prevent some problems caused by monkey patch. -DEBUG=false - -# Flask debug mode, it can output trace information at the interface when turned on, -# which is convenient for debugging. -FLASK_DEBUG=false - -# Enable request logging, which will log the request and response information. -# And the log level is DEBUG -ENABLE_REQUEST_LOGGING=False - -# A secret key that is used for securely signing the session cookie -# and encrypting sensitive information on the database. -# You can generate a strong key using `openssl rand -base64 42`. -SECRET_KEY=sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U - -# Password for admin user initialization. -# If left unset, admin user will not be prompted for a password -# when creating the initial admin account. -# The length of the password cannot exceed 30 characters. -INIT_PASSWORD= - -# Deployment environment. -# Supported values are `PRODUCTION`, `TESTING`. Default is `PRODUCTION`. -# Testing environment. There will be a distinct color label on the front-end page, -# indicating that this environment is a testing environment. -DEPLOY_ENV=PRODUCTION - -# Whether to enable the version check policy. -# If set to empty, https://updates.dify.ai will be called for version check. -CHECK_UPDATE_URL=https://updates.dify.ai - -# Used to change the OpenAI base address, default is https://api.openai.com/v1. -# When OpenAI cannot be accessed in China, replace it with a domestic mirror address, -# or when a local model provides OpenAI compatible API, it can be replaced. -OPENAI_API_BASE=https://api.openai.com/v1 - -# When enabled, migrations will be executed prior to application startup -# and the application will start after the migrations have completed. -MIGRATION_ENABLED=true - -# File Access Time specifies a time interval in seconds for the file to be accessed. -# The default value is 300 seconds. -FILES_ACCESS_TIMEOUT=300 - -# Access token expiration time in minutes -ACCESS_TOKEN_EXPIRE_MINUTES=60 - -# Refresh token expiration time in days -REFRESH_TOKEN_EXPIRE_DAYS=30 - -# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. -APP_MAX_ACTIVE_REQUESTS=0 -APP_MAX_EXECUTION_TIME=1200 - -# ------------------------------ -# Container Startup Related Configuration -# Only effective when starting with docker image or docker-compose. -# ------------------------------ - -# API service binding address, default: 0.0.0.0, i.e., all addresses can be accessed. -DIFY_BIND_ADDRESS=0.0.0.0 - -# API service binding port number, default 5001. -DIFY_PORT=5001 - -# The number of API server workers, i.e., the number of workers. -# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent -# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers -SERVER_WORKER_AMOUNT=1 - -# Defaults to gevent. If using windows, it can be switched to sync or solo. -SERVER_WORKER_CLASS=gevent - -# Default number of worker connections, the default is 10. -SERVER_WORKER_CONNECTIONS=10 - -# Similar to SERVER_WORKER_CLASS. -# If using windows, it can be switched to sync or solo. -CELERY_WORKER_CLASS= - -# Request handling timeout. The default is 200, -# it is recommended to set it to 360 to support a longer sse connection time. -GUNICORN_TIMEOUT=360 - -# The number of Celery workers. The default is 1, and can be set as needed. -CELERY_WORKER_AMOUNT= - -# Flag indicating whether to enable autoscaling of Celery workers. -# -# Autoscaling is useful when tasks are CPU intensive and can be dynamically -# allocated and deallocated based on the workload. -# -# When autoscaling is enabled, the maximum and minimum number of workers can -# be specified. The autoscaling algorithm will dynamically adjust the number -# of workers within the specified range. -# -# Default is false (i.e., autoscaling is disabled). -# -# Example: -# CELERY_AUTO_SCALE=true -CELERY_AUTO_SCALE=false - -# The maximum number of Celery workers that can be autoscaled. -# This is optional and only used when autoscaling is enabled. -# Default is not set. -CELERY_MAX_WORKERS= - -# The minimum number of Celery workers that can be autoscaled. -# This is optional and only used when autoscaling is enabled. -# Default is not set. -CELERY_MIN_WORKERS= - -# API Tool configuration -API_TOOL_DEFAULT_CONNECT_TIMEOUT=10 -API_TOOL_DEFAULT_READ_TIMEOUT=60 - -# ------------------------------- -# Datasource Configuration -# -------------------------------- -ENABLE_WEBSITE_JINAREADER=true -ENABLE_WEBSITE_FIRECRAWL=true -ENABLE_WEBSITE_WATERCRAWL=true - -# ------------------------------ -# Database Configuration -# The database uses PostgreSQL. Please use the public schema. -# It is consistent with the configuration in the 'db' service below. -# ------------------------------ - -DB_USERNAME=postgres -DB_PASSWORD=difyai123456 -DB_HOST=db -DB_PORT=5432 -DB_DATABASE=dify -# The size of the database connection pool. -# The default is 30 connections, which can be appropriately increased. -SQLALCHEMY_POOL_SIZE=30 -# Database connection pool recycling time, the default is 3600 seconds. -SQLALCHEMY_POOL_RECYCLE=3600 -# Whether to print SQL, default is false. -SQLALCHEMY_ECHO=false -# If True, will test connections for liveness upon each checkout -SQLALCHEMY_POOL_PRE_PING=false -# Whether to enable the Last in first out option or use default FIFO queue if is false -SQLALCHEMY_POOL_USE_LIFO=false - -# Maximum number of connections to the database -# Default is 100 -# -# Reference: https://www.postgresql.org/docs/current/runtime-config-connection.html#GUC-MAX-CONNECTIONS -POSTGRES_MAX_CONNECTIONS=100 - -# Sets the amount of shared memory used for postgres's shared buffers. -# Default is 128MB -# Recommended value: 25% of available memory -# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-SHARED-BUFFERS -POSTGRES_SHARED_BUFFERS=128MB - -# Sets the amount of memory used by each database worker for working space. -# Default is 4MB -# -# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-WORK-MEM -POSTGRES_WORK_MEM=4MB - -# Sets the amount of memory reserved for maintenance activities. -# Default is 64MB -# -# Reference: https://www.postgresql.org/docs/current/runtime-config-resource.html#GUC-MAINTENANCE-WORK-MEM -POSTGRES_MAINTENANCE_WORK_MEM=64MB - -# Sets the planner's assumption about the effective cache size. -# Default is 4096MB -# -# Reference: https://www.postgresql.org/docs/current/runtime-config-query.html#GUC-EFFECTIVE-CACHE-SIZE -POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB - -# ------------------------------ -# Redis Configuration -# This Redis configuration is used for caching and for pub/sub during conversation. -# ------------------------------ - -REDIS_HOST=redis -REDIS_PORT=6379 -REDIS_USERNAME= -REDIS_PASSWORD=difyai123456 -REDIS_USE_SSL=false -REDIS_DB=0 - -# Whether to use Redis Sentinel mode. -# If set to true, the application will automatically discover and connect to the master node through Sentinel. -REDIS_USE_SENTINEL=false - -# List of Redis Sentinel nodes. If Sentinel mode is enabled, provide at least one Sentinel IP and port. -# Format: `:,:,:` -REDIS_SENTINELS= -REDIS_SENTINEL_SERVICE_NAME= -REDIS_SENTINEL_USERNAME= -REDIS_SENTINEL_PASSWORD= -REDIS_SENTINEL_SOCKET_TIMEOUT=0.1 - -# List of Redis Cluster nodes. If Cluster mode is enabled, provide at least one Cluster IP and port. -# Format: `:,:,:` -REDIS_USE_CLUSTERS=false -REDIS_CLUSTERS= -REDIS_CLUSTERS_PASSWORD= - -# ------------------------------ -# Celery Configuration -# ------------------------------ - -# Use redis as the broker, and redis db 1 for celery broker. -# Format as follows: `redis://:@:/` -# Example: redis://:difyai123456@redis:6379/1 -# If use Redis Sentinel, format as follows: `sentinel://:@:/` -# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1 -CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1 -BROKER_USE_SSL=false - -# If you are using Redis Sentinel for high availability, configure the following settings. -CELERY_USE_SENTINEL=false -CELERY_SENTINEL_MASTER_NAME= -CELERY_SENTINEL_PASSWORD= -CELERY_SENTINEL_SOCKET_TIMEOUT=0.1 - -# ------------------------------ -# CORS Configuration -# Used to set the front-end cross-domain access policy. -# ------------------------------ - -# Specifies the allowed origins for cross-origin requests to the Web API, -# e.g. https://dify.app or * for all origins. -WEB_API_CORS_ALLOW_ORIGINS=* - -# Specifies the allowed origins for cross-origin requests to the console API, -# e.g. https://cloud.dify.ai or * for all origins. -CONSOLE_CORS_ALLOW_ORIGINS=* - -# ------------------------------ -# File Storage Configuration -# ------------------------------ - -# The type of storage to use for storing user files. -STORAGE_TYPE=opendal - -# Apache OpenDAL Configuration -# The configuration for OpenDAL consists of the following format: OPENDAL__. -# You can find all the service configurations (CONFIG_NAME) in the repository at: https://github.com/apache/opendal/tree/main/core/src/services. -# Dify will scan configurations starting with OPENDAL_ and automatically apply them. -# The scheme name for the OpenDAL storage. -OPENDAL_SCHEME=fs -# Configurations for OpenDAL Local File System. -OPENDAL_FS_ROOT=storage - -# ClickZetta Volume Configuration (for storage backend) -# To use ClickZetta Volume as storage backend, set STORAGE_TYPE=clickzetta-volume -# Note: ClickZetta Volume will reuse the existing CLICKZETTA_* connection parameters - -# Volume type selection (three types available): -# - user: Personal/small team use, simple config, user-level permissions -# - table: Enterprise multi-tenant, smart routing, table-level + user-level permissions -# - external: Data lake integration, external storage connection, volume-level + storage-level permissions -CLICKZETTA_VOLUME_TYPE=user - -# External Volume name (required only when TYPE=external) -CLICKZETTA_VOLUME_NAME= - -# Table Volume table prefix (used only when TYPE=table) -CLICKZETTA_VOLUME_TABLE_PREFIX=dataset_ - -# Dify file directory prefix (isolates from other apps, recommended to keep default) -CLICKZETTA_VOLUME_DIFY_PREFIX=dify_km - -# S3 Configuration -# -S3_ENDPOINT= -S3_REGION=us-east-1 -S3_BUCKET_NAME=difyai -S3_ACCESS_KEY= -S3_SECRET_KEY= -# Whether to use AWS managed IAM roles for authenticating with the S3 service. -# If set to false, the access key and secret key must be provided. -S3_USE_AWS_MANAGED_IAM=false - -# Azure Blob Configuration -# -AZURE_BLOB_ACCOUNT_NAME=difyai -AZURE_BLOB_ACCOUNT_KEY=difyai -AZURE_BLOB_CONTAINER_NAME=difyai-container -AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net - -# Google Storage Configuration -# -GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name -GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64= - -# The Alibaba Cloud OSS configurations, -# -ALIYUN_OSS_BUCKET_NAME=your-bucket-name -ALIYUN_OSS_ACCESS_KEY=your-access-key -ALIYUN_OSS_SECRET_KEY=your-secret-key -ALIYUN_OSS_ENDPOINT=https://oss-ap-southeast-1-internal.aliyuncs.com -ALIYUN_OSS_REGION=ap-southeast-1 -ALIYUN_OSS_AUTH_VERSION=v4 -# Don't start with '/'. OSS doesn't support leading slash in object names. -ALIYUN_OSS_PATH=your-path - -# Tencent COS Configuration -# -TENCENT_COS_BUCKET_NAME=your-bucket-name -TENCENT_COS_SECRET_KEY=your-secret-key -TENCENT_COS_SECRET_ID=your-secret-id -TENCENT_COS_REGION=your-region -TENCENT_COS_SCHEME=your-scheme - -# Oracle Storage Configuration -# -OCI_ENDPOINT=https://your-object-storage-namespace.compat.objectstorage.us-ashburn-1.oraclecloud.com -OCI_BUCKET_NAME=your-bucket-name -OCI_ACCESS_KEY=your-access-key -OCI_SECRET_KEY=your-secret-key -OCI_REGION=us-ashburn-1 - -# Huawei OBS Configuration -# -HUAWEI_OBS_BUCKET_NAME=your-bucket-name -HUAWEI_OBS_SECRET_KEY=your-secret-key -HUAWEI_OBS_ACCESS_KEY=your-access-key -HUAWEI_OBS_SERVER=your-server-url - -# Volcengine TOS Configuration -# -VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name -VOLCENGINE_TOS_SECRET_KEY=your-secret-key -VOLCENGINE_TOS_ACCESS_KEY=your-access-key -VOLCENGINE_TOS_ENDPOINT=your-server-url -VOLCENGINE_TOS_REGION=your-region - -# Baidu OBS Storage Configuration -# -BAIDU_OBS_BUCKET_NAME=your-bucket-name -BAIDU_OBS_SECRET_KEY=your-secret-key -BAIDU_OBS_ACCESS_KEY=your-access-key -BAIDU_OBS_ENDPOINT=your-server-url - -# Supabase Storage Configuration -# -SUPABASE_BUCKET_NAME=your-bucket-name -SUPABASE_API_KEY=your-access-key -SUPABASE_URL=your-server-url - -# ------------------------------ -# Vector Database Configuration -# ------------------------------ - -# The type of vector store to use. -# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. -VECTOR_STORE=weaviate - -# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. -WEAVIATE_ENDPOINT=http://weaviate:8080 -WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih - -# The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`. -QDRANT_URL=http://qdrant:6333 -QDRANT_API_KEY=difyai123456 -QDRANT_CLIENT_TIMEOUT=20 -QDRANT_GRPC_ENABLED=false -QDRANT_GRPC_PORT=6334 -QDRANT_REPLICATION_FACTOR=1 - -# Milvus configuration. Only available when VECTOR_STORE is `milvus`. -# The milvus uri. -MILVUS_URI=http://host.docker.internal:19530 -MILVUS_DATABASE= -MILVUS_TOKEN= -MILVUS_USER= -MILVUS_PASSWORD= -MILVUS_ENABLE_HYBRID_SEARCH=False -MILVUS_ANALYZER_PARAMS= - -# MyScale configuration, only available when VECTOR_STORE is `myscale` -# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to: -# https://myscale.com/docs/en/text-search/#understanding-fts-index-parameters -MYSCALE_HOST=myscale -MYSCALE_PORT=8123 -MYSCALE_USER=default -MYSCALE_PASSWORD= -MYSCALE_DATABASE=dify -MYSCALE_FTS_PARAMS= - -# Couchbase configurations, only available when VECTOR_STORE is `couchbase` -# The connection string must include hostname defined in the docker-compose file (couchbase-server in this case) -COUCHBASE_CONNECTION_STRING=couchbase://couchbase-server -COUCHBASE_USER=Administrator -COUCHBASE_PASSWORD=password -COUCHBASE_BUCKET_NAME=Embeddings -COUCHBASE_SCOPE_NAME=_default - -# pgvector configurations, only available when VECTOR_STORE is `pgvector` -PGVECTOR_HOST=pgvector -PGVECTOR_PORT=5432 -PGVECTOR_USER=postgres -PGVECTOR_PASSWORD=difyai123456 -PGVECTOR_DATABASE=dify -PGVECTOR_MIN_CONNECTION=1 -PGVECTOR_MAX_CONNECTION=5 -PGVECTOR_PG_BIGM=false -PGVECTOR_PG_BIGM_VERSION=1.2-20240606 - -# vastbase configurations, only available when VECTOR_STORE is `vastbase` -VASTBASE_HOST=vastbase -VASTBASE_PORT=5432 -VASTBASE_USER=dify -VASTBASE_PASSWORD=Difyai123456 -VASTBASE_DATABASE=dify -VASTBASE_MIN_CONNECTION=1 -VASTBASE_MAX_CONNECTION=5 - -# pgvecto-rs configurations, only available when VECTOR_STORE is `pgvecto-rs` -PGVECTO_RS_HOST=pgvecto-rs -PGVECTO_RS_PORT=5432 -PGVECTO_RS_USER=postgres -PGVECTO_RS_PASSWORD=difyai123456 -PGVECTO_RS_DATABASE=dify - -# analyticdb configurations, only available when VECTOR_STORE is `analyticdb` -ANALYTICDB_KEY_ID=your-ak -ANALYTICDB_KEY_SECRET=your-sk -ANALYTICDB_REGION_ID=cn-hangzhou -ANALYTICDB_INSTANCE_ID=gp-ab123456 -ANALYTICDB_ACCOUNT=testaccount -ANALYTICDB_PASSWORD=testpassword -ANALYTICDB_NAMESPACE=dify -ANALYTICDB_NAMESPACE_PASSWORD=difypassword -ANALYTICDB_HOST=gp-test.aliyuncs.com -ANALYTICDB_PORT=5432 -ANALYTICDB_MIN_CONNECTION=1 -ANALYTICDB_MAX_CONNECTION=5 - -# TiDB vector configurations, only available when VECTOR_STORE is `tidb_vector` -TIDB_VECTOR_HOST=tidb -TIDB_VECTOR_PORT=4000 -TIDB_VECTOR_USER= -TIDB_VECTOR_PASSWORD= -TIDB_VECTOR_DATABASE=dify - -# Matrixone vector configurations. -MATRIXONE_HOST=matrixone -MATRIXONE_PORT=6001 -MATRIXONE_USER=dump -MATRIXONE_PASSWORD=111 -MATRIXONE_DATABASE=dify - -# Tidb on qdrant configuration, only available when VECTOR_STORE is `tidb_on_qdrant` -TIDB_ON_QDRANT_URL=http://127.0.0.1 -TIDB_ON_QDRANT_API_KEY=dify -TIDB_ON_QDRANT_CLIENT_TIMEOUT=20 -TIDB_ON_QDRANT_GRPC_ENABLED=false -TIDB_ON_QDRANT_GRPC_PORT=6334 -TIDB_PUBLIC_KEY=dify -TIDB_PRIVATE_KEY=dify -TIDB_API_URL=http://127.0.0.1 -TIDB_IAM_API_URL=http://127.0.0.1 -TIDB_REGION=regions/aws-us-east-1 -TIDB_PROJECT_ID=dify -TIDB_SPEND_LIMIT=100 - -# Chroma configuration, only available when VECTOR_STORE is `chroma` -CHROMA_HOST=127.0.0.1 -CHROMA_PORT=8000 -CHROMA_TENANT=default_tenant -CHROMA_DATABASE=default_database -CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthClientProvider -CHROMA_AUTH_CREDENTIALS= - -# Oracle configuration, only available when VECTOR_STORE is `oracle` -ORACLE_USER=dify -ORACLE_PASSWORD=dify -ORACLE_DSN=oracle:1521/FREEPDB1 -ORACLE_CONFIG_DIR=/app/api/storage/wallet -ORACLE_WALLET_LOCATION=/app/api/storage/wallet -ORACLE_WALLET_PASSWORD=dify -ORACLE_IS_AUTONOMOUS=false - -# relyt configurations, only available when VECTOR_STORE is `relyt` -RELYT_HOST=db -RELYT_PORT=5432 -RELYT_USER=postgres -RELYT_PASSWORD=difyai123456 -RELYT_DATABASE=postgres - -# open search configuration, only available when VECTOR_STORE is `opensearch` -OPENSEARCH_HOST=opensearch -OPENSEARCH_PORT=9200 -OPENSEARCH_SECURE=true -OPENSEARCH_VERIFY_CERTS=true -OPENSEARCH_AUTH_METHOD=basic -OPENSEARCH_USER=admin -OPENSEARCH_PASSWORD=admin -# If using AWS managed IAM, e.g. Managed Cluster or OpenSearch Serverless -OPENSEARCH_AWS_REGION=ap-southeast-1 -OPENSEARCH_AWS_SERVICE=aoss - -# tencent vector configurations, only available when VECTOR_STORE is `tencent` -TENCENT_VECTOR_DB_URL=http://127.0.0.1 -TENCENT_VECTOR_DB_API_KEY=dify -TENCENT_VECTOR_DB_TIMEOUT=30 -TENCENT_VECTOR_DB_USERNAME=dify -TENCENT_VECTOR_DB_DATABASE=dify -TENCENT_VECTOR_DB_SHARD=1 -TENCENT_VECTOR_DB_REPLICAS=2 -TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false - -# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch` -ELASTICSEARCH_HOST=0.0.0.0 -ELASTICSEARCH_PORT=9200 -ELASTICSEARCH_USERNAME=elastic -ELASTICSEARCH_PASSWORD=elastic -KIBANA_PORT=5601 - -# baidu vector configurations, only available when VECTOR_STORE is `baidu` -BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 -BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 -BAIDU_VECTOR_DB_ACCOUNT=root -BAIDU_VECTOR_DB_API_KEY=dify -BAIDU_VECTOR_DB_DATABASE=dify -BAIDU_VECTOR_DB_SHARD=1 -BAIDU_VECTOR_DB_REPLICAS=3 - -# VikingDB configurations, only available when VECTOR_STORE is `vikingdb` -VIKINGDB_ACCESS_KEY=your-ak -VIKINGDB_SECRET_KEY=your-sk -VIKINGDB_REGION=cn-shanghai -VIKINGDB_HOST=api-vikingdb.xxx.volces.com -VIKINGDB_SCHEMA=http -VIKINGDB_CONNECTION_TIMEOUT=30 -VIKINGDB_SOCKET_TIMEOUT=30 - -# Lindorm configuration, only available when VECTOR_STORE is `lindorm` -LINDORM_URL=http://lindorm:30070 -LINDORM_USERNAME=lindorm -LINDORM_PASSWORD=lindorm -LINDORM_QUERY_TIMEOUT=1 - -# OceanBase Vector configuration, only available when VECTOR_STORE is `oceanbase` -OCEANBASE_VECTOR_HOST=oceanbase -OCEANBASE_VECTOR_PORT=2881 -OCEANBASE_VECTOR_USER=root@test -OCEANBASE_VECTOR_PASSWORD=difyai123456 -OCEANBASE_VECTOR_DATABASE=test -OCEANBASE_CLUSTER_NAME=difyai -OCEANBASE_MEMORY_LIMIT=6G -OCEANBASE_ENABLE_HYBRID_SEARCH=false - -# opengauss configurations, only available when VECTOR_STORE is `opengauss` -OPENGAUSS_HOST=opengauss -OPENGAUSS_PORT=6600 -OPENGAUSS_USER=postgres -OPENGAUSS_PASSWORD=Dify@123 -OPENGAUSS_DATABASE=dify -OPENGAUSS_MIN_CONNECTION=1 -OPENGAUSS_MAX_CONNECTION=5 -OPENGAUSS_ENABLE_PQ=false - -# huawei cloud search service vector configurations, only available when VECTOR_STORE is `huawei_cloud` -HUAWEI_CLOUD_HOSTS=https://127.0.0.1:9200 -HUAWEI_CLOUD_USER=admin -HUAWEI_CLOUD_PASSWORD=admin - -# Upstash Vector configuration, only available when VECTOR_STORE is `upstash` -UPSTASH_VECTOR_URL=https://xxx-vector.upstash.io -UPSTASH_VECTOR_TOKEN=dify - -# TableStore Vector configuration -# (only used when VECTOR_STORE is tablestore) -TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com -TABLESTORE_INSTANCE_NAME=instance-name -TABLESTORE_ACCESS_KEY_ID=xxx -TABLESTORE_ACCESS_KEY_SECRET=xxx - -# Clickzetta configuration, only available when VECTOR_STORE is `clickzetta` -CLICKZETTA_USERNAME= -CLICKZETTA_PASSWORD= -CLICKZETTA_INSTANCE= -CLICKZETTA_SERVICE=api.clickzetta.com -CLICKZETTA_WORKSPACE=quick_start -CLICKZETTA_VCLUSTER=default_ap -CLICKZETTA_SCHEMA=dify -CLICKZETTA_BATCH_SIZE=100 -CLICKZETTA_ENABLE_INVERTED_INDEX=true -CLICKZETTA_ANALYZER_TYPE=chinese -CLICKZETTA_ANALYZER_MODE=smart -CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance - -# ------------------------------ -# Knowledge Configuration -# ------------------------------ - -# Upload file size limit, default 15M. -UPLOAD_FILE_SIZE_LIMIT=15 - -# The maximum number of files that can be uploaded at a time, default 5. -UPLOAD_FILE_BATCH_LIMIT=5 - -# ETL type, support: `dify`, `Unstructured` -# `dify` Dify's proprietary file extraction scheme -# `Unstructured` Unstructured.io file extraction scheme -ETL_TYPE=dify - -# Unstructured API path and API key, needs to be configured when ETL_TYPE is Unstructured -# Or using Unstructured for document extractor node for pptx. -# For example: http://unstructured:8000/general/v0/general -UNSTRUCTURED_API_URL= -UNSTRUCTURED_API_KEY= -SCARF_NO_ANALYTICS=true - -# ------------------------------ -# Model Configuration -# ------------------------------ - -# The maximum number of tokens allowed for prompt generation. -# This setting controls the upper limit of tokens that can be used by the LLM -# when generating a prompt in the prompt generation tool. -# Default: 512 tokens. -PROMPT_GENERATION_MAX_TOKENS=512 - -# The maximum number of tokens allowed for code generation. -# This setting controls the upper limit of tokens that can be used by the LLM -# when generating code in the code generation tool. -# Default: 1024 tokens. -CODE_GENERATION_MAX_TOKENS=1024 - -# Enable or disable plugin based token counting. If disabled, token counting will return 0. -# This can improve performance by skipping token counting operations. -# Default: false (disabled). -PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false - -# ------------------------------ -# Multi-modal Configuration -# ------------------------------ - -# The format of the image/video/audio/document sent when the multi-modal model is input, -# the default is base64, optional url. -# The delay of the call in url mode will be lower than that in base64 mode. -# It is generally recommended to use the more compatible base64 mode. -# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video/audio/document. -MULTIMODAL_SEND_FORMAT=base64 -# Upload image file size limit, default 10M. -UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 -# Upload video file size limit, default 100M. -UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 -# Upload audio file size limit, default 50M. -UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 - -# ------------------------------ -# Sentry Configuration -# Used for application monitoring and error log tracking. -# ------------------------------ -SENTRY_DSN= - -# API Service Sentry DSN address, default is empty, when empty, -# all monitoring information is not reported to Sentry. -# If not set, Sentry error reporting will be disabled. -API_SENTRY_DSN= -# API Service The reporting ratio of Sentry events, if it is 0.01, it is 1%. -API_SENTRY_TRACES_SAMPLE_RATE=1.0 -# API Service The reporting ratio of Sentry profiles, if it is 0.01, it is 1%. -API_SENTRY_PROFILES_SAMPLE_RATE=1.0 - -# Web Service Sentry DSN address, default is empty, when empty, -# all monitoring information is not reported to Sentry. -# If not set, Sentry error reporting will be disabled. -WEB_SENTRY_DSN= - -# ------------------------------ -# Notion Integration Configuration -# Variables can be obtained by applying for Notion integration: https://www.notion.so/my-integrations -# ------------------------------ - -# Configure as "public" or "internal". -# Since Notion's OAuth redirect URL only supports HTTPS, -# if deploying locally, please use Notion's internal integration. -NOTION_INTEGRATION_TYPE=public -# Notion OAuth client secret (used for public integration type) -NOTION_CLIENT_SECRET= -# Notion OAuth client id (used for public integration type) -NOTION_CLIENT_ID= -# Notion internal integration secret. -# If the value of NOTION_INTEGRATION_TYPE is "internal", -# you need to configure this variable. -NOTION_INTERNAL_SECRET= - -# ------------------------------ -# Mail related configuration -# ------------------------------ - -# Mail type, support: resend, smtp, sendgrid -MAIL_TYPE=resend - -# Default send from email address, if not specified -# If using SendGrid, use the 'from' field for authentication if necessary. -MAIL_DEFAULT_SEND_FROM= - -# API-Key for the Resend email provider, used when MAIL_TYPE is `resend`. -RESEND_API_URL=https://api.resend.com -RESEND_API_KEY=your-resend-api-key - - -# SMTP server configuration, used when MAIL_TYPE is `smtp` -SMTP_SERVER= -SMTP_PORT=465 -SMTP_USERNAME= -SMTP_PASSWORD= -SMTP_USE_TLS=true -SMTP_OPPORTUNISTIC_TLS=false - -# Sendgid configuration -SENDGRID_API_KEY= - -# ------------------------------ -# Others Configuration -# ------------------------------ - -# Maximum length of segmentation tokens for indexing -INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 - -# Member invitation link valid time (hours), -# Default: 72. -INVITE_EXPIRY_HOURS=72 - -# Reset password token valid time (minutes), -RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 - -# The sandbox service endpoint. -CODE_EXECUTION_ENDPOINT=http://sandbox:8194 -CODE_EXECUTION_API_KEY=dify-sandbox -CODE_MAX_NUMBER=9223372036854775807 -CODE_MIN_NUMBER=-9223372036854775808 -CODE_MAX_DEPTH=5 -CODE_MAX_PRECISION=20 -CODE_MAX_STRING_LENGTH=80000 -CODE_MAX_STRING_ARRAY_LENGTH=30 -CODE_MAX_OBJECT_ARRAY_LENGTH=30 -CODE_MAX_NUMBER_ARRAY_LENGTH=1000 -CODE_EXECUTION_CONNECT_TIMEOUT=10 -CODE_EXECUTION_READ_TIMEOUT=60 -CODE_EXECUTION_WRITE_TIMEOUT=10 -TEMPLATE_TRANSFORM_MAX_LENGTH=80000 - -# Workflow runtime configuration -WORKFLOW_MAX_EXECUTION_STEPS=500 -WORKFLOW_MAX_EXECUTION_TIME=1200 -WORKFLOW_CALL_MAX_DEPTH=5 -MAX_VARIABLE_SIZE=204800 -WORKFLOW_PARALLEL_DEPTH_LIMIT=3 -WORKFLOW_FILE_UPLOAD_LIMIT=10 - -# Workflow storage configuration -# Options: rdbms, hybrid -# rdbms: Use only the relational database (default) -# hybrid: Save new data to object storage, read from both object storage and RDBMS -WORKFLOW_NODE_EXECUTION_STORAGE=rdbms - -# Repository configuration -# Core workflow execution repository implementation -CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository - -# Core workflow node execution repository implementation -CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository - -# API workflow node execution repository implementation -API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository - -# API workflow run repository implementation -API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository - -# HTTP request node in workflow configuration -HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 -HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 -HTTP_REQUEST_NODE_SSL_VERIFY=True - -# Respect X-* headers to redirect clients -RESPECT_XFORWARD_HEADERS_ENABLED=false - -# SSRF Proxy server HTTP URL -SSRF_PROXY_HTTP_URL=http://ssrf_proxy:3128 -# SSRF Proxy server HTTPS URL -SSRF_PROXY_HTTPS_URL=http://ssrf_proxy:3128 - -# Maximum loop count in the workflow -LOOP_NODE_MAX_COUNT=100 - -# The maximum number of tools that can be used in the agent. -MAX_TOOLS_NUM=10 - -# Maximum number of Parallelism branches in the workflow -MAX_PARALLEL_LIMIT=10 - -# The maximum number of iterations for agent setting -MAX_ITERATIONS_NUM=99 - -# ------------------------------ -# Environment Variables for web Service -# ------------------------------ - -# The timeout for the text generation in millisecond -TEXT_GENERATION_TIMEOUT_MS=60000 - -# Allow rendering unsafe URLs which have "data:" scheme. -ALLOW_UNSAFE_DATA_SCHEME=false - -# ------------------------------ -# Environment Variables for db Service -# ------------------------------ - -# The name of the default postgres user. -POSTGRES_USER=${DB_USERNAME} -# The password for the default postgres user. -POSTGRES_PASSWORD=${DB_PASSWORD} -# The name of the default postgres database. -POSTGRES_DB=${DB_DATABASE} -# postgres data directory -PGDATA=/var/lib/postgresql/data/pgdata - -# ------------------------------ -# Environment Variables for sandbox Service -# ------------------------------ - -# The API key for the sandbox service -SANDBOX_API_KEY=dify-sandbox -# The mode in which the Gin framework runs -SANDBOX_GIN_MODE=release -# The timeout for the worker in seconds -SANDBOX_WORKER_TIMEOUT=15 -# Enable network for the sandbox service -SANDBOX_ENABLE_NETWORK=true -# HTTP proxy URL for SSRF protection -SANDBOX_HTTP_PROXY=http://ssrf_proxy:3128 -# HTTPS proxy URL for SSRF protection -SANDBOX_HTTPS_PROXY=http://ssrf_proxy:3128 -# The port on which the sandbox service runs -SANDBOX_PORT=8194 - -# ------------------------------ -# Environment Variables for weaviate Service -# (only used when VECTOR_STORE is weaviate) -# ------------------------------ -WEAVIATE_PERSISTENCE_DATA_PATH=/var/lib/weaviate -WEAVIATE_QUERY_DEFAULTS_LIMIT=25 -WEAVIATE_AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true -WEAVIATE_DEFAULT_VECTORIZER_MODULE=none -WEAVIATE_CLUSTER_HOSTNAME=node1 -WEAVIATE_AUTHENTICATION_APIKEY_ENABLED=true -WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih -WEAVIATE_AUTHENTICATION_APIKEY_USERS=hello@dify.ai -WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED=true -WEAVIATE_AUTHORIZATION_ADMINLIST_USERS=hello@dify.ai - -# ------------------------------ -# Environment Variables for Chroma -# (only used when VECTOR_STORE is chroma) -# ------------------------------ - -# Authentication credentials for Chroma server -CHROMA_SERVER_AUTHN_CREDENTIALS=difyai123456 -# Authentication provider for Chroma server -CHROMA_SERVER_AUTHN_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider -# Persistence setting for Chroma server -CHROMA_IS_PERSISTENT=TRUE - -# ------------------------------ -# Environment Variables for Oracle Service -# (only used when VECTOR_STORE is oracle) -# ------------------------------ -ORACLE_PWD=Dify123456 -ORACLE_CHARACTERSET=AL32UTF8 - -# ------------------------------ -# Environment Variables for milvus Service -# (only used when VECTOR_STORE is milvus) -# ------------------------------ -# ETCD configuration for auto compaction mode -ETCD_AUTO_COMPACTION_MODE=revision -# ETCD configuration for auto compaction retention in terms of number of revisions -ETCD_AUTO_COMPACTION_RETENTION=1000 -# ETCD configuration for backend quota in bytes -ETCD_QUOTA_BACKEND_BYTES=4294967296 -# ETCD configuration for the number of changes before triggering a snapshot -ETCD_SNAPSHOT_COUNT=50000 -# MinIO access key for authentication -MINIO_ACCESS_KEY=minioadmin -# MinIO secret key for authentication -MINIO_SECRET_KEY=minioadmin -# ETCD service endpoints -ETCD_ENDPOINTS=etcd:2379 -# MinIO service address -MINIO_ADDRESS=minio:9000 -# Enable or disable security authorization -MILVUS_AUTHORIZATION_ENABLED=true - -# ------------------------------ -# Environment Variables for pgvector / pgvector-rs Service -# (only used when VECTOR_STORE is pgvector / pgvector-rs) -# ------------------------------ -PGVECTOR_PGUSER=postgres -# The password for the default postgres user. -PGVECTOR_POSTGRES_PASSWORD=difyai123456 -# The name of the default postgres database. -PGVECTOR_POSTGRES_DB=dify -# postgres data directory -PGVECTOR_PGDATA=/var/lib/postgresql/data/pgdata - -# ------------------------------ -# Environment Variables for opensearch -# (only used when VECTOR_STORE is opensearch) -# ------------------------------ -OPENSEARCH_DISCOVERY_TYPE=single-node -OPENSEARCH_BOOTSTRAP_MEMORY_LOCK=true -OPENSEARCH_JAVA_OPTS_MIN=512m -OPENSEARCH_JAVA_OPTS_MAX=1024m -OPENSEARCH_INITIAL_ADMIN_PASSWORD=Qazwsxedc!@#123 -OPENSEARCH_MEMLOCK_SOFT=-1 -OPENSEARCH_MEMLOCK_HARD=-1 -OPENSEARCH_NOFILE_SOFT=65536 -OPENSEARCH_NOFILE_HARD=65536 - -# ------------------------------ -# Environment Variables for Nginx reverse proxy -# ------------------------------ -NGINX_SERVER_NAME=_ -NGINX_HTTPS_ENABLED=false -# HTTP port -NGINX_PORT=80 -# SSL settings are only applied when HTTPS_ENABLED is true -NGINX_SSL_PORT=443 -# if HTTPS_ENABLED is true, you're required to add your own SSL certificates/keys to the `./nginx/ssl` directory -# and modify the env vars below accordingly. -NGINX_SSL_CERT_FILENAME=dify.crt -NGINX_SSL_CERT_KEY_FILENAME=dify.key -NGINX_SSL_PROTOCOLS=TLSv1.1 TLSv1.2 TLSv1.3 - -# Nginx performance tuning -NGINX_WORKER_PROCESSES=auto -NGINX_CLIENT_MAX_BODY_SIZE=100M -NGINX_KEEPALIVE_TIMEOUT=65 - -# Proxy settings -NGINX_PROXY_READ_TIMEOUT=3600s -NGINX_PROXY_SEND_TIMEOUT=3600s - -# Set true to accept requests for /.well-known/acme-challenge/ -NGINX_ENABLE_CERTBOT_CHALLENGE=false - -# ------------------------------ -# Certbot Configuration -# ------------------------------ - -# Email address (required to get certificates from Let's Encrypt) -CERTBOT_EMAIL=your_email@example.com - -# Domain name -CERTBOT_DOMAIN=your_domain.com - -# certbot command options -# i.e: --force-renewal --dry-run --test-cert --debug -CERTBOT_OPTIONS= - -# ------------------------------ -# Environment Variables for SSRF Proxy -# ------------------------------ -SSRF_HTTP_PORT=3128 -SSRF_COREDUMP_DIR=/var/spool/squid -SSRF_REVERSE_PROXY_PORT=8194 -SSRF_SANDBOX_HOST=sandbox -SSRF_DEFAULT_TIME_OUT=5 -SSRF_DEFAULT_CONNECT_TIME_OUT=5 -SSRF_DEFAULT_READ_TIME_OUT=5 -SSRF_DEFAULT_WRITE_TIME_OUT=5 - -# ------------------------------ -# docker env var for specifying vector db type at startup -# (based on the vector db type, the corresponding docker -# compose profile will be used) -# if you want to use unstructured, add ',unstructured' to the end -# ------------------------------ -COMPOSE_PROFILES=${VECTOR_STORE:-weaviate} - -# ------------------------------ -# Docker Compose Service Expose Host Port Configurations -# ------------------------------ -EXPOSE_NGINX_PORT=80 -EXPOSE_NGINX_SSL_PORT=443 - -# ---------------------------------------------------------------------------- -# ModelProvider & Tool Position Configuration -# Used to specify the model providers and tools that can be used in the app. -# ---------------------------------------------------------------------------- - -# Pin, include, and exclude tools -# Use comma-separated values with no spaces between items. -# Example: POSITION_TOOL_PINS=bing,google -POSITION_TOOL_PINS= -POSITION_TOOL_INCLUDES= -POSITION_TOOL_EXCLUDES= - -# Pin, include, and exclude model providers -# Use comma-separated values with no spaces between items. -# Example: POSITION_PROVIDER_PINS=openai,openllm -POSITION_PROVIDER_PINS= -POSITION_PROVIDER_INCLUDES= -POSITION_PROVIDER_EXCLUDES= - -# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP -CSP_WHITELIST= - -# Enable or disable create tidb service job -CREATE_TIDB_SERVICE_JOB_ENABLED=false - -# Maximum number of submitted thread count in a ThreadPool for parallel node execution -MAX_SUBMIT_COUNT=100 - -# The maximum number of top-k value for RAG. -TOP_K_MAX_VALUE=10 - -# ------------------------------ -# Plugin Daemon Configuration -# ------------------------------ - -DB_PLUGIN_DATABASE=dify_plugin -EXPOSE_PLUGIN_DAEMON_PORT=5002 -PLUGIN_DAEMON_PORT=5002 -PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi -PLUGIN_DAEMON_URL=http://plugin_daemon:5002 -PLUGIN_MAX_PACKAGE_SIZE=52428800 -PLUGIN_PPROF_ENABLED=false - -PLUGIN_DEBUGGING_HOST=0.0.0.0 -PLUGIN_DEBUGGING_PORT=5003 -EXPOSE_PLUGIN_DEBUGGING_HOST=localhost -EXPOSE_PLUGIN_DEBUGGING_PORT=5003 - -# If this key is changed, DIFY_INNER_API_KEY in plugin_daemon service must also be updated or agent node will fail. -PLUGIN_DIFY_INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 -PLUGIN_DIFY_INNER_API_URL=http://api:5001 - -ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id} - -MARKETPLACE_ENABLED=true -MARKETPLACE_API_URL=https://marketplace.dify.ai - -FORCE_VERIFYING_SIGNATURE=true - -PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 -PLUGIN_MAX_EXECUTION_TIMEOUT=600 -# PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple -PIP_MIRROR_URL= - -# https://github.com/langgenius/dify-plugin-daemon/blob/main/.env.example -# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss volcengine_tos -PLUGIN_STORAGE_TYPE=local -PLUGIN_STORAGE_LOCAL_ROOT=/app/storage -PLUGIN_WORKING_PATH=/app/storage/cwd -PLUGIN_INSTALLED_PATH=plugin -PLUGIN_PACKAGE_CACHE_PATH=plugin_packages -PLUGIN_MEDIA_CACHE_PATH=assets -# Plugin oss bucket -PLUGIN_STORAGE_OSS_BUCKET= -# Plugin oss s3 credentials -PLUGIN_S3_USE_AWS=false -PLUGIN_S3_USE_AWS_MANAGED_IAM=false -PLUGIN_S3_ENDPOINT= -PLUGIN_S3_USE_PATH_STYLE=false -PLUGIN_AWS_ACCESS_KEY= -PLUGIN_AWS_SECRET_KEY= -PLUGIN_AWS_REGION= -# Plugin oss azure blob -PLUGIN_AZURE_BLOB_STORAGE_CONTAINER_NAME= -PLUGIN_AZURE_BLOB_STORAGE_CONNECTION_STRING= -# Plugin oss tencent cos -PLUGIN_TENCENT_COS_SECRET_KEY= -PLUGIN_TENCENT_COS_SECRET_ID= -PLUGIN_TENCENT_COS_REGION= -# Plugin oss aliyun oss -PLUGIN_ALIYUN_OSS_REGION= -PLUGIN_ALIYUN_OSS_ENDPOINT= -PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID= -PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET= -PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4 -PLUGIN_ALIYUN_OSS_PATH= -# Plugin oss volcengine tos -PLUGIN_VOLCENGINE_TOS_ENDPOINT= -PLUGIN_VOLCENGINE_TOS_ACCESS_KEY= -PLUGIN_VOLCENGINE_TOS_SECRET_KEY= -PLUGIN_VOLCENGINE_TOS_REGION= - -# ------------------------------ -# OTLP Collector Configuration -# ------------------------------ -ENABLE_OTEL=false -OTLP_TRACE_ENDPOINT= -OTLP_METRIC_ENDPOINT= -OTLP_BASE_ENDPOINT=http://localhost:4318 -OTLP_API_KEY= -OTEL_EXPORTER_OTLP_PROTOCOL= -OTEL_EXPORTER_TYPE=otlp -OTEL_SAMPLING_RATE=0.1 -OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000 -OTEL_MAX_QUEUE_SIZE=2048 -OTEL_MAX_EXPORT_BATCH_SIZE=512 -OTEL_METRIC_EXPORT_INTERVAL=60000 -OTEL_BATCH_EXPORT_TIMEOUT=10000 -OTEL_METRIC_EXPORT_TIMEOUT=30000 - -# Prevent Clickjacking -ALLOW_EMBED=false - -# Dataset queue monitor configuration -QUEUE_MONITOR_THRESHOLD=200 -# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai -QUEUE_MONITOR_ALERT_EMAILS= -# Monitor interval in minutes, default is 30 minutes -QUEUE_MONITOR_INTERVAL=30 diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 54f3f42a25..9aad9558b0 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -82,7 +82,7 @@ jobs: - name: Install pnpm uses: pnpm/action-setup@v4 with: - version: 10 + package_json_file: web/package.json run_install: false - name: Setup NodeJS @@ -95,10 +95,12 @@ jobs: - name: Web dependencies if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web run: pnpm install --frozen-lockfile - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web run: pnpm run lint docker-compose-template: diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index 4b06174ee1..c004836808 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -46,7 +46,7 @@ jobs: - name: Install pnpm uses: pnpm/action-setup@v4 with: - version: 10 + package_json_file: web/package.json run_install: false - name: Set up Node.js @@ -59,10 +59,12 @@ jobs: - name: Install dependencies if: env.FILES_CHANGED == 'true' + working-directory: ./web run: pnpm install --frozen-lockfile - name: Generate i18n translations if: env.FILES_CHANGED == 'true' + working-directory: ./web run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }} - name: Create Pull Request diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index c3f8fdbaf6..d104d69947 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -35,7 +35,7 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' uses: pnpm/action-setup@v4 with: - version: 10 + package_json_file: web/package.json run_install: false - name: Setup Node.js @@ -48,8 +48,10 @@ jobs: - name: Install dependencies if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web run: pnpm install --frozen-lockfile - name: Run tests if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web run: pnpm test diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..7ce04382c9 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,83 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management. + +The codebase consists of: +- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture +- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19 +- **Docker deployment** (`/docker`): Containerized deployment configurations + +## Development Commands + +### Backend (API) + +All Python commands must be prefixed with `uv run --project api`: + +```bash +# Start development servers +./dev/start-api # Start API server +./dev/start-worker # Start Celery worker + +# Run tests +uv run --project api pytest # Run all tests +uv run --project api pytest tests/unit_tests/ # Unit tests only +uv run --project api pytest tests/integration_tests/ # Integration tests + +# Code quality +./dev/reformat # Run all formatters and linters +uv run --project api ruff check --fix ./ # Fix linting issues +uv run --project api ruff format ./ # Format code +uv run --project api mypy . # Type checking +``` + +### Frontend (Web) + +```bash +cd web +pnpm lint # Run ESLint +pnpm eslint-fix # Fix ESLint issues +pnpm test # Run Jest tests +``` + +## Testing Guidelines + +### Backend Testing +- Use `pytest` for all backend tests +- Write tests first (TDD approach) +- Test structure: Arrange-Act-Assert + +## Code Style Requirements + +### Python +- Use type hints for all functions and class attributes +- No `Any` types unless absolutely necessary +- Implement special methods (`__repr__`, `__str__`) appropriately + +### TypeScript/JavaScript +- Strict TypeScript configuration +- ESLint with Prettier integration +- Avoid `any` type + +## Important Notes + +- **Environment Variables**: Always use UV for Python commands: `uv run --project api ` +- **Comments**: Only write meaningful comments that explain "why", not "what" +- **File Creation**: Always prefer editing existing files over creating new ones +- **Documentation**: Don't create documentation files unless explicitly requested +- **Code Quality**: Always run `./dev/reformat` before committing backend changes + +## Common Development Tasks + +### Adding a New API Endpoint +1. Create controller in `/api/controllers/` +2. Add service logic in `/api/services/` +3. Update routes in controller's `__init__.py` +4. Write tests in `/api/tests/` + +## Project-Specific Conventions + +- All async tasks use Celery with Redis as broker diff --git a/README.md b/README.md index 775f6f351f..80e44b0728 100644 --- a/README.md +++ b/README.md @@ -225,7 +225,8 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/) ##### AWS -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Using Alibaba Cloud Computing Nest diff --git a/README_AR.md b/README_AR.md index e7a4dbdb27..9c8378d087 100644 --- a/README_AR.md +++ b/README_AR.md @@ -208,7 +208,8 @@ docker compose up -d ##### AWS -- [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK بواسطة @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK بواسطة @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### استخدام Alibaba Cloud للنشر [بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) diff --git a/README_BN.md b/README_BN.md index e4da437eff..a31aafdf56 100644 --- a/README_BN.md +++ b/README_BN.md @@ -225,7 +225,8 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন ##### AWS -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud ব্যবহার করে ডিপ্লয় diff --git a/README_CN.md b/README_CN.md index 82149519d3..0698693429 100644 --- a/README_CN.md +++ b/README_CN.md @@ -223,7 +223,8 @@ docker compose up -d 使用 [CDK](https://aws.amazon.com/cdk/) 将 Dify 部署到 AWS ##### AWS -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### 使用 阿里云计算巢 部署 diff --git a/README_DE.md b/README_DE.md index 2420ac0392..392cc7885e 100644 --- a/README_DE.md +++ b/README_DE.md @@ -220,7 +220,8 @@ Stellen Sie Dify mit nur einem Klick mithilfe von [terraform](https://www.terraf Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/) ##### AWS -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud diff --git a/README_ES.md b/README_ES.md index 4fa59dc18f..859da5bfd7 100644 --- a/README_ES.md +++ b/README_ES.md @@ -220,7 +220,8 @@ Despliega Dify en una plataforma en la nube con un solo clic utilizando [terrafo Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) ##### AWS -- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK por @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK por @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud diff --git a/README_FR.md b/README_FR.md index dcbc869620..fcadad419b 100644 --- a/README_FR.md +++ b/README_FR.md @@ -218,7 +218,8 @@ Déployez Dify sur une plateforme cloud en un clic en utilisant [terraform](http Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) ##### AWS -- [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK par @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK par @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud diff --git a/README_JA.md b/README_JA.md index d840fd6419..6ddc30789c 100644 --- a/README_JA.md +++ b/README_JA.md @@ -219,7 +219,8 @@ docker compose up -d [CDK](https://aws.amazon.com/cdk/) を使用して、DifyをAWSにデプロイします ##### AWS -- [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [@KevinZhaoによるAWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [@tmokmssによるAWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) diff --git a/README_KL.md b/README_KL.md index 41c7969e1c..7232da8003 100644 --- a/README_KL.md +++ b/README_KL.md @@ -218,7 +218,8 @@ wa'logh nIqHom neH ghun deployment toy'wI' [terraform](https://www.terraform.io/ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo'laH. ##### AWS -- [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK qachlot @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK qachlot @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud diff --git a/README_KR.md b/README_KR.md index d4b31a8928..74010d43ed 100644 --- a/README_KR.md +++ b/README_KR.md @@ -212,7 +212,8 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 [CDK](https://aws.amazon.com/cdk/)를 사용하여 AWS에 Dify 배포 ##### AWS -- [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [KevinZhao의 AWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [tmokmss의 AWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud diff --git a/README_PT.md b/README_PT.md index 94452cb233..f9e3ef7f4b 100644 --- a/README_PT.md +++ b/README_PT.md @@ -217,7 +217,8 @@ Implante o Dify na Plataforma Cloud com um único clique usando [terraform](http Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) ##### AWS -- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK por @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK por @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud diff --git a/README_SI.md b/README_SI.md index d840e9155f..ac16df798b 100644 --- a/README_SI.md +++ b/README_SI.md @@ -218,7 +218,8 @@ namestite Dify v Cloud Platform z enim klikom z uporabo [terraform](https://www. Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) ##### AWS -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud diff --git a/README_TR.md b/README_TR.md index 470a7570e0..8065ec908c 100644 --- a/README_TR.md +++ b/README_TR.md @@ -211,7 +211,8 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter [CDK](https://aws.amazon.com/cdk/) kullanarak Dify'ı AWS'ye dağıtın ##### AWS -- [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK tarafından @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK tarafından @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud diff --git a/README_TW.md b/README_TW.md index 18f1d2754a..c36027183c 100644 --- a/README_TW.md +++ b/README_TW.md @@ -223,7 +223,8 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify ### AWS -- [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [由 @KevinZhao 提供的 AWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [由 @tmokmss 提供的 AWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### 使用 阿里云计算巢進行部署 diff --git a/README_VI.md b/README_VI.md index 2ab6da80fc..958a70114a 100644 --- a/README_VI.md +++ b/README_VI.md @@ -213,7 +213,8 @@ Triển khai Dify lên nền tảng đám mây với một cú nhấp chuột b Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) ##### AWS -- [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK bởi @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK bởi @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud diff --git a/api/.env.example b/api/.env.example index 4beabfecea..3c30872422 100644 --- a/api/.env.example +++ b/api/.env.example @@ -42,6 +42,15 @@ REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false +# SSL configuration for Redis (when REDIS_USE_SSL=true) +REDIS_SSL_CERT_REQS=CERT_NONE +# Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED +REDIS_SSL_CA_CERTS= +# Path to CA certificate file for SSL verification +REDIS_SSL_CERTFILE= +# Path to client certificate file for SSL authentication +REDIS_SSL_KEYFILE= +# Path to client private key file for SSL authentication REDIS_DB=0 # redis Sentinel configuration. diff --git a/api/app_factory.py b/api/app_factory.py index 81155cbacd..032d6b17fc 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp): ext_login, ext_mail, ext_migrate, + ext_orjson, ext_otel, ext_proxy_fix, ext_redis, @@ -67,6 +68,7 @@ def initialize_extensions(app: DifyApp): ext_logging, ext_warnings, ext_import_modules, + ext_orjson, ext_set_secretkey, ext_compress, ext_code_based_extension, diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 916f52e165..16dca98cfa 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -39,6 +39,26 @@ class RedisConfig(BaseSettings): default=False, ) + REDIS_SSL_CERT_REQS: str = Field( + description="SSL certificate requirements (CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED)", + default="CERT_NONE", + ) + + REDIS_SSL_CA_CERTS: Optional[str] = Field( + description="Path to the CA certificate file for SSL verification", + default=None, + ) + + REDIS_SSL_CERTFILE: Optional[str] = Field( + description="Path to the client certificate file for SSL authentication", + default=None, + ) + + REDIS_SSL_KEYFILE: Optional[str] = Field( + description="Path to the client private key file for SSL authentication", + default=None, + ) + REDIS_USE_SENTINEL: Optional[bool] = Field( description="Enable Redis Sentinel mode for high availability", default=False, diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 732f5b799a..ad94112f05 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,6 +1,7 @@ import logging import flask_login +from flask import request from flask_restful import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -24,6 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value @@ -115,6 +117,10 @@ class ChatMessageApi(Resource): streaming = args["response_mode"] != "blocking" args["auto_generate_name"] = False + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id + account = flask_login.current_user try: diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index a9f088a276..c58301b300 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -23,6 +23,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File +from core.helper.trace_id_helper import get_external_trace_id from extensions.ext_database import db from factories import file_factory, variable_factory from fields.workflow_fields import workflow_fields, workflow_pagination_fields @@ -185,6 +186,10 @@ class AdvancedChatDraftWorkflowRunApi(Resource): args = parser.parse_args() + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id + try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True @@ -373,6 +378,10 @@ class DraftWorkflowRunApi(Resource): parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id + try: response = AppGenerateService.generate( app_model=app_model, diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index ba93f82756..414c07ef50 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -163,11 +163,11 @@ class WorkflowVariableCollectionApi(Resource): draft_var_srv = WorkflowDraftVariableService( session=session, ) - workflow_vars = draft_var_srv.list_variables_without_values( - app_id=app_model.id, - page=args.page, - limit=args.limit, - ) + workflow_vars = draft_var_srv.list_variables_without_values( + app_id=app_model.id, + page=args.page, + limit=args.limit, + ) return workflow_vars diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 8237ea3cdc..894785abc8 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -32,7 +32,7 @@ class VersionApi(Resource): return result try: - response = requests.get(check_update_url, {"current_version": args.get("current_version")}) + response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10)) except Exception as error: logging.warning("Check update version error: %s.", str(error)) result["version"] = args.get("current_version") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 79c860e6b8..073307ac4a 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,3 @@ -import json - from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range from sqlalchemy.orm import Session @@ -136,12 +134,15 @@ class ConversationVariableDetailApi(Resource): variable_id = str(variable_id) parser = reqparse.RequestParser() - parser.add_argument("value", required=True, location="json") + # using lambda is for passing the already-typed value without modification + # if no lambda, it will be converted to string + # the string cannot be converted using json.loads + parser.add_argument("value", required=True, location="json", type=lambda x: x) args = parser.parse_args() try: return ConversationService.update_conversation_variable( - app_model, conversation_id, variable_id, end_user, json.loads(args["value"]) + app_model, conversation_id, variable_id, end_user, args["value"] ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 0c76cc39ae..c273776eb1 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -140,7 +140,9 @@ class ChatAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - trace_manager = TraceQueueManager(app_id=app_model.id) + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) # init application generate entity application_generate_entity = ChatAppGenerateEntity( diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 9356bd1cea..64dade2968 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -124,7 +124,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - trace_manager = TraceQueueManager(app_model.id) + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) # init application generate entity application_generate_entity = CompletionAppGenerateEntity( diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 8507f23f17..4100a0d5a9 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -6,7 +6,6 @@ from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAdvancedChatMessageEndEvent, QueueErrorEvent, - QueueMessage, QueueMessageEndEvent, QueueStopEvent, ) @@ -22,15 +21,6 @@ class MessageBasedAppQueueManager(AppQueueManager): self._app_mode = app_mode self._message_id = str(message_id) - def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: - return MessageQueueMessage( - task_id=self._task_id, - message_id=self._message_id, - conversation_id=self._conversation_id, - app_mode=self._app_mode, - event=event, - ) - def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ Publish event to queue diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b416e48ce4..3965f8cb31 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,7 +5,7 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any -from core.variables.utils import SegmentJSONEncoder +from core.variables.utils import dumps_with_segments class TemplateTransformer(ABC): @@ -93,7 +93,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode() + inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index e90c3194f2..df42837796 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -16,15 +16,33 @@ def get_external_trace_id(request: Any) -> Optional[str]: """ Retrieve the trace_id from the request. - Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid. + Priority: + 1. header ('X-Trace-Id') + 2. parameters + 3. JSON body + 4. Current OpenTelemetry context (if enabled) + 5. OpenTelemetry traceparent header (if present and valid) + + Returns None if no valid trace_id is provided. """ trace_id = request.headers.get("X-Trace-Id") + if not trace_id: trace_id = request.args.get("trace_id") + if not trace_id and getattr(request, "is_json", False): json_data = getattr(request, "json", None) if json_data: trace_id = json_data.get("trace_id") + + if not trace_id: + trace_id = get_trace_id_from_otel_context() + + if not trace_id: + traceparent = request.headers.get("traceparent") + if traceparent: + trace_id = parse_traceparent_header(traceparent) + if isinstance(trace_id, str) and is_valid_trace_id(trace_id): return trace_id return None @@ -40,3 +58,49 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict: if trace_id: return {"external_trace_id": trace_id} return {} + + +def get_trace_id_from_otel_context() -> Optional[str]: + """ + Retrieve the current trace ID from the active OpenTelemetry trace context. + Returns None if: + 1. OpenTelemetry SDK is not installed or enabled. + 2. There is no active span or trace context. + """ + try: + from opentelemetry.trace import SpanContext, get_current_span + from opentelemetry.trace.span import INVALID_TRACE_ID + + span = get_current_span() + if not span: + return None + + span_context: SpanContext = span.get_span_context() + + if not span_context or span_context.trace_id == INVALID_TRACE_ID: + return None + + trace_id_hex = f"{span_context.trace_id:032x}" + return trace_id_hex + + except Exception: + return None + + +def parse_traceparent_header(traceparent: str) -> Optional[str]: + """ + Parse the `traceparent` header to extract the trace_id. + + Expected format: + 'version-trace_id-span_id-flags' + + Reference: + W3C Trace Context Specification: https://www.w3.org/TR/trace-context/ + """ + try: + parts = traceparent.split("-") + if len(parts) == 4 and len(parts[1]) == 32: + return parts[1] + except Exception: + pass + return None diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index 00d5a25956..bad99fc092 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -10,8 +10,6 @@ from core.mcp.types import ( from models.tools import MCPToolProvider from services.tools.mcp_tools_manage_service import MCPToolManageService -LATEST_PROTOCOL_VERSION = "1.0" - class OAuthClientProvider: mcp_provider: MCPToolProvider diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 2d3a3f5344..cc38954eca 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -7,6 +7,7 @@ from typing import Any, TypeAlias, final from urllib.parse import urljoin, urlparse import httpx +from httpx_sse import EventSource, ServerSentEvent from sseclient import SSEClient from core.mcp import types @@ -37,11 +38,6 @@ WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None] StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError] -def remove_request_params(url: str) -> str: - """Remove request parameters from URL, keeping only the path.""" - return urljoin(url, urlparse(url).path) - - class SSETransport: """SSE client transport implementation.""" @@ -114,7 +110,7 @@ class SSETransport: logger.exception("Error parsing server message") read_queue.put(exc) - def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None: """Handle a single SSE event. Args: @@ -130,7 +126,7 @@ class SSETransport: case _: logger.warning("Unknown SSE event: %s", sse.event) - def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None: """Read and process SSE events. Args: @@ -225,7 +221,7 @@ class SSETransport: self, executor: ThreadPoolExecutor, client: httpx.Client, - event_source, + event_source: EventSource, ) -> tuple[ReadQueue, WriteQueue]: """Establish connection and start worker threads. diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 496b5432a0..efe91bbff4 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -16,13 +16,14 @@ from extensions.ext_database import db from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService -""" -Apply to MCP HTTP streamable server with stateless http -""" logger = logging.getLogger(__name__) class MCPServerStreamableHTTPRequestHandler: + """ + Apply to MCP HTTP streamable server with stateless http + """ + def __init__( self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity] ): diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index a54badcd4c..80912bc4c1 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -1,6 +1,10 @@ import json +from collections.abc import Generator +from contextlib import AbstractContextManager import httpx +import httpx_sse +from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError @@ -55,20 +59,42 @@ def create_ssrf_proxy_mcp_http_client( ) -def ssrf_proxy_sse_connect(url, **kwargs): +def ssrf_proxy_sse_connect(url: str, **kwargs) -> AbstractContextManager[httpx_sse.EventSource]: """Connect to SSE endpoint with SSRF proxy protection. This function creates an SSE connection using the configured proxy settings - to prevent SSRF attacks when connecting to external endpoints. + to prevent SSRF attacks when connecting to external endpoints. It returns + a context manager that yields an EventSource object for SSE streaming. + + The function handles HTTP client creation and cleanup automatically, but + also accepts a pre-configured client via kwargs. Args: - url: The SSE endpoint URL - **kwargs: Additional arguments passed to the SSE connection + url (str): The SSE endpoint URL to connect to + **kwargs: Additional arguments passed to the SSE connection, including: + - client (httpx.Client, optional): Pre-configured HTTP client. + If not provided, one will be created with SSRF protection. + - method (str, optional): HTTP method to use, defaults to "GET" + - headers (dict, optional): HTTP headers to include in the request + - timeout (httpx.Timeout, optional): Timeout configuration for the connection Returns: - EventSource object for SSE streaming + AbstractContextManager[httpx_sse.EventSource]: A context manager that yields an EventSource + object for SSE streaming. The EventSource provides access to server-sent events. + + Example: + ```python + with ssrf_proxy_sse_connect(url, headers=headers) as event_source: + for sse in event_source.iter_sse(): + print(sse.event, sse.data) + ``` + + Note: + If a client is not provided in kwargs, one will be automatically created + with SSRF protection based on the application's configuration. If an + exception occurs during connection, any automatically created client + will be cleaned up automatically. """ - from httpx_sse import connect_sse # Extract client if provided, otherwise create one client = kwargs.pop("client", None) @@ -101,7 +127,9 @@ def ssrf_proxy_sse_connect(url, **kwargs): raise -def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None): +def create_mcp_error_response( + request_id: int | str | None, code: int, message: str, data=None +) -> Generator[bytes, None, None]: """Create MCP error response""" error_data = ErrorData(code=code, message=message, data=data) json_response = JSONRPCError( diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index a5c11aeeba..f65339fbfc 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -151,12 +151,9 @@ def jsonable_encoder( return format(obj, "f") if isinstance(obj, dict): encoded_dict = {} - allowed_keys = set(obj.keys()) for key, value in obj.items(): - if ( - (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) - and (value is not None or not exclude_none) - and key in allowed_keys + if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and ( + value is not None or not exclude_none ): encoded_key = jsonable_encoder( key, diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 06050619e9..82f54582ed 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -4,15 +4,15 @@ from collections.abc import Sequence from typing import Optional from urllib.parse import urljoin -from opentelemetry.trace import Status, StatusCode +from opentelemetry.trace import Link, Status, StatusCode from sqlalchemy.orm import Session, sessionmaker from core.ops.aliyun_trace.data_exporter.traceclient import ( TraceClient, convert_datetime_to_nanoseconds, - convert_string_to_id, convert_to_span_id, convert_to_trace_id, + create_link, generate_span_id, ) from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData @@ -103,10 +103,11 @@ class AliyunDataTrace(BaseTraceInstance): def workflow_trace(self, trace_info: WorkflowTraceInfo): trace_id = convert_to_trace_id(trace_info.workflow_run_id) + links = [] if trace_info.trace_id: - trace_id = convert_string_to_id(trace_info.trace_id) + links.append(create_link(trace_id_str=trace_info.trace_id)) workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow") - self.add_workflow_span(trace_id, workflow_span_id, trace_info) + self.add_workflow_span(trace_id, workflow_span_id, trace_info, links) workflow_node_executions = self.get_workflow_node_executions(trace_info) for node_execution in workflow_node_executions: @@ -132,8 +133,9 @@ class AliyunDataTrace(BaseTraceInstance): status = Status(StatusCode.ERROR, trace_info.error) trace_id = convert_to_trace_id(message_id) + links = [] if trace_info.trace_id: - trace_id = convert_string_to_id(trace_info.trace_id) + links.append(create_link(trace_id_str=trace_info.trace_id)) message_span_id = convert_to_span_id(message_id, "message") message_span = SpanData( @@ -152,6 +154,7 @@ class AliyunDataTrace(BaseTraceInstance): OUTPUT_VALUE: str(trace_info.outputs), }, status=status, + links=links, ) self.trace_client.add_span(message_span) @@ -192,8 +195,9 @@ class AliyunDataTrace(BaseTraceInstance): message_id = trace_info.message_id trace_id = convert_to_trace_id(message_id) + links = [] if trace_info.trace_id: - trace_id = convert_string_to_id(trace_info.trace_id) + links.append(create_link(trace_id_str=trace_info.trace_id)) documents_data = extract_retrieval_documents(trace_info.documents) dataset_retrieval_span = SpanData( @@ -211,6 +215,7 @@ class AliyunDataTrace(BaseTraceInstance): INPUT_VALUE: str(trace_info.inputs), OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False), }, + links=links, ) self.trace_client.add_span(dataset_retrieval_span) @@ -224,8 +229,9 @@ class AliyunDataTrace(BaseTraceInstance): status = Status(StatusCode.ERROR, trace_info.error) trace_id = convert_to_trace_id(message_id) + links = [] if trace_info.trace_id: - trace_id = convert_string_to_id(trace_info.trace_id) + links.append(create_link(trace_id_str=trace_info.trace_id)) tool_span = SpanData( trace_id=trace_id, @@ -244,6 +250,7 @@ class AliyunDataTrace(BaseTraceInstance): OUTPUT_VALUE: str(trace_info.tool_outputs), }, status=status, + links=links, ) self.trace_client.add_span(tool_span) @@ -413,7 +420,9 @@ class AliyunDataTrace(BaseTraceInstance): status=self.get_workflow_node_status(node_execution), ) - def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo): + def add_workflow_span( + self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, links: Sequence[Link] + ): message_span_id = None if trace_info.message_id: message_span_id = convert_to_span_id(trace_info.message_id, "message") @@ -438,6 +447,7 @@ class AliyunDataTrace(BaseTraceInstance): OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), }, status=status, + links=links, ) self.trace_client.add_span(message_span) @@ -456,6 +466,7 @@ class AliyunDataTrace(BaseTraceInstance): OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), }, status=status, + links=links, ) self.trace_client.add_span(workflow_span) @@ -466,8 +477,9 @@ class AliyunDataTrace(BaseTraceInstance): status = Status(StatusCode.ERROR, trace_info.error) trace_id = convert_to_trace_id(message_id) + links = [] if trace_info.trace_id: - trace_id = convert_string_to_id(trace_info.trace_id) + links.append(create_link(trace_id_str=trace_info.trace_id)) suggested_question_span = SpanData( trace_id=trace_id, @@ -487,6 +499,7 @@ class AliyunDataTrace(BaseTraceInstance): OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False), }, status=status, + links=links, ) self.trace_client.add_span(suggested_question_span) diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index bd19c8a503..3eb7c30d55 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.util.instrumentation import InstrumentationScope from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace import Link, SpanContext, TraceFlags from configs import dify_config from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData @@ -166,6 +167,16 @@ class SpanBuilder: return span +def create_link(trace_id_str: str) -> Link: + placeholder_span_id = 0x0000000000000000 + trace_id = int(trace_id_str, 16) + span_context = SpanContext( + trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED) + ) + + return Link(span_context) + + def generate_span_id() -> int: span_id = random.getrandbits(64) while span_id == INVALID_SPAN_ID: diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 88addd7e68..789a032654 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -523,7 +523,7 @@ class ProviderManager: # Init trial provider records if not exists if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: try: - # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic + # FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic new_provider_record = Provider( tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 835d71056d..2555d6b8c0 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -1,7 +1,7 @@ -import json from collections import defaultdict from typing import Any, Optional +import orjson from pydantic import BaseModel from configs import dify_config @@ -135,13 +135,13 @@ class Jieba(BaseKeyword): dataset_keyword_table = self.dataset.dataset_keyword_table keyword_data_source_type = dataset_keyword_table.data_source_type if keyword_data_source_type == "database": - dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) + dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict) db.session.commit() else: file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8")) + storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8")) def _get_dataset_keyword_table(self) -> Optional[dict]: dataset_keyword_table = self.dataset.dataset_keyword_table @@ -157,12 +157,11 @@ class Jieba(BaseKeyword): data_source_type=keyword_data_source_type, ) if keyword_data_source_type == "database": - dataset_keyword_table.keyword_table = json.dumps( + dataset_keyword_table.keyword_table = dumps_with_sets( { "__type__": "keyword_table", "__data__": {"index_id": self.dataset.id, "summary": None, "table": {}}, - }, - cls=SetEncoder, + } ) db.session.add(dataset_keyword_table) db.session.commit() @@ -257,8 +256,13 @@ class Jieba(BaseKeyword): self._save_dataset_keyword_table(keyword_table) -class SetEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, set): - return list(obj) - return super().default(obj) +def set_orjson_default(obj: Any) -> Any: + """Default function for orjson serialization of set types""" + if isinstance(obj, set): + return list(obj) + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +def dumps_with_sets(obj: Any) -> str: + """JSON dumps with set support using orjson""" + return orjson.dumps(obj, default=set_orjson_default).decode("utf-8") diff --git a/api/core/tools/entities/agent_entities.py b/api/core/tools/entities/agent_entities.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/tools/entities/file_entities.py b/api/core/tools/entities/file_entities.py deleted file mode 100644 index 8b13789179..0000000000 --- a/api/core/tools/entities/file_entities.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index f9990260bc..b1a7eacf0e 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -108,10 +108,18 @@ class ApiProviderAuthType(Enum): :param value: mode value :return: mode """ + # 'api_key' deprecated in PR #21656 + # normalize & tiny alias for backward compatibility + v = (value or "").strip().lower() + if v == "api_key": + v = cls.API_KEY_HEADER.value + for mode in cls: - if mode.value == value: + if mode.value == v: return mode - raise ValueError(f"invalid mode value {value}") + + valid = ", ".join(m.value for m in cls) + raise ValueError(f"invalid mode value '{value}', expected one of: {valid}") class ToolInvokeMessage(BaseModel): diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py index 692db3502e..7ebd29f865 100644 --- a/api/core/variables/utils.py +++ b/api/core/variables/utils.py @@ -1,5 +1,7 @@ -import json from collections.abc import Iterable, Sequence +from typing import Any + +import orjson from .segment_group import SegmentGroup from .segments import ArrayFileSegment, FileSegment, Segment @@ -12,15 +14,20 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[ return selectors -class SegmentJSONEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, ArrayFileSegment): - return [v.model_dump() for v in o.value] - elif isinstance(o, FileSegment): - return o.value.model_dump() - elif isinstance(o, SegmentGroup): - return [self.default(seg) for seg in o.value] - elif isinstance(o, Segment): - return o.value - else: - super().default(o) +def segment_orjson_default(o: Any) -> Any: + """Default function for orjson serialization of Segment types""" + if isinstance(o, ArrayFileSegment): + return [v.model_dump() for v in o.value] + elif isinstance(o, FileSegment): + return o.value.model_dump() + elif isinstance(o, SegmentGroup): + return [segment_orjson_default(seg) for seg in o.value] + elif isinstance(o, Segment): + return o.value + raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") + + +def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str: + """JSON dumps with segment support using orjson""" + option = orjson.OPT_NON_STR_KEYS + return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index c0c0cb405c..dfc2a0000b 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -5,7 +5,7 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError @@ -194,17 +194,6 @@ class LLMNode(BaseNode): else [] ) - # single step run fetch file from sys files - if not files and self.invoke_from == InvokeFrom.DEBUGGER and not self.previous_node_id: - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=["sys", "files"], - ) - if self._node_data.vision.enabled - else [] - ) - if files: node_inputs["#files#"] = [file.to_dict() for file in files] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index df89b2476d..4c8e13de70 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -318,33 +318,6 @@ class ToolNode(BaseNode): json.append(message.message.json_object) elif message.type == ToolInvokeMessage.MessageType.LINK: assert isinstance(message.message, ToolInvokeMessage.TextMessage) - - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - else: - transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"Tool file {tool_file_id} does not exist") - - mapping = { - "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), - "transfer_method": transfer_method, - "url": message.message.text, - } - - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - stream_text = f"Link: {message.message.text}\n" text += stream_text yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"]) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/events/event_handlers/document_index_event.py b/api/events/document_index_event.py similarity index 100% rename from api/events/event_handlers/document_index_event.py rename to api/events/document_index_event.py diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index bdb69945f0..c607161e2a 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -5,7 +5,7 @@ import click from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from events.event_handlers.document_index_event import document_index_created +from events.document_index_event import document_index_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index bd72c93404..198f60e554 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,4 +1,6 @@ +import ssl from datetime import timedelta +from typing import Any, Optional import pytz from celery import Celery, Task # type: ignore @@ -8,6 +10,40 @@ from configs import dify_config from dify_app import DifyApp +def _get_celery_ssl_options() -> Optional[dict[str, Any]]: + """Get SSL configuration for Celery broker/backend connections.""" + # Use REDIS_USE_SSL for consistency with the main Redis client + # Only apply SSL if we're using Redis as broker/backend + if not dify_config.REDIS_USE_SSL: + return None + + # Check if Celery is actually using Redis + broker_is_redis = dify_config.CELERY_BROKER_URL and ( + dify_config.CELERY_BROKER_URL.startswith("redis://") or dify_config.CELERY_BROKER_URL.startswith("rediss://") + ) + + if not broker_is_redis: + return None + + # Map certificate requirement strings to SSL constants + cert_reqs_map = { + "CERT_NONE": ssl.CERT_NONE, + "CERT_OPTIONAL": ssl.CERT_OPTIONAL, + "CERT_REQUIRED": ssl.CERT_REQUIRED, + } + + ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) + + ssl_options = { + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, + "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, + "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, + } + + return ssl_options + + def init_app(app: DifyApp) -> Celery: class FlaskTask(Task): def __call__(self, *args: object, **kwargs: object) -> object: @@ -33,14 +69,6 @@ def init_app(app: DifyApp) -> Celery: task_ignore_result=True, ) - # Add SSL options to the Celery configuration - ssl_options = { - "ssl_cert_reqs": None, - "ssl_ca_certs": None, - "ssl_certfile": None, - "ssl_keyfile": None, - } - celery_app.conf.update( result_backend=dify_config.CELERY_RESULT_BACKEND, broker_transport_options=broker_transport_options, @@ -51,9 +79,13 @@ def init_app(app: DifyApp) -> Celery: timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"), ) - if dify_config.BROKER_USE_SSL: + # Apply SSL configuration if enabled + ssl_options = _get_celery_ssl_options() + if ssl_options: celery_app.conf.update( - broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration + broker_use_ssl=ssl_options, + # Also apply SSL to the backend if it's Redis + redis_backend_use_ssl=ssl_options if dify_config.CELERY_BACKEND == "redis" else None, ) if dify_config.LOG_FILE: diff --git a/api/extensions/ext_orjson.py b/api/extensions/ext_orjson.py new file mode 100644 index 0000000000..659784a585 --- /dev/null +++ b/api/extensions/ext_orjson.py @@ -0,0 +1,8 @@ +from flask_orjson import OrjsonProvider + +from dify_app import DifyApp + + +def init_app(app: DifyApp) -> None: + """Initialize Flask-Orjson extension for faster JSON serialization""" + app.json = OrjsonProvider(app) diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 914d6219cf..f5f544679f 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -1,5 +1,6 @@ import functools import logging +import ssl from collections.abc import Callable from datetime import timedelta from typing import TYPE_CHECKING, Any, Union @@ -116,76 +117,132 @@ class RedisClientWrapper: redis_client: RedisClientWrapper = RedisClientWrapper() -def init_app(app: DifyApp): - global redis_client - connection_class: type[Union[Connection, SSLConnection]] = Connection - if dify_config.REDIS_USE_SSL: - connection_class = SSLConnection - resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL - if dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE: - if resp_protocol >= 3: - clientside_cache_config = CacheConfig() - else: - raise ValueError("Client side cache is only supported in RESP3") - else: - clientside_cache_config = None +def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: + """Get SSL configuration for Redis connection.""" + if not dify_config.REDIS_USE_SSL: + return Connection, {} - redis_params: dict[str, Any] = { + cert_reqs_map = { + "CERT_NONE": ssl.CERT_NONE, + "CERT_OPTIONAL": ssl.CERT_OPTIONAL, + "CERT_REQUIRED": ssl.CERT_REQUIRED, + } + ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) + + ssl_kwargs = { + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, + "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, + "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, + } + + return SSLConnection, ssl_kwargs + + +def _get_cache_configuration() -> CacheConfig | None: + """Get client-side cache configuration if enabled.""" + if not dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE: + return None + + resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL + if resp_protocol < 3: + raise ValueError("Client side cache is only supported in RESP3") + + return CacheConfig() + + +def _get_base_redis_params() -> dict[str, Any]: + """Get base Redis connection parameters.""" + return { "username": dify_config.REDIS_USERNAME, - "password": dify_config.REDIS_PASSWORD or None, # Temporary fix for empty password + "password": dify_config.REDIS_PASSWORD or None, "db": dify_config.REDIS_DB, "encoding": "utf-8", "encoding_errors": "strict", "decode_responses": False, - "protocol": resp_protocol, - "cache_config": clientside_cache_config, + "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, + "cache_config": _get_cache_configuration(), } - if dify_config.REDIS_USE_SENTINEL: - assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True" - assert dify_config.REDIS_SENTINEL_SERVICE_NAME is not None, ( - "REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True" - ) - sentinel_hosts = [ - (node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",") - ] - sentinel = Sentinel( - sentinel_hosts, - sentinel_kwargs={ - "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, - "username": dify_config.REDIS_SENTINEL_USERNAME, - "password": dify_config.REDIS_SENTINEL_PASSWORD, - }, - ) - master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) - redis_client.initialize(master) - elif dify_config.REDIS_USE_CLUSTERS: - assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True" - nodes = [ - ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) - for node in dify_config.REDIS_CLUSTERS.split(",") - ] - redis_client.initialize( - RedisCluster( - startup_nodes=nodes, - password=dify_config.REDIS_CLUSTERS_PASSWORD, - protocol=resp_protocol, - cache_config=clientside_cache_config, - ) - ) - else: - redis_params.update( - { - "host": dify_config.REDIS_HOST, - "port": dify_config.REDIS_PORT, - "connection_class": connection_class, - "protocol": resp_protocol, - "cache_config": clientside_cache_config, - } - ) - pool = redis.ConnectionPool(**redis_params) - redis_client.initialize(redis.Redis(connection_pool=pool)) +def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: + """Create Redis client using Sentinel configuration.""" + if not dify_config.REDIS_SENTINELS: + raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True") + + if not dify_config.REDIS_SENTINEL_SERVICE_NAME: + raise ValueError("REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True") + + sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")] + + sentinel = Sentinel( + sentinel_hosts, + sentinel_kwargs={ + "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, + "username": dify_config.REDIS_SENTINEL_USERNAME, + "password": dify_config.REDIS_SENTINEL_PASSWORD, + }, + ) + + master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) + return master + + +def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: + """Create Redis cluster client.""" + if not dify_config.REDIS_CLUSTERS: + raise ValueError("REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True") + + nodes = [ + ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) + for node in dify_config.REDIS_CLUSTERS.split(",") + ] + + cluster: RedisCluster = RedisCluster( + startup_nodes=nodes, + password=dify_config.REDIS_CLUSTERS_PASSWORD, + protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, + cache_config=_get_cache_configuration(), + ) + return cluster + + +def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: + """Create standalone Redis client.""" + connection_class, ssl_kwargs = _get_ssl_configuration() + + redis_params.update( + { + "host": dify_config.REDIS_HOST, + "port": dify_config.REDIS_PORT, + "connection_class": connection_class, + } + ) + + if ssl_kwargs: + redis_params.update(ssl_kwargs) + + pool = redis.ConnectionPool(**redis_params) + client: redis.Redis = redis.Redis(connection_pool=pool) + return client + + +def init_app(app: DifyApp): + """Initialize Redis client and attach it to the app.""" + global redis_client + + # Determine Redis mode and create appropriate client + if dify_config.REDIS_USE_SENTINEL: + redis_params = _get_base_redis_params() + client = _create_sentinel_client(redis_params) + elif dify_config.REDIS_USE_CLUSTERS: + client = _create_cluster_client() + else: + redis_params = _get_base_redis_params() + client = _create_standalone_client(redis_params) + + # Initialize the wrapper and attach to app + redis_client.initialize(client) app.extensions["redis"] = redis_client diff --git a/api/models/workflow.py b/api/models/workflow.py index ba7396e0a2..ed23cb9c16 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1184,7 +1184,7 @@ class WorkflowDraftVariable(Base): value: The Segment object to store as the variable's value. """ self.__value = value - self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder) + self.value = variable_utils.dumps_with_segments(value) self.value_type = value.value_type def get_node_id(self) -> str | None: diff --git a/api/mypy.ini b/api/mypy.ini index 6836b2602b..3a6a54afe1 100644 --- a/api/mypy.ini +++ b/api/mypy.ini @@ -5,8 +5,7 @@ check_untyped_defs = True cache_fine_grained = True sqlite_cache = True exclude = (?x)( - core/model_runtime/model_providers/ - | tests/ + tests/ | migrations/ ) diff --git a/api/pyproject.toml b/api/pyproject.toml index de472c870a..61a725a830 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "flask-cors~=6.0.0", "flask-login~=0.6.3", "flask-migrate~=4.0.7", + "flask-orjson~=2.0.0", "flask-restful~=0.3.10", "flask-sqlalchemy~=3.1.1", "gevent~=24.11.1", diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 713c4c6782..4f3dd3c762 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -103,10 +103,10 @@ class ConversationService: @classmethod def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation): field_value = getattr(reference_conversation, sort_field) - if sort_direction == desc: + if sort_direction is desc: return getattr(Conversation, sort_field) < field_value - else: - return getattr(Conversation, sort_field) > field_value + + return getattr(Conversation, sort_field) > field_value @classmethod def rename( @@ -147,7 +147,7 @@ class ConversationService: app_model.tenant_id, message.query, conversation.id, app_model.id ) conversation.name = name - except: + except Exception: pass db.session.commit() @@ -277,6 +277,11 @@ class ConversationService: # Validate that the new value type matches the expected variable type expected_type = SegmentType(current_variable.value_type) + + # There is showing number in web ui but int in db + if expected_type == SegmentType.INTEGER: + expected_type = SegmentType.NUMBER + if not expected_type.is_valid(new_value): inferred_type = SegmentType.infer_segment_type(new_value) raise ConversationVariableTypeMismatchError( diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 4a09e71504..057b20428f 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -55,7 +55,9 @@ class OAuthProxyService(BasePluginClient): if not context_id: raise ValueError("context_id is required") # get data from redis - data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}") + key = f"{OAuthProxyService.__KEY_PREFIX__}{context_id}" + data = redis_client.get(key) if not data: raise ValueError("context_id is invalid") + redis_client.delete(key) return json.loads(data) diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py new file mode 100644 index 0000000000..7fef572c14 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -0,0 +1,1144 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from core.rag.index_processor.constant.built_in_field import BuiltInField +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document +from services.entities.knowledge_entities.knowledge_entities import MetadataArgs +from services.metadata_service import MetadataService + + +class TestMetadataService: + """Integration tests for MetadataService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.metadata_service.current_user") as mock_current_user, + patch("services.metadata_service.redis_client") as mock_redis_client, + patch("services.dataset_service.DocumentService") as mock_document_service, + ): + # Setup default mock returns + mock_redis_client.get.return_value = None + mock_redis_client.set.return_value = True + mock_redis_client.delete.return_value = 1 + + yield { + "current_user": mock_current_user, + "redis_client": mock_redis_client, + "document_service": mock_document_service, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, account, tenant): + """ + Helper method to create a test dataset for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + account: Account instance + tenant: Tenant instance + + Returns: + Dataset: Created dataset instance + """ + fake = Faker() + + dataset = Dataset( + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + created_by=account.id, + built_in_field_enabled=False, + ) + + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + return dataset + + def _create_test_document(self, db_session_with_containers, mock_external_service_dependencies, dataset, account): + """ + Helper method to create a test document for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + dataset: Dataset instance + account: Account instance + + Returns: + Document: Created document instance + """ + fake = Faker() + + document = Document( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + data_source_info="{}", + batch="test-batch", + name=fake.file_name(), + created_from="web", + created_by=account.id, + doc_form="text", + doc_language="en", + ) + + from extensions.ext_database import db + + db.session.add(document) + db.session.commit() + + return document + + def test_create_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful metadata creation with valid parameters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + metadata_args = MetadataArgs(type="string", name="test_metadata") + + # Act: Execute the method under test + result = MetadataService.create_metadata(dataset.id, metadata_args) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.name == "test_metadata" + assert result.type == "string" + assert result.dataset_id == dataset.id + assert result.tenant_id == tenant.id + assert result.created_by == account.id + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.id is not None + assert result.created_at is not None + + def test_create_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata creation fails when name exceeds 255 characters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + long_name = "a" * 256 # 256 characters, exceeding 255 limit + metadata_args = MetadataArgs(type="string", name=long_name) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): + MetadataService.create_metadata(dataset.id, metadata_args) + + def test_create_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata creation fails when name already exists in the same dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create first metadata + first_metadata_args = MetadataArgs(type="string", name="duplicate_name") + MetadataService.create_metadata(dataset.id, first_metadata_args) + + # Try to create second metadata with same name + second_metadata_args = MetadataArgs(type="number", name="duplicate_name") + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name already exists."): + MetadataService.create_metadata(dataset.id, second_metadata_args) + + def test_create_metadata_name_conflicts_with_built_in_field( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata creation fails when name conflicts with built-in field names. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Try to create metadata with built-in field name + built_in_field_name = BuiltInField.document_name.value + metadata_args = MetadataArgs(type="string", name=built_in_field_name) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): + MetadataService.create_metadata(dataset.id, metadata_args) + + def test_update_metadata_name_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful metadata name update with valid parameters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Act: Execute the method under test + new_name = "new_name" + result = MetadataService.update_metadata_name(dataset.id, metadata.id, new_name) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.name == new_name + assert result.updated_by == account.id + assert result.updated_at is not None + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.name == new_name + + def test_update_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata name update fails when new name exceeds 255 characters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Try to update with too long name + long_name = "a" * 256 # 256 characters, exceeding 255 limit + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): + MetadataService.update_metadata_name(dataset.id, metadata.id, long_name) + + def test_update_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata name update fails when new name already exists in the same dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create two metadata entries + first_metadata_args = MetadataArgs(type="string", name="first_metadata") + first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args) + + second_metadata_args = MetadataArgs(type="number", name="second_metadata") + second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args) + + # Try to update first metadata with second metadata's name + with pytest.raises(ValueError, match="Metadata name already exists."): + MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata") + + def test_update_metadata_name_conflicts_with_built_in_field( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata name update fails when new name conflicts with built-in field names. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Try to update with built-in field name + built_in_field_name = BuiltInField.document_name.value + + with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): + MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) + + def test_update_metadata_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata name update fails when metadata ID does not exist. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Try to update non-existent metadata + import uuid + + fake_metadata_id = str(uuid.uuid4()) # Use valid UUID format + new_name = "new_name" + + # Act: Execute the method under test + result = MetadataService.update_metadata_name(dataset.id, fake_metadata_id, new_name) + + # Assert: Verify the method returns None when metadata is not found + assert result is None + + def test_delete_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful metadata deletion with valid parameters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="to_be_deleted") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Act: Execute the method under test + result = MetadataService.delete_metadata(dataset.id, metadata.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.id == metadata.id + + # Verify metadata was deleted from database + from extensions.ext_database import db + + deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + assert deleted_metadata is None + + def test_delete_metadata_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata deletion fails when metadata ID does not exist. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Try to delete non-existent metadata + import uuid + + fake_metadata_id = str(uuid.uuid4()) # Use valid UUID format + + # Act: Execute the method under test + result = MetadataService.delete_metadata(dataset.id, fake_metadata_id) + + # Assert: Verify the method returns None when metadata is not found + assert result is None + + def test_delete_metadata_with_document_bindings( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata deletion successfully removes document metadata bindings. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Create metadata binding + binding = DatasetMetadataBinding( + tenant_id=tenant.id, + dataset_id=dataset.id, + metadata_id=metadata.id, + document_id=document.id, + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(binding) + db.session.commit() + + # Set document metadata + document.doc_metadata = {"test_metadata": "test_value"} + db.session.add(document) + db.session.commit() + + # Act: Execute the method under test + result = MetadataService.delete_metadata(dataset.id, metadata.id) + + # Assert: Verify the expected outcomes + assert result is not None + + # Verify metadata was deleted from database + deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + assert deleted_metadata is None + + # Note: The service attempts to update document metadata but may not succeed + # due to mock configuration. The main functionality (metadata deletion) is verified. + + def test_get_built_in_fields_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of built-in metadata fields. + """ + # Act: Execute the method under test + result = MetadataService.get_built_in_fields() + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 5 + + # Verify all expected built-in fields are present + field_names = [field["name"] for field in result] + field_types = [field["type"] for field in result] + + assert BuiltInField.document_name.value in field_names + assert BuiltInField.uploader.value in field_names + assert BuiltInField.upload_date.value in field_names + assert BuiltInField.last_update_date.value in field_names + assert BuiltInField.source.value in field_names + + # Verify field types + assert "string" in field_types + assert "time" in field_types + + def test_enable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful enabling of built-in fields for a dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ + document + ] + + # Verify dataset starts with built-in fields disabled + assert dataset.built_in_field_enabled is False + + # Act: Execute the method under test + MetadataService.enable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is True + + # Note: Document metadata update depends on DocumentService mock working correctly + # The main functionality (enabling built-in fields) is verified + + def test_enable_built_in_field_already_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test enabling built-in fields when they are already enabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Enable built-in fields first + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.enable_built_in_field(dataset) + + # Assert: Verify the method returns early without changes + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is True + + def test_enable_built_in_field_with_no_documents( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test enabling built-in fields for a dataset with no documents. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Mock DocumentService.get_working_documents_by_dataset_id to return empty list + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.enable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is True + + def test_disable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful disabling of built-in fields for a dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Enable built-in fields first + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Set document metadata with built-in fields + document.doc_metadata = { + BuiltInField.document_name.value: document.name, + BuiltInField.uploader.value: "test_uploader", + BuiltInField.upload_date.value: 1234567890.0, + BuiltInField.last_update_date.value: 1234567890.0, + BuiltInField.source.value: "test_source", + } + db.session.add(document) + db.session.commit() + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ + document + ] + + # Act: Execute the method under test + MetadataService.disable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is False + + # Note: Document metadata update depends on DocumentService mock working correctly + # The main functionality (disabling built-in fields) is verified + + def test_disable_built_in_field_already_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test disabling built-in fields when they are already disabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Verify dataset starts with built-in fields disabled + assert dataset.built_in_field_enabled is False + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.disable_built_in_field(dataset) + + # Assert: Verify the method returns early without changes + from extensions.ext_database import db + + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is False + + def test_disable_built_in_field_with_no_documents( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test disabling built-in fields for a dataset with no documents. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Enable built-in fields first + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Mock DocumentService.get_working_documents_by_dataset_id to return empty list + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.disable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is False + + def test_update_documents_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful update of documents metadata. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Mock DocumentService.get_document + mock_external_service_dependencies["document_service"].get_document.return_value = document + + # Create metadata operation data + from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, + ) + + metadata_detail = MetadataDetail(id=metadata.id, name=metadata.name, value="test_value") + + operation = DocumentMetadataOperation(document_id=document.id, metadata_list=[metadata_detail]) + + operation_data = MetadataOperationData(operation_data=[operation]) + + # Act: Execute the method under test + MetadataService.update_documents_metadata(dataset, operation_data) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + # Verify document metadata was updated + db.session.refresh(document) + assert document.doc_metadata is not None + assert "test_metadata" in document.doc_metadata + assert document.doc_metadata["test_metadata"] == "test_value" + + # Verify metadata binding was created + binding = ( + db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata.id, document_id=document.id).first() + ) + assert binding is not None + assert binding.tenant_id == tenant.id + assert binding.dataset_id == dataset.id + + def test_update_documents_metadata_with_built_in_fields_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test update of documents metadata when built-in fields are enabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Enable built-in fields + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Mock DocumentService.get_document + mock_external_service_dependencies["document_service"].get_document.return_value = document + + # Create metadata operation data + from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, + ) + + metadata_detail = MetadataDetail(id=metadata.id, name=metadata.name, value="test_value") + + operation = DocumentMetadataOperation(document_id=document.id, metadata_list=[metadata_detail]) + + operation_data = MetadataOperationData(operation_data=[operation]) + + # Act: Execute the method under test + MetadataService.update_documents_metadata(dataset, operation_data) + + # Assert: Verify the expected outcomes + # Verify document metadata was updated with both custom and built-in fields + db.session.refresh(document) + assert document.doc_metadata is not None + assert "test_metadata" in document.doc_metadata + assert document.doc_metadata["test_metadata"] == "test_value" + + # Note: Built-in fields would be added if DocumentService mock works correctly + # The main functionality (custom metadata update) is verified + + def test_update_documents_metadata_document_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test update of documents metadata when document is not found. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Mock DocumentService.get_document to return None (document not found) + mock_external_service_dependencies["document_service"].get_document.return_value = None + + # Create metadata operation data + from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, + ) + + metadata_detail = MetadataDetail(id=metadata.id, name=metadata.name, value="test_value") + + operation = DocumentMetadataOperation(document_id="non-existent-document-id", metadata_list=[metadata_detail]) + + operation_data = MetadataOperationData(operation_data=[operation]) + + # Act: Execute the method under test + # The method should handle the error gracefully and continue + MetadataService.update_documents_metadata(dataset, operation_data) + + # Assert: Verify the method completes without raising exceptions + # The main functionality (error handling) is verified + + def test_knowledge_base_metadata_lock_check_dataset_id( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check for dataset operations. + """ + # Arrange: Setup mocks + mock_external_service_dependencies["redis_client"].get.return_value = None + mock_external_service_dependencies["redis_client"].set.return_value = True + + dataset_id = "test-dataset-id" + + # Act: Execute the method under test + MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) + + # Assert: Verify the expected outcomes + # Verify Redis lock was set + mock_external_service_dependencies["redis_client"].set.assert_called_once() + + # Verify lock key format + call_args = mock_external_service_dependencies["redis_client"].set.call_args + assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}" + + def test_knowledge_base_metadata_lock_check_document_id( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check for document operations. + """ + # Arrange: Setup mocks + mock_external_service_dependencies["redis_client"].get.return_value = None + mock_external_service_dependencies["redis_client"].set.return_value = True + + document_id = "test-document-id" + + # Act: Execute the method under test + MetadataService.knowledge_base_metadata_lock_check(None, document_id) + + # Assert: Verify the expected outcomes + # Verify Redis lock was set + mock_external_service_dependencies["redis_client"].set.assert_called_once() + + # Verify lock key format + call_args = mock_external_service_dependencies["redis_client"].set.call_args + assert call_args[0][0] == f"document_metadata_lock_{document_id}" + + def test_knowledge_base_metadata_lock_check_lock_exists( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check when lock already exists. + """ + # Arrange: Setup mocks to simulate existing lock + mock_external_service_dependencies["redis_client"].get.return_value = "1" # Lock exists + + dataset_id = "test-dataset-id" + + # Act & Assert: Verify proper error handling + with pytest.raises( + ValueError, match="Another knowledge base metadata operation is running, please wait a moment." + ): + MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) + + def test_knowledge_base_metadata_lock_check_document_lock_exists( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check when document lock already exists. + """ + # Arrange: Setup mocks to simulate existing lock + mock_external_service_dependencies["redis_client"].get.return_value = "1" # Lock exists + + document_id = "test-document-id" + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Another document metadata operation is running, please wait a moment."): + MetadataService.knowledge_base_metadata_lock_check(None, document_id) + + def test_get_dataset_metadatas_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of dataset metadata information. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Create document and metadata binding + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + binding = DatasetMetadataBinding( + tenant_id=tenant.id, + dataset_id=dataset.id, + metadata_id=metadata.id, + document_id=document.id, + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(binding) + db.session.commit() + + # Act: Execute the method under test + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert: Verify the expected outcomes + assert result is not None + assert "doc_metadata" in result + assert "built_in_field_enabled" in result + + # Verify metadata information + doc_metadata = result["doc_metadata"] + assert len(doc_metadata) == 1 + assert doc_metadata[0]["id"] == metadata.id + assert doc_metadata[0]["name"] == metadata.name + assert doc_metadata[0]["type"] == metadata.type + assert doc_metadata[0]["count"] == 1 # One document bound to this metadata + + # Verify built-in field status + assert result["built_in_field_enabled"] is False + + def test_get_dataset_metadatas_with_built_in_fields_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test retrieval of dataset metadata when built-in fields are enabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Enable built-in fields + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Act: Execute the method under test + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert: Verify the expected outcomes + assert result is not None + assert "doc_metadata" in result + assert "built_in_field_enabled" in result + + # Verify metadata information + doc_metadata = result["doc_metadata"] + assert len(doc_metadata) == 1 # Only custom metadata, built-in fields are not included in this list + + # Verify built-in field status + assert result["built_in_field_enabled"] is True + + def test_get_dataset_metadatas_no_metadata(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of dataset metadata when no metadata exists. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Act: Execute the method under test + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert: Verify the expected outcomes + assert result is not None + assert "doc_metadata" in result + assert "built_in_field_enabled" in result + + # Verify metadata information + doc_metadata = result["doc_metadata"] + assert len(doc_metadata) == 0 # No metadata exists + + # Verify built-in field status + assert result["built_in_field_enabled"] is False diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py new file mode 100644 index 0000000000..a8a36b2565 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -0,0 +1,474 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from models.account import TenantAccountJoin, TenantAccountRole +from models.model import Account, Tenant +from models.provider import LoadBalancingModelConfig, Provider, ProviderModelSetting +from services.model_load_balancing_service import ModelLoadBalancingService + + +class TestModelLoadBalancingService: + """Integration tests for ModelLoadBalancingService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager, + patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager, + patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory, + patch("services.model_load_balancing_service.encrypter") as mock_encrypter, + ): + # Setup default mock returns + mock_provider_manager_instance = mock_provider_manager.return_value + + # Mock provider configuration + mock_provider_config = MagicMock() + mock_provider_config.provider.provider = "openai" + mock_provider_config.custom_configuration.provider = None + + # Mock provider model setting + mock_provider_model_setting = MagicMock() + mock_provider_model_setting.load_balancing_enabled = False + + mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting + + # Mock provider configurations dict + mock_provider_configs = {"openai": mock_provider_config} + mock_provider_manager_instance.get_configurations.return_value = mock_provider_configs + + # Mock LBModelManager + mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) + + # Mock ModelProviderFactory + mock_model_provider_factory_instance = mock_model_provider_factory.return_value + + # Mock credential schemas + mock_credential_schema = MagicMock() + mock_credential_schema.credential_form_schemas = [] + + # Mock provider configuration methods + mock_provider_config.extract_secret_variables.return_value = [] + mock_provider_config.obfuscated_credentials.return_value = {} + mock_provider_config._get_credential_schema.return_value = mock_credential_schema + + yield { + "provider_manager": mock_provider_manager, + "lb_model_manager": mock_lb_model_manager, + "model_provider_factory": mock_model_provider_factory, + "encrypter": mock_encrypter, + "provider_config": mock_provider_config, + "provider_model_setting": mock_provider_model_setting, + "credential_schema": mock_credential_schema, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_provider_and_setting( + self, db_session_with_containers, tenant_id, mock_external_service_dependencies + ): + """ + Helper method to create a test provider and provider model setting. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant_id: Tenant ID for the provider + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (provider, provider_model_setting) - Created provider and setting instances + """ + fake = Faker() + + from extensions.ext_database import db + + # Create provider + provider = Provider( + tenant_id=tenant_id, + provider_name="openai", + provider_type="custom", + is_valid=True, + ) + db.session.add(provider) + db.session.commit() + + # Create provider model setting + provider_model_setting = ProviderModelSetting( + tenant_id=tenant_id, + provider_name="openai", + model_name="gpt-3.5-turbo", + model_type="text-generation", # Use the origin model type that matches the query + enabled=True, + load_balancing_enabled=False, + ) + db.session.add(provider_model_setting) + db.session.commit() + + return provider, provider_model_setting + + def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful model load balancing enablement. + + This test verifies: + - Proper provider configuration retrieval + - Successful enablement of model load balancing + - Correct method calls to provider configuration + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Setup mocks for enable method + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_config.enable_model_load_balancing = MagicMock() + + # Act: Execute the method under test + service = ModelLoadBalancingService() + service.enable_model_load_balancing( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + mock_provider_config.enable_model_load_balancing.assert_called_once() + call_args = mock_provider_config.enable_model_load_balancing.call_args + assert call_args.kwargs["model"] == "gpt-3.5-turbo" + assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(provider) + db.session.refresh(provider_model_setting) + assert provider.id is not None + assert provider_model_setting.id is not None + + def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful model load balancing disablement. + + This test verifies: + - Proper provider configuration retrieval + - Successful disablement of model load balancing + - Correct method calls to provider configuration + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Setup mocks for disable method + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_config.disable_model_load_balancing = MagicMock() + + # Act: Execute the method under test + service = ModelLoadBalancingService() + service.disable_model_load_balancing( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + mock_provider_config.disable_model_load_balancing.assert_called_once() + call_args = mock_provider_config.disable_model_load_balancing.call_args + assert call_args.kwargs["model"] == "gpt-3.5-turbo" + assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(provider) + db.session.refresh(provider_model_setting) + assert provider.id is not None + assert provider_model_setting.id is not None + + def test_enable_model_load_balancing_provider_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling when provider does not exist. + + This test verifies: + - Proper error handling for non-existent provider + - Correct exception type and message + - No database state changes + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup mocks to return empty provider configurations + mock_provider_manager = mock_external_service_dependencies["provider_manager"] + mock_provider_manager_instance = mock_provider_manager.return_value + mock_provider_manager_instance.get_configurations.return_value = {} + + # Act & Assert: Verify proper error handling + service = ModelLoadBalancingService() + with pytest.raises(ValueError) as exc_info: + service.enable_model_load_balancing( + tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm" + ) + + # Verify correct error message + assert "Provider nonexistent_provider does not exist." in str(exc_info.value) + + # Verify no database state changes occurred + from extensions.ext_database import db + + db.session.rollback() + + def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of load balancing configurations. + + This test verifies: + - Proper provider configuration retrieval + - Successful database query for load balancing configs + - Correct return format and data structure + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Create load balancing config + from extensions.ext_database import db + + load_balancing_config = LoadBalancingModelConfig( + tenant_id=tenant.id, + provider_name="openai", + model_name="gpt-3.5-turbo", + model_type="text-generation", # Use the origin model type that matches the query + name="config1", + encrypted_config='{"api_key": "test_key"}', + enabled=True, + ) + db.session.add(load_balancing_config) + db.session.commit() + + # Verify the config was created + db.session.refresh(load_balancing_config) + assert load_balancing_config.id is not None + + # Setup mocks for get_load_balancing_configs method + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"] + mock_provider_model_setting.load_balancing_enabled = True + + # Mock credential schema methods + mock_credential_schema = mock_external_service_dependencies["credential_schema"] + mock_credential_schema.credential_form_schemas = [] + + # Mock encrypter + mock_encrypter = mock_external_service_dependencies["encrypter"] + mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher") + + # Mock _get_credential_schema method + mock_provider_config._get_credential_schema.return_value = mock_credential_schema + + # Mock extract_secret_variables method + mock_provider_config.extract_secret_variables.return_value = [] + + # Mock obfuscated_credentials method + mock_provider_config.obfuscated_credentials.return_value = {} + + # Mock LBModelManager.get_config_in_cooldown_and_ttl + mock_lb_model_manager = mock_external_service_dependencies["lb_model_manager"] + mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) + + # Act: Execute the method under test + service = ModelLoadBalancingService() + is_enabled, configs = service.get_load_balancing_configs( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + assert is_enabled is True + assert len(configs) == 1 + assert configs[0]["id"] == load_balancing_config.id + assert configs[0]["name"] == "config1" + assert configs[0]["enabled"] is True + assert configs[0]["in_cooldown"] is False + assert configs[0]["ttl"] == 0 + + # Verify database state + db.session.refresh(load_balancing_config) + assert load_balancing_config.id is not None + + def test_get_load_balancing_configs_provider_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling when provider does not exist in get_load_balancing_configs. + + This test verifies: + - Proper error handling for non-existent provider + - Correct exception type and message + - No database state changes + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup mocks to return empty provider configurations + mock_provider_manager = mock_external_service_dependencies["provider_manager"] + mock_provider_manager_instance = mock_provider_manager.return_value + mock_provider_manager_instance.get_configurations.return_value = {} + + # Act & Assert: Verify proper error handling + service = ModelLoadBalancingService() + with pytest.raises(ValueError) as exc_info: + service.get_load_balancing_configs( + tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm" + ) + + # Verify correct error message + assert "Provider nonexistent_provider does not exist." in str(exc_info.value) + + # Verify no database state changes occurred + from extensions.ext_database import db + + db.session.rollback() + + def test_get_load_balancing_configs_with_inherit_config( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test load balancing configs retrieval with inherit configuration. + + This test verifies: + - Proper handling of inherit configuration + - Correct ordering of configurations + - Inherit config initialization when needed + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Create load balancing config + from extensions.ext_database import db + + load_balancing_config = LoadBalancingModelConfig( + tenant_id=tenant.id, + provider_name="openai", + model_name="gpt-3.5-turbo", + model_type="text-generation", # Use the origin model type that matches the query + name="config1", + encrypted_config='{"api_key": "test_key"}', + enabled=True, + ) + db.session.add(load_balancing_config) + db.session.commit() + + # Setup mocks for inherit config scenario + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_config.custom_configuration.provider = MagicMock() # Enable custom config + + mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"] + mock_provider_model_setting.load_balancing_enabled = True + + # Mock credential schema methods + mock_credential_schema = mock_external_service_dependencies["credential_schema"] + mock_credential_schema.credential_form_schemas = [] + + # Mock encrypter + mock_encrypter = mock_external_service_dependencies["encrypter"] + mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher") + + # Act: Execute the method under test + service = ModelLoadBalancingService() + is_enabled, configs = service.get_load_balancing_configs( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + assert is_enabled is True + assert len(configs) == 2 # inherit config + existing config + + # First config should be inherit config + assert configs[0]["name"] == "__inherit__" + assert configs[0]["enabled"] is True + + # Second config should be the existing config + assert configs[1]["id"] == load_balancing_config.id + assert configs[1]["name"] == "config1" + + # Verify database state + db.session.refresh(load_balancing_config) + assert load_balancing_config.id is not None + + # Verify inherit config was created in database + inherit_configs = ( + db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.name == "__inherit__").all() + ) + assert len(inherit_configs) == 1 diff --git a/api/tests/unit_tests/core/mcp/client/test_sse.py b/api/tests/unit_tests/core/mcp/client/test_sse.py index 8122cd08eb..880a0d4940 100644 --- a/api/tests/unit_tests/core/mcp/client/test_sse.py +++ b/api/tests/unit_tests/core/mcp/client/test_sse.py @@ -262,26 +262,6 @@ def test_sse_client_queue_cleanup(): # Note: In real implementation, cleanup should put None to signal shutdown -def test_sse_client_url_processing(): - """Test SSE client URL processing functions.""" - from core.mcp.client.sse_client import remove_request_params - - # Test URL with parameters - url_with_params = "http://example.com/sse?param1=value1¶m2=value2" - cleaned_url = remove_request_params(url_with_params) - assert cleaned_url == "http://example.com/sse" - - # Test URL without parameters - url_without_params = "http://example.com/sse" - cleaned_url = remove_request_params(url_without_params) - assert cleaned_url == "http://example.com/sse" - - # Test URL with path and parameters - complex_url = "http://example.com/path/to/sse?session=123&token=abc" - cleaned_url = remove_request_params(complex_url) - assert cleaned_url == "http://example.com/path/to/sse" - - def test_sse_client_headers_propagation(): """Test that custom headers are properly propagated in SSE client.""" test_url = "http://test.example/sse" diff --git a/api/tests/unit_tests/core/tools/utils/test_tool_engine_serialization.py b/api/tests/unit_tests/core/tools/utils/test_tool_engine_serialization.py new file mode 100644 index 0000000000..4029edfb68 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_tool_engine_serialization.py @@ -0,0 +1,481 @@ +import json +from datetime import date, datetime +from decimal import Decimal +from uuid import uuid4 + +import numpy as np +import pytest +import pytz + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_dict, safe_json_value + + +class TestSafeJsonValue: + """Test suite for safe_json_value function to ensure proper serialization of complex types""" + + def test_datetime_conversion(self): + """Test datetime conversion with timezone handling""" + # Test datetime with UTC timezone + dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC) + result = safe_json_value(dt) + assert isinstance(result, str) + assert "2024-01-01T12:00:00+00:00" in result + + # Test datetime without timezone (should default to UTC) + dt_no_tz = datetime(2024, 1, 1, 12, 0, 0) + result = safe_json_value(dt_no_tz) + assert isinstance(result, str) + # The exact time will depend on the system's timezone, so we check the format + assert "T" in result # ISO format separator + # Check that it's a valid ISO format datetime string + assert len(result) >= 19 # At least YYYY-MM-DDTHH:MM:SS + + def test_date_conversion(self): + """Test date conversion to ISO format""" + test_date = date(2024, 1, 1) + result = safe_json_value(test_date) + assert result == "2024-01-01" + + def test_uuid_conversion(self): + """Test UUID conversion to string""" + test_uuid = uuid4() + result = safe_json_value(test_uuid) + assert isinstance(result, str) + assert result == str(test_uuid) + + def test_decimal_conversion(self): + """Test Decimal conversion to float""" + test_decimal = Decimal("123.456") + result = safe_json_value(test_decimal) + assert result == 123.456 + assert isinstance(result, float) + + def test_bytes_conversion(self): + """Test bytes conversion with UTF-8 decoding""" + # Test valid UTF-8 bytes + test_bytes = b"Hello, World!" + result = safe_json_value(test_bytes) + assert result == "Hello, World!" + + # Test invalid UTF-8 bytes (should fall back to hex) + invalid_bytes = b"\xff\xfe\xfd" + result = safe_json_value(invalid_bytes) + assert result == "fffefd" + + def test_memoryview_conversion(self): + """Test memoryview conversion to hex string""" + test_bytes = b"test data" + test_memoryview = memoryview(test_bytes) + result = safe_json_value(test_memoryview) + assert result == "746573742064617461" # hex of "test data" + + def test_numpy_ndarray_conversion(self): + """Test numpy ndarray conversion to list""" + # Test 1D array + test_array = np.array([1, 2, 3, 4]) + result = safe_json_value(test_array) + assert result == [1, 2, 3, 4] + + # Test 2D array + test_2d_array = np.array([[1, 2], [3, 4]]) + result = safe_json_value(test_2d_array) + assert result == [[1, 2], [3, 4]] + + # Test array with float values + test_float_array = np.array([1.5, 2.7, 3.14]) + result = safe_json_value(test_float_array) + assert result == [1.5, 2.7, 3.14] + + def test_dict_conversion(self): + """Test dictionary conversion using safe_json_dict""" + test_dict = { + "string": "value", + "number": 42, + "float": 3.14, + "boolean": True, + "list": [1, 2, 3], + "nested": {"key": "value"}, + } + result = safe_json_value(test_dict) + assert isinstance(result, dict) + assert result == test_dict + + def test_list_conversion(self): + """Test list conversion with mixed types""" + test_list = [ + "string", + 42, + 3.14, + True, + [1, 2, 3], + {"key": "value"}, + datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + Decimal("123.456"), + uuid4(), + ] + result = safe_json_value(test_list) + assert isinstance(result, list) + assert len(result) == len(test_list) + assert isinstance(result[6], str) # datetime should be converted to string + assert isinstance(result[7], float) # Decimal should be converted to float + assert isinstance(result[8], str) # UUID should be converted to string + + def test_tuple_conversion(self): + """Test tuple conversion to list""" + test_tuple = (1, "string", 3.14) + result = safe_json_value(test_tuple) + assert isinstance(result, list) + assert result == [1, "string", 3.14] + + def test_set_conversion(self): + """Test set conversion to list""" + test_set = {1, "string", 3.14} + result = safe_json_value(test_set) + assert isinstance(result, list) + # Note: set order is not guaranteed, so we check length and content + assert len(result) == 3 + assert 1 in result + assert "string" in result + assert 3.14 in result + + def test_basic_types_passthrough(self): + """Test that basic types are passed through unchanged""" + assert safe_json_value("string") == "string" + assert safe_json_value(42) == 42 + assert safe_json_value(3.14) == 3.14 + assert safe_json_value(True) is True + assert safe_json_value(False) is False + assert safe_json_value(None) is None + + def test_nested_complex_structure(self): + """Test complex nested structure with all types""" + complex_data = { + "dates": [date(2024, 1, 1), date(2024, 1, 2)], + "timestamps": [ + datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + datetime(2024, 1, 2, 12, 0, 0, tzinfo=pytz.UTC), + ], + "numbers": [Decimal("123.456"), Decimal("789.012")], + "identifiers": [uuid4(), uuid4()], + "binary_data": [b"hello", b"world"], + "arrays": [np.array([1, 2, 3]), np.array([4, 5, 6])], + } + + result = safe_json_value(complex_data) + + # Verify structure is maintained + assert isinstance(result, dict) + assert "dates" in result + assert "timestamps" in result + assert "numbers" in result + assert "identifiers" in result + assert "binary_data" in result + assert "arrays" in result + + # Verify conversions + assert all(isinstance(d, str) for d in result["dates"]) + assert all(isinstance(t, str) for t in result["timestamps"]) + assert all(isinstance(n, float) for n in result["numbers"]) + assert all(isinstance(i, str) for i in result["identifiers"]) + assert all(isinstance(b, str) for b in result["binary_data"]) + assert all(isinstance(a, list) for a in result["arrays"]) + + +class TestSafeJsonDict: + """Test suite for safe_json_dict function""" + + def test_valid_dict_conversion(self): + """Test valid dictionary conversion""" + test_dict = { + "string": "value", + "number": 42, + "datetime": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + "decimal": Decimal("123.456"), + } + result = safe_json_dict(test_dict) + assert isinstance(result, dict) + assert result["string"] == "value" + assert result["number"] == 42 + assert isinstance(result["datetime"], str) + assert isinstance(result["decimal"], float) + + def test_invalid_input_type(self): + """Test that invalid input types raise TypeError""" + with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"): + safe_json_dict("not a dict") + + with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"): + safe_json_dict([1, 2, 3]) + + with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"): + safe_json_dict(42) + + def test_empty_dict(self): + """Test empty dictionary handling""" + result = safe_json_dict({}) + assert result == {} + + def test_nested_dict_conversion(self): + """Test nested dictionary conversion""" + test_dict = { + "level1": { + "level2": {"datetime": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), "decimal": Decimal("123.456")} + } + } + result = safe_json_dict(test_dict) + assert isinstance(result["level1"]["level2"]["datetime"], str) + assert isinstance(result["level1"]["level2"]["decimal"], float) + + +class TestToolInvokeMessageJsonSerialization: + """Test suite for ToolInvokeMessage JSON serialization through safe_json_value""" + + def test_json_message_serialization(self): + """Test JSON message serialization with complex data""" + complex_data = { + "timestamp": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + "amount": Decimal("123.45"), + "id": uuid4(), + "binary": b"test data", + "array": np.array([1, 2, 3]), + } + + # Create JSON message + json_message = ToolInvokeMessage.JsonMessage(json_object=complex_data) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + + # Apply safe_json_value transformation + transformed_data = safe_json_value(message.message.json_object) + + # Verify transformations + assert isinstance(transformed_data["timestamp"], str) + assert isinstance(transformed_data["amount"], float) + assert isinstance(transformed_data["id"], str) + assert isinstance(transformed_data["binary"], str) + assert isinstance(transformed_data["array"], list) + + # Verify JSON serialization works + json_string = json.dumps(transformed_data, ensure_ascii=False) + assert isinstance(json_string, str) + + # Verify we can deserialize back + deserialized = json.loads(json_string) + assert deserialized["amount"] == 123.45 + assert deserialized["array"] == [1, 2, 3] + + def test_json_message_with_nested_structures(self): + """Test JSON message with deeply nested complex structures""" + nested_data = { + "level1": { + "level2": { + "level3": { + "dates": [date(2024, 1, 1), date(2024, 1, 2)], + "timestamps": [datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC)], + "numbers": [Decimal("1.1"), Decimal("2.2")], + "arrays": [np.array([1, 2]), np.array([3, 4])], + } + } + } + } + + json_message = ToolInvokeMessage.JsonMessage(json_object=nested_data) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + + # Transform the data + transformed_data = safe_json_value(message.message.json_object) + + # Verify nested transformations + level3 = transformed_data["level1"]["level2"]["level3"] + assert all(isinstance(d, str) for d in level3["dates"]) + assert all(isinstance(t, str) for t in level3["timestamps"]) + assert all(isinstance(n, float) for n in level3["numbers"]) + assert all(isinstance(a, list) for a in level3["arrays"]) + + # Test JSON serialization + json_string = json.dumps(transformed_data, ensure_ascii=False) + assert isinstance(json_string, str) + + # Verify deserialization + deserialized = json.loads(json_string) + assert deserialized["level1"]["level2"]["level3"]["numbers"] == [1.1, 2.2] + + def test_json_message_transformer_integration(self): + """Test integration with ToolFileMessageTransformer for JSON messages""" + complex_data = { + "metadata": { + "created_at": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + "version": Decimal("1.0"), + "tags": ["tag1", "tag2"], + }, + "data": {"values": np.array([1.1, 2.2, 3.3]), "binary": b"binary content"}, + } + + # Create message generator + def message_generator(): + json_message = ToolInvokeMessage.JsonMessage(json_object=complex_data) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + yield message + + # Transform messages + transformed_messages = list( + ToolFileMessageTransformer.transform_tool_invoke_messages( + message_generator(), user_id="test_user", tenant_id="test_tenant" + ) + ) + + assert len(transformed_messages) == 1 + transformed_message = transformed_messages[0] + assert transformed_message.type == ToolInvokeMessage.MessageType.JSON + + # Verify the JSON object was transformed + json_obj = transformed_message.message.json_object + assert isinstance(json_obj["metadata"]["created_at"], str) + assert isinstance(json_obj["metadata"]["version"], float) + assert isinstance(json_obj["data"]["values"], list) + assert isinstance(json_obj["data"]["binary"], str) + + # Test final JSON serialization + final_json = json.dumps(json_obj, ensure_ascii=False) + assert isinstance(final_json, str) + + # Verify we can deserialize + deserialized = json.loads(final_json) + assert deserialized["metadata"]["version"] == 1.0 + assert deserialized["data"]["values"] == [1.1, 2.2, 3.3] + + def test_edge_cases_and_error_handling(self): + """Test edge cases and error handling in JSON serialization""" + # Test with None values + data_with_none = {"null_value": None, "empty_string": "", "zero": 0, "false_value": False} + + json_message = ToolInvokeMessage.JsonMessage(json_object=data_with_none) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + + transformed_data = safe_json_value(message.message.json_object) + json_string = json.dumps(transformed_data, ensure_ascii=False) + + # Verify serialization works with edge cases + assert json_string is not None + deserialized = json.loads(json_string) + assert deserialized["null_value"] is None + assert deserialized["empty_string"] == "" + assert deserialized["zero"] == 0 + assert deserialized["false_value"] is False + + # Test with very large numbers + large_data = { + "large_int": 2**63 - 1, + "large_float": 1.7976931348623157e308, + "small_float": 2.2250738585072014e-308, + } + + json_message = ToolInvokeMessage.JsonMessage(json_object=large_data) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + + transformed_data = safe_json_value(message.message.json_object) + json_string = json.dumps(transformed_data, ensure_ascii=False) + + # Verify large numbers are handled correctly + deserialized = json.loads(json_string) + assert deserialized["large_int"] == 2**63 - 1 + assert deserialized["large_float"] == 1.7976931348623157e308 + assert deserialized["small_float"] == 2.2250738585072014e-308 + + +class TestEndToEndSerialization: + """Test suite for end-to-end serialization workflow""" + + def test_complete_workflow_with_real_data(self): + """Test complete workflow from complex data to JSON string and back""" + # Simulate real-world complex data structure + real_world_data = { + "user_profile": { + "id": uuid4(), + "name": "John Doe", + "email": "john@example.com", + "created_at": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + "last_login": datetime(2024, 1, 15, 14, 30, 0, tzinfo=pytz.UTC), + "preferences": {"theme": "dark", "language": "en", "timezone": "UTC"}, + }, + "analytics": { + "session_count": 42, + "total_time": Decimal("123.45"), + "metrics": np.array([1.1, 2.2, 3.3, 4.4, 5.5]), + "events": [ + { + "timestamp": datetime(2024, 1, 1, 10, 0, 0, tzinfo=pytz.UTC), + "action": "login", + "duration": Decimal("5.67"), + }, + { + "timestamp": datetime(2024, 1, 1, 11, 0, 0, tzinfo=pytz.UTC), + "action": "logout", + "duration": Decimal("3600.0"), + }, + ], + }, + "files": [ + { + "id": uuid4(), + "name": "document.pdf", + "size": 1024, + "uploaded_at": datetime(2024, 1, 1, 9, 0, 0, tzinfo=pytz.UTC), + "checksum": b"abc123def456", + } + ], + } + + # Step 1: Create ToolInvokeMessage + json_message = ToolInvokeMessage.JsonMessage(json_object=real_world_data) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + + # Step 2: Apply safe_json_value transformation + transformed_data = safe_json_value(message.message.json_object) + + # Step 3: Serialize to JSON string + json_string = json.dumps(transformed_data, ensure_ascii=False) + + # Step 4: Verify the string is valid JSON + assert isinstance(json_string, str) + assert json_string.startswith("{") + assert json_string.endswith("}") + + # Step 5: Deserialize back to Python object + deserialized_data = json.loads(json_string) + + # Step 6: Verify data integrity + assert deserialized_data["user_profile"]["name"] == "John Doe" + assert deserialized_data["user_profile"]["email"] == "john@example.com" + assert isinstance(deserialized_data["user_profile"]["created_at"], str) + assert isinstance(deserialized_data["analytics"]["total_time"], float) + assert deserialized_data["analytics"]["total_time"] == 123.45 + assert isinstance(deserialized_data["analytics"]["metrics"], list) + assert deserialized_data["analytics"]["metrics"] == [1.1, 2.2, 3.3, 4.4, 5.5] + assert isinstance(deserialized_data["files"][0]["checksum"], str) + + # Step 7: Verify all complex types were properly converted + self._verify_all_complex_types_converted(deserialized_data) + + def _verify_all_complex_types_converted(self, data): + """Helper method to verify all complex types were properly converted""" + if isinstance(data, dict): + for key, value in data.items(): + if key in ["id", "checksum"]: + # These should be strings (UUID/bytes converted) + assert isinstance(value, str) + elif key in ["created_at", "last_login", "timestamp", "uploaded_at"]: + # These should be strings (datetime converted) + assert isinstance(value, str) + elif key in ["total_time", "duration"]: + # These should be floats (Decimal converted) + assert isinstance(value, float) + elif key == "metrics": + # This should be a list (ndarray converted) + assert isinstance(value, list) + else: + # Recursively check nested structures + self._verify_all_complex_types_converted(value) + elif isinstance(data, list): + for item in data: + self._verify_all_complex_types_converted(item) diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py new file mode 100644 index 0000000000..bc46fe8322 --- /dev/null +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -0,0 +1,149 @@ +"""Tests for Celery SSL configuration.""" + +import ssl +from unittest.mock import MagicMock, patch + + +class TestCelerySSLConfiguration: + """Test suite for Celery SSL configuration.""" + + def test_get_celery_ssl_options_when_ssl_disabled(self): + """Test SSL options when REDIS_USE_SSL is False.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = False + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is None + + def test_get_celery_ssl_options_when_broker_not_redis(self): + """Test SSL options when broker is not Redis.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "amqp://localhost:5672" + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is None + + def test_get_celery_ssl_options_with_cert_none(self): + """Test SSL options with CERT_NONE requirement.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE" + mock_config.REDIS_SSL_CA_CERTS = None + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_NONE + assert result["ssl_ca_certs"] is None + assert result["ssl_certfile"] is None + assert result["ssl_keyfile"] is None + + def test_get_celery_ssl_options_with_cert_required(self): + """Test SSL options with CERT_REQUIRED and certificates.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "rediss://localhost:6380/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_REQUIRED" + mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" + mock_config.REDIS_SSL_CERTFILE = "/path/to/client.crt" + mock_config.REDIS_SSL_KEYFILE = "/path/to/client.key" + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_REQUIRED + assert result["ssl_ca_certs"] == "/path/to/ca.crt" + assert result["ssl_certfile"] == "/path/to/client.crt" + assert result["ssl_keyfile"] == "/path/to/client.key" + + def test_get_celery_ssl_options_with_cert_optional(self): + """Test SSL options with CERT_OPTIONAL requirement.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_OPTIONAL" + mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_OPTIONAL + assert result["ssl_ca_certs"] == "/path/to/ca.crt" + + def test_get_celery_ssl_options_with_invalid_cert_reqs(self): + """Test SSL options with invalid cert requirement defaults to CERT_NONE.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "INVALID_VALUE" + mock_config.REDIS_SSL_CA_CERTS = None + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_NONE # Should default to CERT_NONE + + def test_celery_init_applies_ssl_to_broker_and_backend(self): + """Test that SSL options are applied to both broker and backend when using Redis.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.CELERY_BACKEND = "redis" + mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE" + mock_config.REDIS_SSL_CA_CERTS = None + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + mock_config.CELERY_USE_SENTINEL = False + mock_config.LOG_FORMAT = "%(message)s" + mock_config.LOG_TZ = "UTC" + mock_config.LOG_FILE = None + + # Mock all the scheduler configs + mock_config.CELERY_BEAT_SCHEDULER_TIME = 1 + mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False + mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False + mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False + mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False + mock_config.ENABLE_CLEAN_MESSAGES = False + mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False + mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False + mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False + + with patch("extensions.ext_celery.dify_config", mock_config): + from dify_app import DifyApp + from extensions.ext_celery import init_app + + app = DifyApp(__name__) + celery_app = init_app(app) + + # Check that SSL options were applied + assert "broker_use_ssl" in celery_app.conf + assert celery_app.conf["broker_use_ssl"] is not None + assert celery_app.conf["broker_use_ssl"]["ssl_cert_reqs"] == ssl.CERT_NONE + + # Check that SSL is also applied to Redis backend + assert "redis_backend_use_ssl" in celery_app.conf + assert celery_app.conf["redis_backend_use_ssl"] is not None diff --git a/api/uv.lock b/api/uv.lock index 870975418f..cecce2bc43 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1253,6 +1253,7 @@ dependencies = [ { name = "flask-cors" }, { name = "flask-login" }, { name = "flask-migrate" }, + { name = "flask-orjson" }, { name = "flask-restful" }, { name = "flask-sqlalchemy" }, { name = "gevent" }, @@ -1440,6 +1441,7 @@ requires-dist = [ { name = "flask-cors", specifier = "~=6.0.0" }, { name = "flask-login", specifier = "~=0.6.3" }, { name = "flask-migrate", specifier = "~=4.0.7" }, + { name = "flask-orjson", specifier = "~=2.0.0" }, { name = "flask-restful", specifier = "~=0.3.10" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "gevent", specifier = "~=24.11.1" }, @@ -1859,6 +1861,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/93/01/587023575286236f95d2ab8a826c320375ed5ea2102bb103ed89704ffa6b/Flask_Migrate-4.0.7-py3-none-any.whl", hash = "sha256:5c532be17e7b43a223b7500d620edae33795df27c75811ddf32560f7d48ec617", size = 21127, upload-time = "2024-03-11T18:42:59.462Z" }, ] +[[package]] +name = "flask-orjson" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "flask" }, + { name = "orjson" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/49/575796f6ddca171d82dbb12762e33166c8b8f8616c946f0a6dfbb9bc3cd6/flask_orjson-2.0.0.tar.gz", hash = "sha256:6df6631437f9bc52cf9821735f896efa5583b5f80712f7d29d9ef69a79986a9c", size = 2974, upload-time = "2024-01-15T00:03:22.236Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/ca/53e14be018a2284acf799830e8cd8e0b263c0fd3dff1ad7b35f8417e7067/flask_orjson-2.0.0-py3-none-any.whl", hash = "sha256:5d15f2ba94b8d6c02aee88fc156045016e83db9eda2c30545fabd640aebaec9d", size = 3622, upload-time = "2024-01-15T00:03:17.511Z" }, +] + [[package]] name = "flask-restful" version = "0.3.10" diff --git a/docker/.env.example b/docker/.env.example index 7a435ad66c..743a1e8bba 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -264,6 +264,15 @@ REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false +# SSL configuration for Redis (when REDIS_USE_SSL=true) +REDIS_SSL_CERT_REQS=CERT_NONE +# Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED +REDIS_SSL_CA_CERTS= +# Path to CA certificate file for SSL verification +REDIS_SSL_CERTFILE= +# Path to client certificate file for SSL authentication +REDIS_SSL_KEYFILE= +# Path to client private key file for SSL authentication REDIS_DB=0 # Whether to use Redis Sentinel mode. @@ -861,7 +870,7 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms # Repository configuration # Core workflow execution repository implementation -# Options: +# Options: # - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default) # - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository @@ -1157,6 +1166,9 @@ MARKETPLACE_API_URL=https://marketplace.dify.ai FORCE_VERIFYING_SIGNATURE=true +PLUGIN_STDIO_BUFFER_SIZE=1024 +PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880 + PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 PLUGIN_MAX_EXECUTION_TIMEOUT=600 # PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 1dbd9b3993..04981f6b7f 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -181,6 +181,8 @@ services: FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} + PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024} + PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local} PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 101f8eb323..bcf9588dff 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -71,6 +71,10 @@ x-shared-env: &shared-api-worker-env REDIS_USERNAME: ${REDIS_USERNAME:-} REDIS_PASSWORD: ${REDIS_PASSWORD:-difyai123456} REDIS_USE_SSL: ${REDIS_USE_SSL:-false} + REDIS_SSL_CERT_REQS: ${REDIS_SSL_CERT_REQS:-CERT_NONE} + REDIS_SSL_CA_CERTS: ${REDIS_SSL_CA_CERTS:-} + REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-} + REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-} REDIS_DB: ${REDIS_DB:-0} REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} REDIS_SENTINELS: ${REDIS_SENTINELS:-} @@ -506,6 +510,8 @@ x-shared-env: &shared-api-worker-env MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} + PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024} + PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} @@ -747,6 +753,8 @@ services: FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} + PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024} + PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local} PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage} diff --git a/web/Dockerfile b/web/Dockerfile index d59039528c..d284efca87 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -6,7 +6,7 @@ LABEL maintainer="takatost@gmail.com" # RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories RUN apk add --no-cache tzdata -RUN npm install -g pnpm@10.13.1 +RUN corepack enable ENV PNPM_HOME="/pnpm" ENV PATH="$PNPM_HOME:$PATH" @@ -19,6 +19,9 @@ WORKDIR /app/web COPY package.json . COPY pnpm-lock.yaml . +# Use packageManager from package.json +RUN corepack install + # if you located in China, you can use taobao registry to speed up # RUN pnpm install --frozen-lockfile --registry https://registry.npmmirror.com/ diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index b83e9e6a2a..2ea5bfb769 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -8,7 +8,6 @@ import { } from '@heroicons/react/24/outline' import { RiCloseLine, RiEditFill } from '@remixicon/react' import { get } from 'lodash-es' -import InfiniteScroll from 'react-infinite-scroll-component' import dayjs from 'dayjs' import utc from 'dayjs/plugin/utc' import timezone from 'dayjs/plugin/timezone' @@ -111,7 +110,8 @@ const statusTdRender = (statusCount: StatusCount) => { const getFormattedChatList = (messages: ChatMessage[], conversationId: string, timezone: string, format: string) => { const newChatList: IChatItem[] = [] - messages.forEach((item: ChatMessage) => { + try { + messages.forEach((item: ChatMessage) => { const questionFiles = item.message_files?.filter((file: any) => file.belongs_to === 'user') || [] newChatList.push({ id: `question-${item.id}`, @@ -178,7 +178,13 @@ const getFormattedChatList = (messages: ChatMessage[], conversationId: string, t parentMessageId: `question-${item.id}`, }) }) - return newChatList + + return newChatList + } + catch (error) { + console.error('getFormattedChatList processing failed:', error) + throw error + } } type IDetailPanel = { @@ -188,6 +194,9 @@ type IDetailPanel = { } function DetailPanel({ detail, onFeedback }: IDetailPanel) { + const MIN_ITEMS_FOR_SCROLL_LOADING = 8 + const SCROLL_THRESHOLD_PX = 50 + const SCROLL_DEBOUNCE_MS = 200 const { userProfile: { timezone } } = useAppContext() const { formatTime } = useTimestamp() const { onClose, appDetail } = useContext(DrawerContext) @@ -204,13 +213,19 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { const { t } = useTranslation() const [hasMore, setHasMore] = useState(true) const [varValues, setVarValues] = useState>({}) + const isLoadingRef = useRef(false) const [allChatItems, setAllChatItems] = useState([]) const [chatItemTree, setChatItemTree] = useState([]) const [threadChatItems, setThreadChatItems] = useState([]) const fetchData = useCallback(async () => { + if (isLoadingRef.current) + return + try { + isLoadingRef.current = true + if (!hasMore) return @@ -218,8 +233,11 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { conversation_id: detail.id, limit: 10, } - if (allChatItems[0]?.id) - params.first_id = allChatItems[0]?.id.replace('question-', '') + // Use the oldest answer item ID for pagination + const answerItems = allChatItems.filter(item => item.isAnswer) + const oldestAnswerItem = answerItems[answerItems.length - 1] + if (oldestAnswerItem?.id) + params.first_id = oldestAnswerItem.id const messageRes = await fetchChatMessages({ url: `/apps/${appDetail?.id}/chat-messages`, params, @@ -249,15 +267,20 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { } setChatItemTree(tree) - setThreadChatItems(getThreadMessages(tree, newAllChatItems.at(-1)?.id)) + const lastMessageId = newAllChatItems.length > 0 ? newAllChatItems[newAllChatItems.length - 1].id : undefined + setThreadChatItems(getThreadMessages(tree, lastMessageId)) } catch (err) { - console.error(err) + console.error('fetchData execution failed:', err) + } + finally { + isLoadingRef.current = false } }, [allChatItems, detail.id, hasMore, timezone, t, appDetail, detail?.model_config?.configs?.introduction]) const switchSibling = useCallback((siblingMessageId: string) => { - setThreadChatItems(getThreadMessages(chatItemTree, siblingMessageId)) + const newThreadChatItems = getThreadMessages(chatItemTree, siblingMessageId) + setThreadChatItems(newThreadChatItems) }, [chatItemTree]) const handleAnnotationEdited = useCallback((query: string, answer: string, index: number) => { @@ -344,13 +367,217 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { const fetchInitiated = useRef(false) + // Only load initial messages, don't auto-load more useEffect(() => { if (appDetail?.id && detail.id && appDetail?.mode !== 'completion' && !fetchInitiated.current) { + // Mark as initialized, but don't auto-load more messages fetchInitiated.current = true + // Still call fetchData to get initial messages fetchData() } }, [appDetail?.id, detail.id, appDetail?.mode, fetchData]) + const [isLoading, setIsLoading] = useState(false) + + const loadMoreMessages = useCallback(async () => { + if (isLoading || !hasMore || !appDetail?.id || !detail.id) + return + + setIsLoading(true) + + try { + const params: ChatMessagesRequest = { + conversation_id: detail.id, + limit: 10, + } + + // Use the earliest response item as the first_id + const answerItems = allChatItems.filter(item => item.isAnswer) + const oldestAnswerItem = answerItems[answerItems.length - 1] + if (oldestAnswerItem?.id) { + params.first_id = oldestAnswerItem.id + } + else if (allChatItems.length > 0 && allChatItems[0]?.id) { + const firstId = allChatItems[0].id.replace('question-', '').replace('answer-', '') + params.first_id = firstId + } + + const messageRes = await fetchChatMessages({ + url: `/apps/${appDetail.id}/chat-messages`, + params, + }) + + if (!messageRes.data || messageRes.data.length === 0) { + setHasMore(false) + return + } + + if (messageRes.data.length > 0) { + const varValues = messageRes.data.at(-1)!.inputs + setVarValues(varValues) + } + + setHasMore(messageRes.has_more) + + const newItems = getFormattedChatList( + messageRes.data, + detail.id, + timezone!, + t('appLog.dateTimeFormat') as string, + ) + + // Check for duplicate messages + const existingIds = new Set(allChatItems.map(item => item.id)) + const uniqueNewItems = newItems.filter(item => !existingIds.has(item.id)) + + if (uniqueNewItems.length === 0) { + if (allChatItems.length > 1) { + const nextId = allChatItems[1].id.replace('question-', '').replace('answer-', '') + + const retryParams = { + ...params, + first_id: nextId, + } + + const retryRes = await fetchChatMessages({ + url: `/apps/${appDetail.id}/chat-messages`, + params: retryParams, + }) + + if (retryRes.data && retryRes.data.length > 0) { + const retryItems = getFormattedChatList( + retryRes.data, + detail.id, + timezone!, + t('appLog.dateTimeFormat') as string, + ) + + const retryUniqueItems = retryItems.filter(item => !existingIds.has(item.id)) + if (retryUniqueItems.length > 0) { + const newAllChatItems = [ + ...retryUniqueItems, + ...allChatItems, + ] + + setAllChatItems(newAllChatItems) + + let tree = buildChatItemTree(newAllChatItems) + if (retryRes.has_more === false && detail?.model_config?.configs?.introduction) { + tree = [{ + id: 'introduction', + isAnswer: true, + isOpeningStatement: true, + content: detail?.model_config?.configs?.introduction ?? 'hello', + feedbackDisabled: true, + children: tree, + }] + } + setChatItemTree(tree) + setHasMore(retryRes.has_more) + setThreadChatItems(getThreadMessages(tree, newAllChatItems.at(-1)?.id)) + return + } + } + } + } + + const newAllChatItems = [ + ...uniqueNewItems, + ...allChatItems, + ] + + setAllChatItems(newAllChatItems) + + let tree = buildChatItemTree(newAllChatItems) + if (messageRes.has_more === false && detail?.model_config?.configs?.introduction) { + tree = [{ + id: 'introduction', + isAnswer: true, + isOpeningStatement: true, + content: detail?.model_config?.configs?.introduction ?? 'hello', + feedbackDisabled: true, + children: tree, + }] + } + setChatItemTree(tree) + + setThreadChatItems(getThreadMessages(tree, newAllChatItems.at(-1)?.id)) + } + catch (error) { + console.error(error) + setHasMore(false) + } + finally { + setIsLoading(false) + } + }, [allChatItems, detail.id, hasMore, isLoading, timezone, t, appDetail]) + + useEffect(() => { + const scrollableDiv = document.getElementById('scrollableDiv') + const outerDiv = scrollableDiv?.parentElement + const chatContainer = document.querySelector('.mx-1.mb-1.grow.overflow-auto') as HTMLElement + + let scrollContainer: HTMLElement | null = null + + if (outerDiv && outerDiv.scrollHeight > outerDiv.clientHeight) { + scrollContainer = outerDiv + } + else if (scrollableDiv && scrollableDiv.scrollHeight > scrollableDiv.clientHeight) { + scrollContainer = scrollableDiv + } + else if (chatContainer && chatContainer.scrollHeight > chatContainer.clientHeight) { + scrollContainer = chatContainer + } + else { + const possibleContainers = document.querySelectorAll('.overflow-auto, .overflow-y-auto') + for (let i = 0; i < possibleContainers.length; i++) { + const container = possibleContainers[i] as HTMLElement + if (container.scrollHeight > container.clientHeight) { + scrollContainer = container + break + } + } + } + + if (!scrollContainer) + return + + let lastLoadTime = 0 + const throttleDelay = 200 + + const handleScroll = () => { + const currentScrollTop = scrollContainer!.scrollTop + const scrollHeight = scrollContainer!.scrollHeight + const clientHeight = scrollContainer!.clientHeight + + const distanceFromTop = currentScrollTop + const distanceFromBottom = scrollHeight - currentScrollTop - clientHeight + + const now = Date.now() + + const isNearTop = distanceFromTop < 30 + // eslint-disable-next-line sonarjs/no-unused-vars + const _distanceFromBottom = distanceFromBottom < 30 + if (isNearTop && hasMore && !isLoading && (now - lastLoadTime > throttleDelay)) { + lastLoadTime = now + loadMoreMessages() + } + } + + scrollContainer.addEventListener('scroll', handleScroll, { passive: true }) + + const handleWheel = (e: WheelEvent) => { + if (e.deltaY < 0) + handleScroll() + } + scrollContainer.addEventListener('wheel', handleWheel, { passive: true }) + + return () => { + scrollContainer!.removeEventListener('scroll', handleScroll) + scrollContainer!.removeEventListener('wheel', handleWheel) + } + }, [hasMore, isLoading, loadMoreMessages]) + const isChatMode = appDetail?.mode !== 'completion' const isAdvanced = appDetail?.mode === 'advanced-chat' @@ -378,6 +605,36 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { return () => cancelAnimationFrame(raf) }, []) + // Add scroll listener to ensure loading is triggered + useEffect(() => { + if (threadChatItems.length >= MIN_ITEMS_FOR_SCROLL_LOADING && hasMore) { + const scrollableDiv = document.getElementById('scrollableDiv') + + if (scrollableDiv) { + let loadingTimeout: NodeJS.Timeout | null = null + + const handleScroll = () => { + const { scrollTop } = scrollableDiv + + // Trigger loading when scrolling near the top + if (scrollTop < SCROLL_THRESHOLD_PX && !isLoadingRef.current) { + if (loadingTimeout) + clearTimeout(loadingTimeout) + + loadingTimeout = setTimeout(fetchData, SCROLL_DEBOUNCE_MS) // 200ms debounce + } + } + + scrollableDiv.addEventListener('scroll', handleScroll) + return () => { + scrollableDiv.removeEventListener('scroll', handleScroll) + if (loadingTimeout) + clearTimeout(loadingTimeout) + } + } + } + }, [threadChatItems.length, hasMore, fetchData]) + return (
{/* Panel Header */} @@ -439,8 +696,8 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { siteInfo={null} />
- : threadChatItems.length < 8 - ?
+ : threadChatItems.length < MIN_ITEMS_FOR_SCROLL_LOADING ? ( +
- :
{/* Put the scroll bar always on the bottom */} - {t('appLog.detail.loading')}...
} - // endMessage={
Nothing more to show
} - // below props only if you need pull down functionality - refreshFunction={fetchData} - pullDownToRefresh - pullDownToRefreshThreshold={50} - // pullDownToRefreshContent={ - //
Pull down to refresh
- // } - // releaseToRefreshContent={ - //
Release to refresh
- // } - // To put endMessage and loader to the top. - style={{ display: 'flex', flexDirection: 'column-reverse' }} - inverse={true} - > +
+ {/* Loading state indicator - only shown when loading */} + {hasMore && isLoading && ( +
+
+ {t('appLog.detail.loading')}... +
+
+ )} + - +
+ ) } {showMessageLogModal && ( diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 688da4c25d..ee9230af12 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -407,8 +407,8 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { } btnClassName={open => cn( - open ? '!bg-black/5 !shadow-none' : '!bg-transparent', - 'h-8 w-8 rounded-md border-none !p-2 hover:!bg-black/5', + open ? '!bg-state-base-hover !shadow-none' : '!bg-transparent', + 'h-8 w-8 rounded-md border-none !p-2 hover:!bg-state-base-hover', ) } popupClassName={ diff --git a/web/app/components/base/chat/chat/loading-anim/index.tsx b/web/app/components/base/chat/chat/loading-anim/index.tsx index dd43ef9c14..801c89fce7 100644 --- a/web/app/components/base/chat/chat/loading-anim/index.tsx +++ b/web/app/components/base/chat/chat/loading-anim/index.tsx @@ -2,6 +2,7 @@ import type { FC } from 'react' import React from 'react' import s from './style.module.css' +import cn from '@/utils/classnames' export type ILoadingAnimProps = { type: 'text' | 'avatar' @@ -11,7 +12,7 @@ const LoadingAnim: FC = ({ type, }) => { return ( -
+
) } export default React.memo(LoadingAnim) diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx index 51e33c43d2..53db991e71 100644 --- a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx @@ -8,6 +8,7 @@ import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import ConfirmAddVar from '@/app/components/app/configuration/config-prompt/confirm-add-var' +import PromptEditor from '@/app/components/base/prompt-editor' import type { OpeningStatement } from '@/app/components/base/features/types' import { getInputKeys } from '@/app/components/base/block-input' import type { PromptVariable } from '@/models/debug' @@ -101,7 +102,7 @@ const OpeningSettingModal = ({
·
{tempSuggestedQuestions.length}/{MAX_QUESTION_NUM}
- +
{t('appDebug.feature.conversationOpener.title')}
-
+
-