diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index e1b7c8f4e8..766f55f642 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -28,7 +28,27 @@ jobs: run: | uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all - + # Convert Optional[T] to T | None (ignoring quoted types) + cat > /tmp/optional-rule.yml << 'EOF' + id: convert-optional-to-union + language: python + rule: + kind: generic_type + all: + - has: + kind: identifier + pattern: Optional + - has: + kind: type_parameter + has: + kind: type + pattern: $T + fix: $T | None + EOF + uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all + # Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax) + find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; + find . -name "*.py.bak" -type f -delete - name: mdformat run: | diff --git a/api/commands.py b/api/commands.py index d46750316b..0adef498cd 100644 --- a/api/commands.py +++ b/api/commands.py @@ -2,7 +2,7 @@ import base64 import json import logging import secrets -from typing import Any, Optional +from typing import Any import click import sqlalchemy as sa @@ -639,7 +639,7 @@ def old_metadata_migration(): @click.option("--email", prompt=True, help="Tenant account email.") @click.option("--name", prompt=True, help="Workspace name.") @click.option("--language", prompt=True, help="Account language, default: en-US.") -def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None): +def create_tenant(email: str, language: str | None = None, name: str | None = None): """ Create tenant account """ diff --git a/api/configs/extra/notion_config.py b/api/configs/extra/notion_config.py index f9c4d73463..9694f3db6b 100644 --- a/api/configs/extra/notion_config.py +++ b/api/configs/extra/notion_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,28 +7,28 @@ class NotionConfig(BaseSettings): Configuration settings for Notion integration """ - NOTION_CLIENT_ID: Optional[str] = Field( + NOTION_CLIENT_ID: str | None = Field( description="Client ID for Notion API authentication. Required for OAuth 2.0 flow.", default=None, ) - NOTION_CLIENT_SECRET: Optional[str] = Field( + NOTION_CLIENT_SECRET: str | None = Field( description="Client secret for Notion API authentication. Required for OAuth 2.0 flow.", default=None, ) - NOTION_INTEGRATION_TYPE: Optional[str] = Field( + NOTION_INTEGRATION_TYPE: str | None = Field( description="Type of Notion integration." " Set to 'internal' for internal integrations, or None for public integrations.", default=None, ) - NOTION_INTERNAL_SECRET: Optional[str] = Field( + NOTION_INTERNAL_SECRET: str | None = Field( description="Secret key for internal Notion integrations. Required when NOTION_INTEGRATION_TYPE is 'internal'.", default=None, ) - NOTION_INTEGRATION_TOKEN: Optional[str] = Field( + NOTION_INTEGRATION_TOKEN: str | None = Field( description="Integration token for Notion API access. Used for direct API calls without OAuth flow.", default=None, ) diff --git a/api/configs/extra/sentry_config.py b/api/configs/extra/sentry_config.py index f76a6bdb95..d72d01b49f 100644 --- a/api/configs/extra/sentry_config.py +++ b/api/configs/extra/sentry_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeFloat from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class SentryConfig(BaseSettings): Configuration settings for Sentry error tracking and performance monitoring """ - SENTRY_DSN: Optional[str] = Field( + SENTRY_DSN: str | None = Field( description="Sentry Data Source Name (DSN)." " This is the unique identifier of your Sentry project, used to send events to the correct project.", default=None, diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 0d6f4e416e..0b340c51e7 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import ( AliasChoices, @@ -57,7 +57,7 @@ class SecurityConfig(BaseSettings): default=False, ) - ADMIN_API_KEY: Optional[str] = Field( + ADMIN_API_KEY: str | None = Field( description="admin api key for authentication", default=None, ) @@ -97,17 +97,17 @@ class CodeExecutionSandboxConfig(BaseSettings): default="dify-sandbox", ) - CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field( + CODE_EXECUTION_CONNECT_TIMEOUT: float | None = Field( description="Connection timeout in seconds for code execution requests", default=10.0, ) - CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field( + CODE_EXECUTION_READ_TIMEOUT: float | None = Field( description="Read timeout in seconds for code execution requests", default=60.0, ) - CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field( + CODE_EXECUTION_WRITE_TIMEOUT: float | None = Field( description="Write timeout in seconds for code execution request", default=10.0, ) @@ -368,17 +368,17 @@ class HttpConfig(BaseSettings): default=3, ) - SSRF_PROXY_ALL_URL: Optional[str] = Field( + SSRF_PROXY_ALL_URL: str | None = Field( description="Proxy URL for HTTP or HTTPS requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) - SSRF_PROXY_HTTP_URL: Optional[str] = Field( + SSRF_PROXY_HTTP_URL: str | None = Field( description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) - SSRF_PROXY_HTTPS_URL: Optional[str] = Field( + SSRF_PROXY_HTTPS_URL: str | None = Field( description="Proxy URL for HTTPS requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) @@ -420,7 +420,7 @@ class InnerAPIConfig(BaseSettings): default=False, ) - INNER_API_KEY: Optional[str] = Field( + INNER_API_KEY: str | None = Field( description="API key for accessing the internal API", default=None, ) @@ -436,7 +436,7 @@ class LoggingConfig(BaseSettings): default="INFO", ) - LOG_FILE: Optional[str] = Field( + LOG_FILE: str | None = Field( description="File path for log output.", default=None, ) @@ -456,12 +456,12 @@ class LoggingConfig(BaseSettings): default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", ) - LOG_DATEFORMAT: Optional[str] = Field( + LOG_DATEFORMAT: str | None = Field( description="Date format string for log timestamps", default=None, ) - LOG_TZ: Optional[str] = Field( + LOG_TZ: str | None = Field( description="Timezone for log timestamps (e.g., 'America/New_York')", default="UTC", ) @@ -595,22 +595,22 @@ class AuthConfig(BaseSettings): default="/console/api/oauth/authorize", ) - GITHUB_CLIENT_ID: Optional[str] = Field( + GITHUB_CLIENT_ID: str | None = Field( description="GitHub OAuth client ID", default=None, ) - GITHUB_CLIENT_SECRET: Optional[str] = Field( + GITHUB_CLIENT_SECRET: str | None = Field( description="GitHub OAuth client secret", default=None, ) - GOOGLE_CLIENT_ID: Optional[str] = Field( + GOOGLE_CLIENT_ID: str | None = Field( description="Google OAuth client ID", default=None, ) - GOOGLE_CLIENT_SECRET: Optional[str] = Field( + GOOGLE_CLIENT_SECRET: str | None = Field( description="Google OAuth client secret", default=None, ) @@ -678,42 +678,42 @@ class MailConfig(BaseSettings): Configuration for email services """ - MAIL_TYPE: Optional[str] = Field( + MAIL_TYPE: str | None = Field( description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.", default=None, ) - MAIL_DEFAULT_SEND_FROM: Optional[str] = Field( + MAIL_DEFAULT_SEND_FROM: str | None = Field( description="Default email address to use as the sender", default=None, ) - RESEND_API_KEY: Optional[str] = Field( + RESEND_API_KEY: str | None = Field( description="API key for Resend email service", default=None, ) - RESEND_API_URL: Optional[str] = Field( + RESEND_API_URL: str | None = Field( description="API URL for Resend email service", default=None, ) - SMTP_SERVER: Optional[str] = Field( + SMTP_SERVER: str | None = Field( description="SMTP server hostname", default=None, ) - SMTP_PORT: Optional[int] = Field( + SMTP_PORT: int | None = Field( description="SMTP server port number", default=465, ) - SMTP_USERNAME: Optional[str] = Field( + SMTP_USERNAME: str | None = Field( description="Username for SMTP authentication", default=None, ) - SMTP_PASSWORD: Optional[str] = Field( + SMTP_PASSWORD: str | None = Field( description="Password for SMTP authentication", default=None, ) @@ -733,7 +733,7 @@ class MailConfig(BaseSettings): default=50, ) - SENDGRID_API_KEY: Optional[str] = Field( + SENDGRID_API_KEY: str | None = Field( description="API key for SendGrid service", default=None, ) @@ -756,17 +756,17 @@ class RagEtlConfig(BaseSettings): default="database", ) - UNSTRUCTURED_API_URL: Optional[str] = Field( + UNSTRUCTURED_API_URL: str | None = Field( description="API URL for Unstructured.io service", default=None, ) - UNSTRUCTURED_API_KEY: Optional[str] = Field( + UNSTRUCTURED_API_KEY: str | None = Field( description="API key for Unstructured.io service", default="", ) - SCARF_NO_ANALYTICS: Optional[str] = Field( + SCARF_NO_ANALYTICS: str | None = Field( description="This is about whether to disable Scarf analytics in Unstructured library.", default="false", ) diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 18ef1ed45b..476b397ba1 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt from pydantic_settings import BaseSettings @@ -40,17 +38,17 @@ class HostedOpenAiConfig(BaseSettings): Configuration for hosted OpenAI service """ - HOSTED_OPENAI_API_KEY: Optional[str] = Field( + HOSTED_OPENAI_API_KEY: str | None = Field( description="API key for hosted OpenAI service", default=None, ) - HOSTED_OPENAI_API_BASE: Optional[str] = Field( + HOSTED_OPENAI_API_BASE: str | None = Field( description="Base URL for hosted OpenAI API", default=None, ) - HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field( + HOSTED_OPENAI_API_ORGANIZATION: str | None = Field( description="Organization ID for hosted OpenAI service", default=None, ) @@ -110,12 +108,12 @@ class HostedAzureOpenAiConfig(BaseSettings): default=False, ) - HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field( + HOSTED_AZURE_OPENAI_API_KEY: str | None = Field( description="API key for hosted Azure OpenAI service", default=None, ) - HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field( + HOSTED_AZURE_OPENAI_API_BASE: str | None = Field( description="Base URL for hosted Azure OpenAI API", default=None, ) @@ -131,12 +129,12 @@ class HostedAnthropicConfig(BaseSettings): Configuration for hosted Anthropic service """ - HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field( + HOSTED_ANTHROPIC_API_BASE: str | None = Field( description="Base URL for hosted Anthropic API", default=None, ) - HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field( + HOSTED_ANTHROPIC_API_KEY: str | None = Field( description="API key for hosted Anthropic service", default=None, ) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 591c24cbe0..dbad90270e 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,5 +1,5 @@ import os -from typing import Any, Literal, Optional +from typing import Any, Literal from urllib.parse import parse_qsl, quote_plus from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field @@ -78,18 +78,18 @@ class StorageConfig(BaseSettings): class VectorStoreConfig(BaseSettings): - VECTOR_STORE: Optional[str] = Field( + VECTOR_STORE: str | None = Field( description="Type of vector store to use for efficient similarity search." " Set to None if not using a vector store.", default=None, ) - VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field( + VECTOR_STORE_WHITELIST_ENABLE: bool | None = Field( description="Enable whitelist for vector store.", default=False, ) - VECTOR_INDEX_NAME_PREFIX: Optional[str] = Field( + VECTOR_INDEX_NAME_PREFIX: str | None = Field( description="Prefix used to create collection name in vector database", default="Vector_index", ) @@ -225,26 +225,26 @@ class CeleryConfig(DatabaseConfig): default="redis", ) - CELERY_BROKER_URL: Optional[str] = Field( + CELERY_BROKER_URL: str | None = Field( description="URL of the message broker for Celery tasks.", default=None, ) - CELERY_USE_SENTINEL: Optional[bool] = Field( + CELERY_USE_SENTINEL: bool | None = Field( description="Whether to use Redis Sentinel for high availability.", default=False, ) - CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field( + CELERY_SENTINEL_MASTER_NAME: str | None = Field( description="Name of the Redis Sentinel master.", default=None, ) - CELERY_SENTINEL_PASSWORD: Optional[str] = Field( + CELERY_SENTINEL_PASSWORD: str | None = Field( description="Password of the Redis Sentinel master.", default=None, ) - CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( + CELERY_SENTINEL_SOCKET_TIMEOUT: PositiveFloat | None = Field( description="Timeout for Redis Sentinel socket operations in seconds.", default=0.1, ) @@ -268,12 +268,12 @@ class InternalTestConfig(BaseSettings): Configuration settings for Internal Test """ - AWS_SECRET_ACCESS_KEY: Optional[str] = Field( + AWS_SECRET_ACCESS_KEY: str | None = Field( description="Internal test AWS secret access key", default=None, ) - AWS_ACCESS_KEY_ID: Optional[str] = Field( + AWS_ACCESS_KEY_ID: str | None = Field( description="Internal test AWS access key ID", default=None, ) @@ -284,15 +284,15 @@ class DatasetQueueMonitorConfig(BaseSettings): Configuration settings for Dataset Queue Monitor """ - QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field( + QUEUE_MONITOR_THRESHOLD: NonNegativeInt | None = Field( description="Threshold for dataset queue monitor", default=200, ) - QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field( + QUEUE_MONITOR_ALERT_EMAILS: str | None = Field( description="Emails for dataset queue monitor alert, separated by commas", default=None, ) - QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field( + QUEUE_MONITOR_INTERVAL: NonNegativeFloat | None = Field( description="Interval for dataset queue monitor in minutes", default=30, ) diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 16dca98cfa..4705b28c69 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt from pydantic_settings import BaseSettings @@ -19,12 +17,12 @@ class RedisConfig(BaseSettings): default=6379, ) - REDIS_USERNAME: Optional[str] = Field( + REDIS_USERNAME: str | None = Field( description="Username for Redis authentication (if required)", default=None, ) - REDIS_PASSWORD: Optional[str] = Field( + REDIS_PASSWORD: str | None = Field( description="Password for Redis authentication (if required)", default=None, ) @@ -44,47 +42,47 @@ class RedisConfig(BaseSettings): default="CERT_NONE", ) - REDIS_SSL_CA_CERTS: Optional[str] = Field( + REDIS_SSL_CA_CERTS: str | None = Field( description="Path to the CA certificate file for SSL verification", default=None, ) - REDIS_SSL_CERTFILE: Optional[str] = Field( + REDIS_SSL_CERTFILE: str | None = Field( description="Path to the client certificate file for SSL authentication", default=None, ) - REDIS_SSL_KEYFILE: Optional[str] = Field( + REDIS_SSL_KEYFILE: str | None = Field( description="Path to the client private key file for SSL authentication", default=None, ) - REDIS_USE_SENTINEL: Optional[bool] = Field( + REDIS_USE_SENTINEL: bool | None = Field( description="Enable Redis Sentinel mode for high availability", default=False, ) - REDIS_SENTINELS: Optional[str] = Field( + REDIS_SENTINELS: str | None = Field( description="Comma-separated list of Redis Sentinel nodes (host:port)", default=None, ) - REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field( + REDIS_SENTINEL_SERVICE_NAME: str | None = Field( description="Name of the Redis Sentinel service to monitor", default=None, ) - REDIS_SENTINEL_USERNAME: Optional[str] = Field( + REDIS_SENTINEL_USERNAME: str | None = Field( description="Username for Redis Sentinel authentication (if required)", default=None, ) - REDIS_SENTINEL_PASSWORD: Optional[str] = Field( + REDIS_SENTINEL_PASSWORD: str | None = Field( description="Password for Redis Sentinel authentication (if required)", default=None, ) - REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( + REDIS_SENTINEL_SOCKET_TIMEOUT: PositiveFloat | None = Field( description="Socket timeout in seconds for Redis Sentinel connections", default=0.1, ) @@ -94,12 +92,12 @@ class RedisConfig(BaseSettings): default=False, ) - REDIS_CLUSTERS: Optional[str] = Field( + REDIS_CLUSTERS: str | None = Field( description="Comma-separated list of Redis Clusters nodes (host:port)", default=None, ) - REDIS_CLUSTERS_PASSWORD: Optional[str] = Field( + REDIS_CLUSTERS_PASSWORD: str | None = Field( description="Password for Redis Clusters authentication (if required)", default=None, ) diff --git a/api/configs/middleware/storage/aliyun_oss_storage_config.py b/api/configs/middleware/storage/aliyun_oss_storage_config.py index 07eb527170..331c486d54 100644 --- a/api/configs/middleware/storage/aliyun_oss_storage_config.py +++ b/api/configs/middleware/storage/aliyun_oss_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,37 +7,37 @@ class AliyunOSSStorageConfig(BaseSettings): Configuration settings for Aliyun Object Storage Service (OSS) """ - ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field( + ALIYUN_OSS_BUCKET_NAME: str | None = Field( description="Name of the Aliyun OSS bucket to store and retrieve objects", default=None, ) - ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field( + ALIYUN_OSS_ACCESS_KEY: str | None = Field( description="Access key ID for authenticating with Aliyun OSS", default=None, ) - ALIYUN_OSS_SECRET_KEY: Optional[str] = Field( + ALIYUN_OSS_SECRET_KEY: str | None = Field( description="Secret access key for authenticating with Aliyun OSS", default=None, ) - ALIYUN_OSS_ENDPOINT: Optional[str] = Field( + ALIYUN_OSS_ENDPOINT: str | None = Field( description="URL of the Aliyun OSS endpoint for your chosen region", default=None, ) - ALIYUN_OSS_REGION: Optional[str] = Field( + ALIYUN_OSS_REGION: str | None = Field( description="Aliyun OSS region where your bucket is located (e.g., 'oss-cn-hangzhou')", default=None, ) - ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field( + ALIYUN_OSS_AUTH_VERSION: str | None = Field( description="Version of the authentication protocol to use with Aliyun OSS (e.g., 'v4')", default=None, ) - ALIYUN_OSS_PATH: Optional[str] = Field( + ALIYUN_OSS_PATH: str | None = Field( description="Base path within the bucket to store objects (e.g., 'my-app-data/')", default=None, ) diff --git a/api/configs/middleware/storage/amazon_s3_storage_config.py b/api/configs/middleware/storage/amazon_s3_storage_config.py index e14c210718..9277a335f7 100644 --- a/api/configs/middleware/storage/amazon_s3_storage_config.py +++ b/api/configs/middleware/storage/amazon_s3_storage_config.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +9,27 @@ class S3StorageConfig(BaseSettings): Configuration settings for S3-compatible object storage """ - S3_ENDPOINT: Optional[str] = Field( + S3_ENDPOINT: str | None = Field( description="URL of the S3-compatible storage endpoint (e.g., 'https://s3.amazonaws.com')", default=None, ) - S3_REGION: Optional[str] = Field( + S3_REGION: str | None = Field( description="Region where the S3 bucket is located (e.g., 'us-east-1')", default=None, ) - S3_BUCKET_NAME: Optional[str] = Field( + S3_BUCKET_NAME: str | None = Field( description="Name of the S3 bucket to store and retrieve objects", default=None, ) - S3_ACCESS_KEY: Optional[str] = Field( + S3_ACCESS_KEY: str | None = Field( description="Access key ID for authenticating with the S3 service", default=None, ) - S3_SECRET_KEY: Optional[str] = Field( + S3_SECRET_KEY: str | None = Field( description="Secret access key for authenticating with the S3 service", default=None, ) diff --git a/api/configs/middleware/storage/azure_blob_storage_config.py b/api/configs/middleware/storage/azure_blob_storage_config.py index b7ab5247a9..7195d446b1 100644 --- a/api/configs/middleware/storage/azure_blob_storage_config.py +++ b/api/configs/middleware/storage/azure_blob_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class AzureBlobStorageConfig(BaseSettings): Configuration settings for Azure Blob Storage """ - AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field( + AZURE_BLOB_ACCOUNT_NAME: str | None = Field( description="Name of the Azure Storage account (e.g., 'mystorageaccount')", default=None, ) - AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field( + AZURE_BLOB_ACCOUNT_KEY: str | None = Field( description="Access key for authenticating with the Azure Storage account", default=None, ) - AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field( + AZURE_BLOB_CONTAINER_NAME: str | None = Field( description="Name of the Azure Blob container to store and retrieve objects", default=None, ) - AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field( + AZURE_BLOB_ACCOUNT_URL: str | None = Field( description="URL of the Azure Blob storage endpoint (e.g., 'https://mystorageaccount.blob.core.windows.net')", default=None, ) diff --git a/api/configs/middleware/storage/baidu_obs_storage_config.py b/api/configs/middleware/storage/baidu_obs_storage_config.py index e7913b0acc..138a0db650 100644 --- a/api/configs/middleware/storage/baidu_obs_storage_config.py +++ b/api/configs/middleware/storage/baidu_obs_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class BaiduOBSStorageConfig(BaseSettings): Configuration settings for Baidu Object Storage Service (OBS) """ - BAIDU_OBS_BUCKET_NAME: Optional[str] = Field( + BAIDU_OBS_BUCKET_NAME: str | None = Field( description="Name of the Baidu OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')", default=None, ) - BAIDU_OBS_ACCESS_KEY: Optional[str] = Field( + BAIDU_OBS_ACCESS_KEY: str | None = Field( description="Access Key ID for authenticating with Baidu OBS", default=None, ) - BAIDU_OBS_SECRET_KEY: Optional[str] = Field( + BAIDU_OBS_SECRET_KEY: str | None = Field( description="Secret Access Key for authenticating with Baidu OBS", default=None, ) - BAIDU_OBS_ENDPOINT: Optional[str] = Field( + BAIDU_OBS_ENDPOINT: str | None = Field( description="URL of the Baidu OSS endpoint for your chosen region (e.g., 'https://.bj.bcebos.com')", default=None, ) diff --git a/api/configs/middleware/storage/clickzetta_volume_storage_config.py b/api/configs/middleware/storage/clickzetta_volume_storage_config.py index 56e1b6a957..035650d98a 100644 --- a/api/configs/middleware/storage/clickzetta_volume_storage_config.py +++ b/api/configs/middleware/storage/clickzetta_volume_storage_config.py @@ -1,7 +1,5 @@ """ClickZetta Volume Storage Configuration""" -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,17 +7,17 @@ from pydantic_settings import BaseSettings class ClickZettaVolumeStorageConfig(BaseSettings): """Configuration for ClickZetta Volume storage.""" - CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field( + CLICKZETTA_VOLUME_USERNAME: str | None = Field( description="Username for ClickZetta Volume authentication", default=None, ) - CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field( + CLICKZETTA_VOLUME_PASSWORD: str | None = Field( description="Password for ClickZetta Volume authentication", default=None, ) - CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field( + CLICKZETTA_VOLUME_INSTANCE: str | None = Field( description="ClickZetta instance identifier", default=None, ) @@ -49,7 +47,7 @@ class ClickZettaVolumeStorageConfig(BaseSettings): default="user", ) - CLICKZETTA_VOLUME_NAME: Optional[str] = Field( + CLICKZETTA_VOLUME_NAME: str | None = Field( description="ClickZetta volume name for external volumes", default=None, ) diff --git a/api/configs/middleware/storage/google_cloud_storage_config.py b/api/configs/middleware/storage/google_cloud_storage_config.py index e5d763d7f5..a63eb798a8 100644 --- a/api/configs/middleware/storage/google_cloud_storage_config.py +++ b/api/configs/middleware/storage/google_cloud_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class GoogleCloudStorageConfig(BaseSettings): Configuration settings for Google Cloud Storage """ - GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field( + GOOGLE_STORAGE_BUCKET_NAME: str | None = Field( description="Name of the Google Cloud Storage bucket to store and retrieve objects (e.g., 'my-gcs-bucket')", default=None, ) - GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field( + GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: str | None = Field( description="Base64-encoded JSON key file for Google Cloud service account authentication", default=None, ) diff --git a/api/configs/middleware/storage/huawei_obs_storage_config.py b/api/configs/middleware/storage/huawei_obs_storage_config.py index be983b5187..5b5cd2f750 100644 --- a/api/configs/middleware/storage/huawei_obs_storage_config.py +++ b/api/configs/middleware/storage/huawei_obs_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class HuaweiCloudOBSStorageConfig(BaseSettings): Configuration settings for Huawei Cloud Object Storage Service (OBS) """ - HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field( + HUAWEI_OBS_BUCKET_NAME: str | None = Field( description="Name of the Huawei Cloud OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')", default=None, ) - HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field( + HUAWEI_OBS_ACCESS_KEY: str | None = Field( description="Access Key ID for authenticating with Huawei Cloud OBS", default=None, ) - HUAWEI_OBS_SECRET_KEY: Optional[str] = Field( + HUAWEI_OBS_SECRET_KEY: str | None = Field( description="Secret Access Key for authenticating with Huawei Cloud OBS", default=None, ) - HUAWEI_OBS_SERVER: Optional[str] = Field( + HUAWEI_OBS_SERVER: str | None = Field( description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')", default=None, ) diff --git a/api/configs/middleware/storage/oci_storage_config.py b/api/configs/middleware/storage/oci_storage_config.py index edc245bcac..70815a0055 100644 --- a/api/configs/middleware/storage/oci_storage_config.py +++ b/api/configs/middleware/storage/oci_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class OCIStorageConfig(BaseSettings): Configuration settings for Oracle Cloud Infrastructure (OCI) Object Storage """ - OCI_ENDPOINT: Optional[str] = Field( + OCI_ENDPOINT: str | None = Field( description="URL of the OCI Object Storage endpoint (e.g., 'https://objectstorage.us-phoenix-1.oraclecloud.com')", default=None, ) - OCI_REGION: Optional[str] = Field( + OCI_REGION: str | None = Field( description="OCI region where the bucket is located (e.g., 'us-phoenix-1')", default=None, ) - OCI_BUCKET_NAME: Optional[str] = Field( + OCI_BUCKET_NAME: str | None = Field( description="Name of the OCI Object Storage bucket to store and retrieve objects (e.g., 'my-oci-bucket')", default=None, ) - OCI_ACCESS_KEY: Optional[str] = Field( + OCI_ACCESS_KEY: str | None = Field( description="Access key (also known as API key) for authenticating with OCI Object Storage", default=None, ) - OCI_SECRET_KEY: Optional[str] = Field( + OCI_SECRET_KEY: str | None = Field( description="Secret key associated with the access key for authenticating with OCI Object Storage", default=None, ) diff --git a/api/configs/middleware/storage/supabase_storage_config.py b/api/configs/middleware/storage/supabase_storage_config.py index dcf7c20cf9..7f140fc5b9 100644 --- a/api/configs/middleware/storage/supabase_storage_config.py +++ b/api/configs/middleware/storage/supabase_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,17 +7,17 @@ class SupabaseStorageConfig(BaseSettings): Configuration settings for Supabase Object Storage Service """ - SUPABASE_BUCKET_NAME: Optional[str] = Field( + SUPABASE_BUCKET_NAME: str | None = Field( description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')", default=None, ) - SUPABASE_API_KEY: Optional[str] = Field( + SUPABASE_API_KEY: str | None = Field( description="API KEY for authenticating with Supabase", default=None, ) - SUPABASE_URL: Optional[str] = Field( + SUPABASE_URL: str | None = Field( description="URL of the Supabase", default=None, ) diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py index 255c4e8938..e297e748e9 100644 --- a/api/configs/middleware/storage/tencent_cos_storage_config.py +++ b/api/configs/middleware/storage/tencent_cos_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class TencentCloudCOSStorageConfig(BaseSettings): Configuration settings for Tencent Cloud Object Storage (COS) """ - TENCENT_COS_BUCKET_NAME: Optional[str] = Field( + TENCENT_COS_BUCKET_NAME: str | None = Field( description="Name of the Tencent Cloud COS bucket to store and retrieve objects", default=None, ) - TENCENT_COS_REGION: Optional[str] = Field( + TENCENT_COS_REGION: str | None = Field( description="Tencent Cloud region where the COS bucket is located (e.g., 'ap-guangzhou')", default=None, ) - TENCENT_COS_SECRET_ID: Optional[str] = Field( + TENCENT_COS_SECRET_ID: str | None = Field( description="SecretId for authenticating with Tencent Cloud COS (part of API credentials)", default=None, ) - TENCENT_COS_SECRET_KEY: Optional[str] = Field( + TENCENT_COS_SECRET_KEY: str | None = Field( description="SecretKey for authenticating with Tencent Cloud COS (part of API credentials)", default=None, ) - TENCENT_COS_SCHEME: Optional[str] = Field( + TENCENT_COS_SCHEME: str | None = Field( description="Protocol scheme for COS requests: 'https' (recommended) or 'http'", default=None, ) diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py index 06c3ae4d3e..be01f2dc36 100644 --- a/api/configs/middleware/storage/volcengine_tos_storage_config.py +++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class VolcengineTOSStorageConfig(BaseSettings): Configuration settings for Volcengine Tinder Object Storage (TOS) """ - VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field( + VOLCENGINE_TOS_BUCKET_NAME: str | None = Field( description="Name of the Volcengine TOS bucket to store and retrieve objects (e.g., 'my-tos-bucket')", default=None, ) - VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field( + VOLCENGINE_TOS_ACCESS_KEY: str | None = Field( description="Access Key ID for authenticating with Volcengine TOS", default=None, ) - VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field( + VOLCENGINE_TOS_SECRET_KEY: str | None = Field( description="Secret Access Key for authenticating with Volcengine TOS", default=None, ) - VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field( + VOLCENGINE_TOS_ENDPOINT: str | None = Field( description="URL of the Volcengine TOS endpoint (e.g., 'https://tos-cn-beijing.volces.com')", default=None, ) - VOLCENGINE_TOS_REGION: Optional[str] = Field( + VOLCENGINE_TOS_REGION: str | None = Field( description="Volcengine region where the TOS bucket is located (e.g., 'cn-beijing')", default=None, ) diff --git a/api/configs/middleware/vdb/analyticdb_config.py b/api/configs/middleware/vdb/analyticdb_config.py index cb8dc7d724..539b9c0963 100644 --- a/api/configs/middleware/vdb/analyticdb_config.py +++ b/api/configs/middleware/vdb/analyticdb_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -11,37 +9,37 @@ class AnalyticdbConfig(BaseSettings): https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled """ - ANALYTICDB_KEY_ID: Optional[str] = Field( + ANALYTICDB_KEY_ID: str | None = Field( default=None, description="The Access Key ID provided by Alibaba Cloud for API authentication." ) - ANALYTICDB_KEY_SECRET: Optional[str] = Field( + ANALYTICDB_KEY_SECRET: str | None = Field( default=None, description="The Secret Access Key corresponding to the Access Key ID for secure API access." ) - ANALYTICDB_REGION_ID: Optional[str] = Field( + ANALYTICDB_REGION_ID: str | None = Field( default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou', 'ap-southeast-1').", ) - ANALYTICDB_INSTANCE_ID: Optional[str] = Field( + ANALYTICDB_INSTANCE_ID: str | None = Field( default=None, description="The unique identifier of the AnalyticDB instance you want to connect to.", ) - ANALYTICDB_ACCOUNT: Optional[str] = Field( + ANALYTICDB_ACCOUNT: str | None = Field( default=None, description="The account name used to log in to the AnalyticDB instance" " (usually the initial account created with the instance).", ) - ANALYTICDB_PASSWORD: Optional[str] = Field( + ANALYTICDB_PASSWORD: str | None = Field( default=None, description="The password associated with the AnalyticDB account for database authentication." ) - ANALYTICDB_NAMESPACE: Optional[str] = Field( + ANALYTICDB_NAMESPACE: str | None = Field( default=None, description="The namespace within AnalyticDB for schema isolation (if using namespace feature)." ) - ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field( + ANALYTICDB_NAMESPACE_PASSWORD: str | None = Field( default=None, description="The password for accessing the specified namespace within the AnalyticDB instance" " (if namespace feature is enabled).", ) - ANALYTICDB_HOST: Optional[str] = Field( + ANALYTICDB_HOST: str | None = Field( default=None, description="The host of the AnalyticDB instance you want to connect to." ) ANALYTICDB_PORT: PositiveInt = Field( diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py index 44742c2e2f..4b6ddb3bde 100644 --- a/api/configs/middleware/vdb/baidu_vector_config.py +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class BaiduVectorDBConfig(BaseSettings): Configuration settings for Baidu Vector Database """ - BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field( + BAIDU_VECTOR_DB_ENDPOINT: str | None = Field( description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')", default=None, ) @@ -19,17 +17,17 @@ class BaiduVectorDBConfig(BaseSettings): default=30000, ) - BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field( + BAIDU_VECTOR_DB_ACCOUNT: str | None = Field( description="Account for authenticating with the Baidu Vector Database", default=None, ) - BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field( + BAIDU_VECTOR_DB_API_KEY: str | None = Field( description="API key for authenticating with the Baidu Vector Database service", default=None, ) - BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field( + BAIDU_VECTOR_DB_DATABASE: str | None = Field( description="Name of the specific Baidu Vector Database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/chroma_config.py b/api/configs/middleware/vdb/chroma_config.py index e83a9902de..3a78980b91 100644 --- a/api/configs/middleware/vdb/chroma_config.py +++ b/api/configs/middleware/vdb/chroma_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class ChromaConfig(BaseSettings): Configuration settings for Chroma vector database """ - CHROMA_HOST: Optional[str] = Field( + CHROMA_HOST: str | None = Field( description="Hostname or IP address of the Chroma server (e.g., 'localhost' or '192.168.1.100')", default=None, ) @@ -19,22 +17,22 @@ class ChromaConfig(BaseSettings): default=8000, ) - CHROMA_TENANT: Optional[str] = Field( + CHROMA_TENANT: str | None = Field( description="Tenant identifier for multi-tenancy support in Chroma", default=None, ) - CHROMA_DATABASE: Optional[str] = Field( + CHROMA_DATABASE: str | None = Field( description="Name of the Chroma database to connect to", default=None, ) - CHROMA_AUTH_PROVIDER: Optional[str] = Field( + CHROMA_AUTH_PROVIDER: str | None = Field( description="Authentication provider for Chroma (e.g., 'basic', 'token', or a custom provider)", default=None, ) - CHROMA_AUTH_CREDENTIALS: Optional[str] = Field( + CHROMA_AUTH_CREDENTIALS: str | None = Field( description="Authentication credentials for Chroma (format depends on the auth provider)", default=None, ) diff --git a/api/configs/middleware/vdb/clickzetta_config.py b/api/configs/middleware/vdb/clickzetta_config.py index 61bc01202b..e8172b5299 100644 --- a/api/configs/middleware/vdb/clickzetta_config.py +++ b/api/configs/middleware/vdb/clickzetta_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,62 +7,62 @@ class ClickzettaConfig(BaseSettings): Clickzetta Lakehouse vector database configuration """ - CLICKZETTA_USERNAME: Optional[str] = Field( + CLICKZETTA_USERNAME: str | None = Field( description="Username for authenticating with Clickzetta Lakehouse", default=None, ) - CLICKZETTA_PASSWORD: Optional[str] = Field( + CLICKZETTA_PASSWORD: str | None = Field( description="Password for authenticating with Clickzetta Lakehouse", default=None, ) - CLICKZETTA_INSTANCE: Optional[str] = Field( + CLICKZETTA_INSTANCE: str | None = Field( description="Clickzetta Lakehouse instance ID", default=None, ) - CLICKZETTA_SERVICE: Optional[str] = Field( + CLICKZETTA_SERVICE: str | None = Field( description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')", default="api.clickzetta.com", ) - CLICKZETTA_WORKSPACE: Optional[str] = Field( + CLICKZETTA_WORKSPACE: str | None = Field( description="Clickzetta workspace name", default="default", ) - CLICKZETTA_VCLUSTER: Optional[str] = Field( + CLICKZETTA_VCLUSTER: str | None = Field( description="Clickzetta virtual cluster name", default="default_ap", ) - CLICKZETTA_SCHEMA: Optional[str] = Field( + CLICKZETTA_SCHEMA: str | None = Field( description="Database schema name in Clickzetta", default="public", ) - CLICKZETTA_BATCH_SIZE: Optional[int] = Field( + CLICKZETTA_BATCH_SIZE: int | None = Field( description="Batch size for bulk insert operations", default=100, ) - CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field( + CLICKZETTA_ENABLE_INVERTED_INDEX: bool | None = Field( description="Enable inverted index for full-text search capabilities", default=True, ) - CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field( + CLICKZETTA_ANALYZER_TYPE: str | None = Field( description="Analyzer type for full-text search: keyword, english, chinese, unicode", default="chinese", ) - CLICKZETTA_ANALYZER_MODE: Optional[str] = Field( + CLICKZETTA_ANALYZER_MODE: str | None = Field( description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)", default="smart", ) - CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field( + CLICKZETTA_VECTOR_DISTANCE_FUNCTION: str | None = Field( description="Distance function for vector similarity: l2_distance or cosine_distance", default="cosine_distance", ) diff --git a/api/configs/middleware/vdb/couchbase_config.py b/api/configs/middleware/vdb/couchbase_config.py index b81cbf8959..a365e30263 100644 --- a/api/configs/middleware/vdb/couchbase_config.py +++ b/api/configs/middleware/vdb/couchbase_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class CouchbaseConfig(BaseSettings): Couchbase configs """ - COUCHBASE_CONNECTION_STRING: Optional[str] = Field( + COUCHBASE_CONNECTION_STRING: str | None = Field( description="COUCHBASE connection string", default=None, ) - COUCHBASE_USER: Optional[str] = Field( + COUCHBASE_USER: str | None = Field( description="COUCHBASE user", default=None, ) - COUCHBASE_PASSWORD: Optional[str] = Field( + COUCHBASE_PASSWORD: str | None = Field( description="COUCHBASE password", default=None, ) - COUCHBASE_BUCKET_NAME: Optional[str] = Field( + COUCHBASE_BUCKET_NAME: str | None = Field( description="COUCHBASE bucket name", default=None, ) - COUCHBASE_SCOPE_NAME: Optional[str] = Field( + COUCHBASE_SCOPE_NAME: str | None = Field( description="COUCHBASE scope name", default=None, ) diff --git a/api/configs/middleware/vdb/elasticsearch_config.py b/api/configs/middleware/vdb/elasticsearch_config.py index 8c4b333d45..a0efd41417 100644 --- a/api/configs/middleware/vdb/elasticsearch_config.py +++ b/api/configs/middleware/vdb/elasticsearch_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt, model_validator from pydantic_settings import BaseSettings @@ -10,7 +8,7 @@ class ElasticsearchConfig(BaseSettings): Can load from environment variables or .env files. """ - ELASTICSEARCH_HOST: Optional[str] = Field( + ELASTICSEARCH_HOST: str | None = Field( description="Hostname or IP address of the Elasticsearch server (e.g., 'localhost' or '192.168.1.100')", default="127.0.0.1", ) @@ -20,30 +18,28 @@ class ElasticsearchConfig(BaseSettings): default=9200, ) - ELASTICSEARCH_USERNAME: Optional[str] = Field( + ELASTICSEARCH_USERNAME: str | None = Field( description="Username for authenticating with Elasticsearch (default is 'elastic')", default="elastic", ) - ELASTICSEARCH_PASSWORD: Optional[str] = Field( + ELASTICSEARCH_PASSWORD: str | None = Field( description="Password for authenticating with Elasticsearch (default is 'elastic')", default="elastic", ) # Elastic Cloud (optional) - ELASTICSEARCH_USE_CLOUD: Optional[bool] = Field( + ELASTICSEARCH_USE_CLOUD: bool | None = Field( description="Set to True to use Elastic Cloud instead of self-hosted Elasticsearch", default=False ) - ELASTICSEARCH_CLOUD_URL: Optional[str] = Field( + ELASTICSEARCH_CLOUD_URL: str | None = Field( description="Full URL for Elastic Cloud deployment (e.g., 'https://example.es.region.aws.found.io:443')", default=None, ) - ELASTICSEARCH_API_KEY: Optional[str] = Field( - description="API key for authenticating with Elastic Cloud", default=None - ) + ELASTICSEARCH_API_KEY: str | None = Field(description="API key for authenticating with Elastic Cloud", default=None) # Common options - ELASTICSEARCH_CA_CERTS: Optional[str] = Field( + ELASTICSEARCH_CA_CERTS: str | None = Field( description="Path to CA certificate file for SSL verification", default=None ) ELASTICSEARCH_VERIFY_CERTS: bool = Field( diff --git a/api/configs/middleware/vdb/huawei_cloud_config.py b/api/configs/middleware/vdb/huawei_cloud_config.py index 2290c60499..d64cb870fa 100644 --- a/api/configs/middleware/vdb/huawei_cloud_config.py +++ b/api/configs/middleware/vdb/huawei_cloud_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,17 +7,17 @@ class HuaweiCloudConfig(BaseSettings): Configuration settings for Huawei cloud search service """ - HUAWEI_CLOUD_HOSTS: Optional[str] = Field( + HUAWEI_CLOUD_HOSTS: str | None = Field( description="Hostname or IP address of the Huawei cloud search service instance", default=None, ) - HUAWEI_CLOUD_USER: Optional[str] = Field( + HUAWEI_CLOUD_USER: str | None = Field( description="Username for authenticating with Huawei cloud search service", default=None, ) - HUAWEI_CLOUD_PASSWORD: Optional[str] = Field( + HUAWEI_CLOUD_PASSWORD: str | None = Field( description="Password for authenticating with Huawei cloud search service", default=None, ) diff --git a/api/configs/middleware/vdb/lindorm_config.py b/api/configs/middleware/vdb/lindorm_config.py index e80e3f4a35..262d5a1f26 100644 --- a/api/configs/middleware/vdb/lindorm_config.py +++ b/api/configs/middleware/vdb/lindorm_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class LindormConfig(BaseSettings): Lindorm configs """ - LINDORM_URL: Optional[str] = Field( + LINDORM_URL: str | None = Field( description="Lindorm url", default=None, ) - LINDORM_USERNAME: Optional[str] = Field( + LINDORM_USERNAME: str | None = Field( description="Lindorm user", default=None, ) - LINDORM_PASSWORD: Optional[str] = Field( + LINDORM_PASSWORD: str | None = Field( description="Lindorm password", default=None, ) - DEFAULT_INDEX_TYPE: Optional[str] = Field( + DEFAULT_INDEX_TYPE: str | None = Field( description="Lindorm Vector Index Type, hnsw or flat is available in dify", default="hnsw", ) - DEFAULT_DISTANCE_TYPE: Optional[str] = Field( + DEFAULT_DISTANCE_TYPE: str | None = Field( description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2" ) - USING_UGC_INDEX: Optional[bool] = Field( + USING_UGC_INDEX: bool | None = Field( description="Using UGC index will store the same type of Index in a single index but can retrieve separately.", default=False, ) - LINDORM_QUERY_TIMEOUT: Optional[float] = Field(description="The lindorm search request timeout (s)", default=2.0) + LINDORM_QUERY_TIMEOUT: float | None = Field(description="The lindorm search request timeout (s)", default=2.0) diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index d398ef5bd8..05cee51cc9 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class MilvusConfig(BaseSettings): Configuration settings for Milvus vector database """ - MILVUS_URI: Optional[str] = Field( + MILVUS_URI: str | None = Field( description="URI for connecting to the Milvus server (e.g., 'http://localhost:19530' or 'https://milvus-instance.example.com:19530')", default="http://127.0.0.1:19530", ) - MILVUS_TOKEN: Optional[str] = Field( + MILVUS_TOKEN: str | None = Field( description="Authentication token for Milvus, if token-based authentication is enabled", default=None, ) - MILVUS_USER: Optional[str] = Field( + MILVUS_USER: str | None = Field( description="Username for authenticating with Milvus, if username/password authentication is enabled", default=None, ) - MILVUS_PASSWORD: Optional[str] = Field( + MILVUS_PASSWORD: str | None = Field( description="Password for authenticating with Milvus, if username/password authentication is enabled", default=None, ) @@ -40,7 +38,7 @@ class MilvusConfig(BaseSettings): default=True, ) - MILVUS_ANALYZER_PARAMS: Optional[str] = Field( + MILVUS_ANALYZER_PARAMS: str | None = Field( description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.', default=None, ) diff --git a/api/configs/middleware/vdb/oceanbase_config.py b/api/configs/middleware/vdb/oceanbase_config.py index 9b11a22732..8437328e76 100644 --- a/api/configs/middleware/vdb/oceanbase_config.py +++ b/api/configs/middleware/vdb/oceanbase_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class OceanBaseVectorConfig(BaseSettings): Configuration settings for OceanBase Vector database """ - OCEANBASE_VECTOR_HOST: Optional[str] = Field( + OCEANBASE_VECTOR_HOST: str | None = Field( description="Hostname or IP address of the OceanBase Vector server (e.g. 'localhost')", default=None, ) - OCEANBASE_VECTOR_PORT: Optional[PositiveInt] = Field( + OCEANBASE_VECTOR_PORT: PositiveInt | None = Field( description="Port number on which the OceanBase Vector server is listening (default is 2881)", default=2881, ) - OCEANBASE_VECTOR_USER: Optional[str] = Field( + OCEANBASE_VECTOR_USER: str | None = Field( description="Username for authenticating with the OceanBase Vector database", default=None, ) - OCEANBASE_VECTOR_PASSWORD: Optional[str] = Field( + OCEANBASE_VECTOR_PASSWORD: str | None = Field( description="Password for authenticating with the OceanBase Vector database", default=None, ) - OCEANBASE_VECTOR_DATABASE: Optional[str] = Field( + OCEANBASE_VECTOR_DATABASE: str | None = Field( description="Name of the OceanBase Vector database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/opengauss_config.py b/api/configs/middleware/vdb/opengauss_config.py index 87ea292ab4..b57c1e59a9 100644 --- a/api/configs/middleware/vdb/opengauss_config.py +++ b/api/configs/middleware/vdb/opengauss_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class OpenGaussConfig(BaseSettings): Configuration settings for OpenGauss """ - OPENGAUSS_HOST: Optional[str] = Field( + OPENGAUSS_HOST: str | None = Field( description="Hostname or IP address of the OpenGauss server(e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class OpenGaussConfig(BaseSettings): default=6600, ) - OPENGAUSS_USER: Optional[str] = Field( + OPENGAUSS_USER: str | None = Field( description="Username for authenticating with the OpenGauss database", default=None, ) - OPENGAUSS_PASSWORD: Optional[str] = Field( + OPENGAUSS_PASSWORD: str | None = Field( description="Password for authenticating with the OpenGauss database", default=None, ) - OPENGAUSS_DATABASE: Optional[str] = Field( + OPENGAUSS_DATABASE: str | None = Field( description="Name of the OpenGauss database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 9700447a4c..ba015a6eb9 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal, Optional +from typing import Literal from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -18,7 +18,7 @@ class OpenSearchConfig(BaseSettings): BASIC = "basic" AWS_MANAGED_IAM = "aws_managed_iam" - OPENSEARCH_HOST: Optional[str] = Field( + OPENSEARCH_HOST: str | None = Field( description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", default=None, ) @@ -43,21 +43,21 @@ class OpenSearchConfig(BaseSettings): default=AuthMethod.BASIC, ) - OPENSEARCH_USER: Optional[str] = Field( + OPENSEARCH_USER: str | None = Field( description="Username for authenticating with OpenSearch", default=None, ) - OPENSEARCH_PASSWORD: Optional[str] = Field( + OPENSEARCH_PASSWORD: str | None = Field( description="Password for authenticating with OpenSearch", default=None, ) - OPENSEARCH_AWS_REGION: Optional[str] = Field( + OPENSEARCH_AWS_REGION: str | None = Field( description="AWS region for OpenSearch (e.g. 'us-west-2')", default=None, ) - OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field( + OPENSEARCH_AWS_SERVICE: Literal["es", "aoss"] | None = Field( description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None ) diff --git a/api/configs/middleware/vdb/oracle_config.py b/api/configs/middleware/vdb/oracle_config.py index ea39909ef4..dc179e8e4f 100644 --- a/api/configs/middleware/vdb/oracle_config.py +++ b/api/configs/middleware/vdb/oracle_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,33 +7,33 @@ class OracleConfig(BaseSettings): Configuration settings for Oracle database """ - ORACLE_USER: Optional[str] = Field( + ORACLE_USER: str | None = Field( description="Username for authenticating with the Oracle database", default=None, ) - ORACLE_PASSWORD: Optional[str] = Field( + ORACLE_PASSWORD: str | None = Field( description="Password for authenticating with the Oracle database", default=None, ) - ORACLE_DSN: Optional[str] = Field( + ORACLE_DSN: str | None = Field( description="Oracle database connection string. For traditional database, use format 'host:port/service_name'. " "For autonomous database, use the service name from tnsnames.ora in the wallet", default=None, ) - ORACLE_CONFIG_DIR: Optional[str] = Field( + ORACLE_CONFIG_DIR: str | None = Field( description="Directory containing the tnsnames.ora configuration file. Only used in thin mode connection", default=None, ) - ORACLE_WALLET_LOCATION: Optional[str] = Field( + ORACLE_WALLET_LOCATION: str | None = Field( description="Oracle wallet directory path containing the wallet files for secure connection", default=None, ) - ORACLE_WALLET_PASSWORD: Optional[str] = Field( + ORACLE_WALLET_PASSWORD: str | None = Field( description="Password to decrypt the Oracle wallet, if it is encrypted", default=None, ) diff --git a/api/configs/middleware/vdb/pgvector_config.py b/api/configs/middleware/vdb/pgvector_config.py index 9f5f7284d7..62334636a5 100644 --- a/api/configs/middleware/vdb/pgvector_config.py +++ b/api/configs/middleware/vdb/pgvector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class PGVectorConfig(BaseSettings): Configuration settings for PGVector (PostgreSQL with vector extension) """ - PGVECTOR_HOST: Optional[str] = Field( + PGVECTOR_HOST: str | None = Field( description="Hostname or IP address of the PostgreSQL server with PGVector extension (e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class PGVectorConfig(BaseSettings): default=5433, ) - PGVECTOR_USER: Optional[str] = Field( + PGVECTOR_USER: str | None = Field( description="Username for authenticating with the PostgreSQL database", default=None, ) - PGVECTOR_PASSWORD: Optional[str] = Field( + PGVECTOR_PASSWORD: str | None = Field( description="Password for authenticating with the PostgreSQL database", default=None, ) - PGVECTOR_DATABASE: Optional[str] = Field( + PGVECTOR_DATABASE: str | None = Field( description="Name of the PostgreSQL database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/pgvectors_config.py b/api/configs/middleware/vdb/pgvectors_config.py index fa3bca5bb7..7bc144c4ab 100644 --- a/api/configs/middleware/vdb/pgvectors_config.py +++ b/api/configs/middleware/vdb/pgvectors_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class PGVectoRSConfig(BaseSettings): Configuration settings for PGVecto.RS (Rust-based vector extension for PostgreSQL) """ - PGVECTO_RS_HOST: Optional[str] = Field( + PGVECTO_RS_HOST: str | None = Field( description="Hostname or IP address of the PostgreSQL server with PGVecto.RS extension (e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class PGVectoRSConfig(BaseSettings): default=5431, ) - PGVECTO_RS_USER: Optional[str] = Field( + PGVECTO_RS_USER: str | None = Field( description="Username for authenticating with the PostgreSQL database using PGVecto.RS", default=None, ) - PGVECTO_RS_PASSWORD: Optional[str] = Field( + PGVECTO_RS_PASSWORD: str | None = Field( description="Password for authenticating with the PostgreSQL database using PGVecto.RS", default=None, ) - PGVECTO_RS_DATABASE: Optional[str] = Field( + PGVECTO_RS_DATABASE: str | None = Field( description="Name of the PostgreSQL database with PGVecto.RS extension to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/qdrant_config.py b/api/configs/middleware/vdb/qdrant_config.py index 0a753eddec..b9e8e861da 100644 --- a/api/configs/middleware/vdb/qdrant_config.py +++ b/api/configs/middleware/vdb/qdrant_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class QdrantConfig(BaseSettings): Configuration settings for Qdrant vector database """ - QDRANT_URL: Optional[str] = Field( + QDRANT_URL: str | None = Field( description="URL of the Qdrant server (e.g., 'http://localhost:6333' or 'https://qdrant.example.com')", default=None, ) - QDRANT_API_KEY: Optional[str] = Field( + QDRANT_API_KEY: str | None = Field( description="API key for authenticating with the Qdrant server", default=None, ) diff --git a/api/configs/middleware/vdb/relyt_config.py b/api/configs/middleware/vdb/relyt_config.py index 5ffbea7b19..0ed5357852 100644 --- a/api/configs/middleware/vdb/relyt_config.py +++ b/api/configs/middleware/vdb/relyt_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class RelytConfig(BaseSettings): Configuration settings for Relyt database """ - RELYT_HOST: Optional[str] = Field( + RELYT_HOST: str | None = Field( description="Hostname or IP address of the Relyt server (e.g., 'localhost' or 'relyt.example.com')", default=None, ) @@ -19,17 +17,17 @@ class RelytConfig(BaseSettings): default=9200, ) - RELYT_USER: Optional[str] = Field( + RELYT_USER: str | None = Field( description="Username for authenticating with the Relyt database", default=None, ) - RELYT_PASSWORD: Optional[str] = Field( + RELYT_PASSWORD: str | None = Field( description="Password for authenticating with the Relyt database", default=None, ) - RELYT_DATABASE: Optional[str] = Field( + RELYT_DATABASE: str | None = Field( description="Name of the Relyt database to connect to (default is 'default')", default="default", ) diff --git a/api/configs/middleware/vdb/tablestore_config.py b/api/configs/middleware/vdb/tablestore_config.py index 1aab01c6e1..2cec384b5d 100644 --- a/api/configs/middleware/vdb/tablestore_config.py +++ b/api/configs/middleware/vdb/tablestore_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class TableStoreConfig(BaseSettings): Configuration settings for TableStore. """ - TABLESTORE_ENDPOINT: Optional[str] = Field( + TABLESTORE_ENDPOINT: str | None = Field( description="Endpoint address of the TableStore server (e.g. 'https://instance-name.cn-hangzhou.ots.aliyuncs.com')", default=None, ) - TABLESTORE_INSTANCE_NAME: Optional[str] = Field( + TABLESTORE_INSTANCE_NAME: str | None = Field( description="Instance name to access TableStore server (eg. 'instance-name')", default=None, ) - TABLESTORE_ACCESS_KEY_ID: Optional[str] = Field( + TABLESTORE_ACCESS_KEY_ID: str | None = Field( description="AccessKey id for the instance name", default=None, ) - TABLESTORE_ACCESS_KEY_SECRET: Optional[str] = Field( + TABLESTORE_ACCESS_KEY_SECRET: str | None = Field( description="AccessKey secret for the instance name", default=None, ) diff --git a/api/configs/middleware/vdb/tencent_vector_config.py b/api/configs/middleware/vdb/tencent_vector_config.py index a51823c3f3..3dc21ab89a 100644 --- a/api/configs/middleware/vdb/tencent_vector_config.py +++ b/api/configs/middleware/vdb/tencent_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class TencentVectorDBConfig(BaseSettings): Configuration settings for Tencent Vector Database """ - TENCENT_VECTOR_DB_URL: Optional[str] = Field( + TENCENT_VECTOR_DB_URL: str | None = Field( description="URL of the Tencent Vector Database service (e.g., 'https://vectordb.tencentcloudapi.com')", default=None, ) - TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field( + TENCENT_VECTOR_DB_API_KEY: str | None = Field( description="API key for authenticating with the Tencent Vector Database service", default=None, ) @@ -24,12 +22,12 @@ class TencentVectorDBConfig(BaseSettings): default=30, ) - TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field( + TENCENT_VECTOR_DB_USERNAME: str | None = Field( description="Username for authenticating with the Tencent Vector Database (if required)", default=None, ) - TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field( + TENCENT_VECTOR_DB_PASSWORD: str | None = Field( description="Password for authenticating with the Tencent Vector Database (if required)", default=None, ) @@ -44,7 +42,7 @@ class TencentVectorDBConfig(BaseSettings): default=2, ) - TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field( + TENCENT_VECTOR_DB_DATABASE: str | None = Field( description="Name of the specific Tencent Vector Database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/tidb_on_qdrant_config.py b/api/configs/middleware/vdb/tidb_on_qdrant_config.py index d2625af264..9ca0955129 100644 --- a/api/configs/middleware/vdb/tidb_on_qdrant_config.py +++ b/api/configs/middleware/vdb/tidb_on_qdrant_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class TidbOnQdrantConfig(BaseSettings): Tidb on Qdrant configs """ - TIDB_ON_QDRANT_URL: Optional[str] = Field( + TIDB_ON_QDRANT_URL: str | None = Field( description="Tidb on Qdrant url", default=None, ) - TIDB_ON_QDRANT_API_KEY: Optional[str] = Field( + TIDB_ON_QDRANT_API_KEY: str | None = Field( description="Tidb on Qdrant api key", default=None, ) @@ -34,37 +32,37 @@ class TidbOnQdrantConfig(BaseSettings): default=6334, ) - TIDB_PUBLIC_KEY: Optional[str] = Field( + TIDB_PUBLIC_KEY: str | None = Field( description="Tidb account public key", default=None, ) - TIDB_PRIVATE_KEY: Optional[str] = Field( + TIDB_PRIVATE_KEY: str | None = Field( description="Tidb account private key", default=None, ) - TIDB_API_URL: Optional[str] = Field( + TIDB_API_URL: str | None = Field( description="Tidb API url", default=None, ) - TIDB_IAM_API_URL: Optional[str] = Field( + TIDB_IAM_API_URL: str | None = Field( description="Tidb IAM API url", default=None, ) - TIDB_REGION: Optional[str] = Field( + TIDB_REGION: str | None = Field( description="Tidb serverless region", default="regions/aws-us-east-1", ) - TIDB_PROJECT_ID: Optional[str] = Field( + TIDB_PROJECT_ID: str | None = Field( description="Tidb project id", default=None, ) - TIDB_SPEND_LIMIT: Optional[int] = Field( + TIDB_SPEND_LIMIT: int | None = Field( description="Tidb spend limit", default=100, ) diff --git a/api/configs/middleware/vdb/tidb_vector_config.py b/api/configs/middleware/vdb/tidb_vector_config.py index bc68be69d8..0ebf226bea 100644 --- a/api/configs/middleware/vdb/tidb_vector_config.py +++ b/api/configs/middleware/vdb/tidb_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class TiDBVectorConfig(BaseSettings): Configuration settings for TiDB Vector database """ - TIDB_VECTOR_HOST: Optional[str] = Field( + TIDB_VECTOR_HOST: str | None = Field( description="Hostname or IP address of the TiDB Vector server (e.g., 'localhost' or 'tidb.example.com')", default=None, ) - TIDB_VECTOR_PORT: Optional[PositiveInt] = Field( + TIDB_VECTOR_PORT: PositiveInt | None = Field( description="Port number on which the TiDB Vector server is listening (default is 4000)", default=4000, ) - TIDB_VECTOR_USER: Optional[str] = Field( + TIDB_VECTOR_USER: str | None = Field( description="Username for authenticating with the TiDB Vector database", default=None, ) - TIDB_VECTOR_PASSWORD: Optional[str] = Field( + TIDB_VECTOR_PASSWORD: str | None = Field( description="Password for authenticating with the TiDB Vector database", default=None, ) - TIDB_VECTOR_DATABASE: Optional[str] = Field( + TIDB_VECTOR_DATABASE: str | None = Field( description="Name of the TiDB Vector database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/upstash_config.py b/api/configs/middleware/vdb/upstash_config.py index 412c56374a..01a0442f70 100644 --- a/api/configs/middleware/vdb/upstash_config.py +++ b/api/configs/middleware/vdb/upstash_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class UpstashConfig(BaseSettings): Configuration settings for Upstash vector database """ - UPSTASH_VECTOR_URL: Optional[str] = Field( + UPSTASH_VECTOR_URL: str | None = Field( description="URL of the upstash server (e.g., 'https://vector.upstash.io')", default=None, ) - UPSTASH_VECTOR_TOKEN: Optional[str] = Field( + UPSTASH_VECTOR_TOKEN: str | None = Field( description="Token for authenticating with the upstash server", default=None, ) diff --git a/api/configs/middleware/vdb/vastbase_vector_config.py b/api/configs/middleware/vdb/vastbase_vector_config.py index 816d6df90a..ced4cf154c 100644 --- a/api/configs/middleware/vdb/vastbase_vector_config.py +++ b/api/configs/middleware/vdb/vastbase_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class VastbaseVectorConfig(BaseSettings): Configuration settings for Vector (Vastbase with vector extension) """ - VASTBASE_HOST: Optional[str] = Field( + VASTBASE_HOST: str | None = Field( description="Hostname or IP address of the Vastbase server with Vector extension (e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class VastbaseVectorConfig(BaseSettings): default=5432, ) - VASTBASE_USER: Optional[str] = Field( + VASTBASE_USER: str | None = Field( description="Username for authenticating with the Vastbase database", default=None, ) - VASTBASE_PASSWORD: Optional[str] = Field( + VASTBASE_PASSWORD: str | None = Field( description="Password for authenticating with the Vastbase database", default=None, ) - VASTBASE_DATABASE: Optional[str] = Field( + VASTBASE_DATABASE: str | None = Field( description="Name of the Vastbase database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/vikingdb_config.py b/api/configs/middleware/vdb/vikingdb_config.py index aba49ff670..3d5306bb61 100644 --- a/api/configs/middleware/vdb/vikingdb_config.py +++ b/api/configs/middleware/vdb/vikingdb_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -11,14 +9,14 @@ class VikingDBConfig(BaseSettings): https://www.volcengine.com/docs/6291/65568 """ - VIKINGDB_ACCESS_KEY: Optional[str] = Field( + VIKINGDB_ACCESS_KEY: str | None = Field( description="The Access Key provided by Volcengine VikingDB for API authentication." "Refer to the following documentation for details on obtaining credentials:" "https://www.volcengine.com/docs/6291/65568", default=None, ) - VIKINGDB_SECRET_KEY: Optional[str] = Field( + VIKINGDB_SECRET_KEY: str | None = Field( description="The Secret Key provided by Volcengine VikingDB for API authentication.", default=None, ) diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py index 25000e8bde..6a79412ab8 100644 --- a/api/configs/middleware/vdb/weaviate_config.py +++ b/api/configs/middleware/vdb/weaviate_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class WeaviateConfig(BaseSettings): Configuration settings for Weaviate vector database """ - WEAVIATE_ENDPOINT: Optional[str] = Field( + WEAVIATE_ENDPOINT: str | None = Field( description="URL of the Weaviate server (e.g., 'http://localhost:8080' or 'https://weaviate.example.com')", default=None, ) - WEAVIATE_API_KEY: Optional[str] = Field( + WEAVIATE_API_KEY: str | None = Field( description="API key for authenticating with the Weaviate server", default=None, ) diff --git a/api/configs/remote_settings_sources/apollo/__init__.py b/api/configs/remote_settings_sources/apollo/__init__.py index f02f7dc9ff..55c14ead56 100644 --- a/api/configs/remote_settings_sources/apollo/__init__.py +++ b/api/configs/remote_settings_sources/apollo/__init__.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from pydantic import Field from pydantic.fields import FieldInfo @@ -15,22 +15,22 @@ class ApolloSettingsSourceInfo(BaseSettings): Packaging build information """ - APOLLO_APP_ID: Optional[str] = Field( + APOLLO_APP_ID: str | None = Field( description="apollo app_id", default=None, ) - APOLLO_CLUSTER: Optional[str] = Field( + APOLLO_CLUSTER: str | None = Field( description="apollo cluster", default=None, ) - APOLLO_CONFIG_URL: Optional[str] = Field( + APOLLO_CONFIG_URL: str | None = Field( description="apollo config url", default=None, ) - APOLLO_NAMESPACE: Optional[str] = Field( + APOLLO_NAMESPACE: str | None = Field( description="apollo namespace", default=None, ) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 56c61c2886..fec527e4cb 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,5 +1,3 @@ -from typing import Optional - import flask_restx from flask_login import current_user from flask_restx import Resource, fields, marshal_with @@ -50,7 +48,7 @@ class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[type] = None + resource_model: type | None = None resource_id_field: str | None = None token_prefix: str | None = None max_keys = 10 @@ -103,7 +101,7 @@ class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[type] = None + resource_model: type | None = None resource_id_field: str | None = None def delete(self, resource_id, api_key_id): diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 5a871f896a..44aba01820 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional, ParamSpec, TypeVar, Union +from typing import ParamSpec, TypeVar, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db @@ -12,7 +12,7 @@ P = ParamSpec("P") R = TypeVar("R") -def _load_app_model(app_id: str) -> Optional[App]: +def _load_app_model(app_id: str) -> App | None: assert isinstance(current_user, Account) app_model = ( db.session.query(App) @@ -22,7 +22,7 @@ def _load_app_model(app_id: str) -> Optional[App]: return app_model -def get_app_model(view: Optional[Callable[P, R]] = None, *, mode: Union[AppMode, list[AppMode], None] = None): +def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): def decorator(view_func: Callable[P, R]): @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 4b94cc0eb6..1602ee6eea 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import requests from flask import current_app, redirect, request @@ -157,8 +156,8 @@ class OAuthCallback(Resource): ) -def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: - account: Optional[Account] = Account.get_by_openid(provider, user_info.id) +def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None: + account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: with Session(db.engine) as session: diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 6401f804c0..3a8ba64a03 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Concatenate, Optional, ParamSpec, TypeVar +from typing import Concatenate, ParamSpec, TypeVar from flask_login import current_user from flask_restx import Resource @@ -20,7 +20,7 @@ R = TypeVar("R") T = TypeVar("T") -def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): +def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None): def decorator(view: Callable[Concatenate[InstalledApp, P], R]): @wraps(view) def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): @@ -50,7 +50,7 @@ def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], return decorator -def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): +def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None): def decorator(view: Callable[Concatenate[InstalledApp, P], R]): @wraps(view) def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 21cbf56074..206a5d1cc2 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -1,5 +1,4 @@ from mimetypes import guess_extension -from typing import Optional from flask_restx import Resource, reqparse from flask_restx.api import HTTPStatus @@ -73,11 +72,11 @@ class PluginUploadFileApi(Resource): nonce: str = args["nonce"] sign: str = args["sign"] tenant_id: str = args["tenant_id"] - user_id: Optional[str] = args.get("user_id") + user_id: str | None = args.get("user_id") user = get_user(tenant_id, user_id) - filename: Optional[str] = file.filename - mimetype: Optional[str] = file.mimetype + filename: str | None = file.filename + mimetype: str | None = file.mimetype if not filename or not mimetype: raise Forbidden("Invalid request.") diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 369e421599..3776d0be0e 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional, ParamSpec, TypeVar, cast +from typing import ParamSpec, TypeVar, cast from flask import current_app, request from flask_login import user_logged_in @@ -54,7 +54,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: return user_model -def get_user_tenant(view: Optional[Callable[P, R]] = None): +def get_user_tenant(view: Callable[P, R] | None = None): def decorator(view_func: Callable[P, R]): @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): @@ -106,7 +106,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None): return decorator(view) -def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]): +def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): def decorator(view_func: Callable[P, R]): def decorated_view(*args: P.args, **kwargs: P.kwargs): try: diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 80e6037cc7..a8629dca20 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union from flask import Response from flask_restx import Resource, reqparse @@ -73,7 +73,7 @@ class MCPAppApi(Resource): ValidationError: Invalid request format or parameters """ args = mcp_request_parser.parse_args() - request_id: Optional[Union[int, str]] = args.get("id") + request_id: Union[int, str] | None = args.get("id") mcp_request = self._parse_mcp_request(args) with Session(db.engine, expire_on_commit=False) as session: @@ -107,7 +107,7 @@ class MCPAppApi(Resource): def _process_mcp_message( self, mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification, - request_id: Optional[Union[int, str]], + request_id: Union[int, str] | None, app: App, mcp_server: AppMCPServer, user_input_form: list[VariableEntity], @@ -130,7 +130,7 @@ class MCPAppApi(Resource): def _handle_request( self, mcp_request: mcp_types.ClientRequest, - request_id: Optional[Union[int, str]], + request_id: Union[int, str] | None, app: App, mcp_server: AppMCPServer, user_input_form: list[VariableEntity], diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index c38465c98e..1a40707c65 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -3,7 +3,7 @@ from collections.abc import Callable from datetime import timedelta from enum import StrEnum, auto from functools import wraps -from typing import Concatenate, Optional, ParamSpec, TypeVar +from typing import Concatenate, ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in @@ -42,7 +42,7 @@ class FetchUserArg(BaseModel): required: bool = False -def validate_app_token(view: Optional[Callable[P, R]] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): +def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None): def decorator(view_func: Callable[P, R]): @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): @@ -189,7 +189,7 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): return interceptor -def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None): +def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): def decorator(view: Callable[Concatenate[T, P], R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): @@ -267,7 +267,7 @@ def validate_and_get_api_token(scope: str | None = None): return api_token -def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser: +def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None = None) -> EndUser: """ Create or update session terminal based on user ID. """ diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index e79456535a..ba03c4eae4 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,7 +1,7 @@ from collections.abc import Callable from datetime import UTC, datetime from functools import wraps -from typing import Concatenate, Optional, ParamSpec, TypeVar +from typing import Concatenate, ParamSpec, TypeVar from flask import request from flask_restx import Resource @@ -21,7 +21,7 @@ P = ParamSpec("P") R = TypeVar("R") -def validate_jwt_token(view: Optional[Callable[Concatenate[App, EndUser, P], R]] = None): +def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = None): def decorator(view: Callable[Concatenate[App, EndUser, P], R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1bcf83de6a..0a874e9085 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,7 +1,7 @@ import json import logging import uuid -from typing import Optional, Union, cast +from typing import Union, cast from sqlalchemy import select @@ -60,8 +60,8 @@ class BaseAgentRunner(AppRunner): message: Message, user_id: str, model_instance: ModelInstance, - memory: Optional[TokenBufferMemory] = None, - prompt_messages: Optional[list[PromptMessage]] = None, + memory: TokenBufferMemory | None = None, + prompt_messages: list[PromptMessage] | None = None, ): self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -112,7 +112,7 @@ class BaseAgentRunner(AppRunner): features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] - self.query: Optional[str] = "" + self.query: str | None = "" self._current_thoughts: list[PromptMessage] = [] def _repack_app_generate_entity( diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index d1d5a011e0..25ad6dc060 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional +from typing import Any from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit @@ -70,12 +70,12 @@ class CotAgentRunner(BaseAgentRunner, ABC): self._prompt_messages_tools = prompt_messages_tools function_call_state = True - llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} + llm_usage: dict[str, LLMUsage | None] = {"usage": None} final_answer = "" prompt_messages: list = [] # Initialize prompt_messages agent_thought_id = "" # Initialize agent_thought_id - def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: @@ -122,7 +122,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): callbacks=[], ) - usage_dict: dict[str, Optional[LLMUsage]] = {} + usage_dict: dict[str, LLMUsage | None] = {} react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( agent_response="", @@ -274,7 +274,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): action: AgentScratchpadUnit.Action, tool_instances: Mapping[str, Tool], message_file_ids: list[str], - trace_manager: Optional[TraceQueueManager] = None, + trace_manager: TraceQueueManager | None = None, ) -> tuple[str, ToolInvokeMeta]: """ handle invoke action diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 3a4d31e047..da9a001d84 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,5 +1,4 @@ import json -from typing import Optional from core.agent.cot_agent_runner import CotAgentRunner from core.model_runtime.entities.message_entities import ( @@ -31,7 +30,7 @@ class CotCompletionAgentRunner(CotAgentRunner): return system_prompt - def _organize_historic_prompt(self, current_session_messages: Optional[list[PromptMessage]] = None) -> str: + def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str: """ Organize historic prompt """ diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 816d2782f0..220feced1d 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, Field @@ -50,11 +50,11 @@ class AgentScratchpadUnit(BaseModel): "action_input": self.action_input, } - agent_response: Optional[str] = None - thought: Optional[str] = None - action_str: Optional[str] = None - observation: Optional[str] = None - action: Optional[Action] = None + agent_response: str | None = None + thought: str | None = None + action_str: str | None = None + observation: str | None = None + action: Action | None = None def is_final(self) -> bool: """ @@ -81,8 +81,8 @@ class AgentEntity(BaseModel): provider: str model: str strategy: Strategy - prompt: Optional[AgentPromptEntity] = None - tools: Optional[list[AgentToolEntity]] = None + prompt: AgentPromptEntity | None = None + tools: list[AgentToolEntity] | None = None max_iteration: int = 10 diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 5236266908..dcc1326b33 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -2,7 +2,7 @@ import json import logging from collections.abc import Generator from copy import deepcopy -from typing import Any, Optional, Union +from typing import Any, Union from core.agent.base_agent_runner import BaseAgentRunner from core.app.apps.base_app_queue_manager import PublishFrom @@ -52,14 +52,14 @@ class FunctionCallAgentRunner(BaseAgentRunner): # continue to run until there is not any tool call function_call_state = True - llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} + llm_usage: dict[str, LLMUsage | None] = {"usage": None} final_answer = "" prompt_messages: list = [] # Initialize prompt_messages # get tracing instance trace_manager = app_generate_entity.trace_manager - def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 1133ecf66c..90aa7b5fd4 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator @@ -53,7 +53,7 @@ class AgentStrategyParameter(PluginParameter): return cast_parameter_value(self, value) type: AgentStrategyParameterType = Field(..., description="The type of the parameter") - help: Optional[I18nObject] = None + help: I18nObject | None = None def init_frontend_parameter(self, value: Any): return init_frontend_parameter(self, self.type, value) @@ -61,7 +61,7 @@ class AgentStrategyParameter(PluginParameter): class AgentStrategyProviderEntity(BaseModel): identity: AgentStrategyProviderIdentity - plugin_id: Optional[str] = Field(None, description="The id of the plugin") + plugin_id: str | None = Field(None, description="The id of the plugin") class AgentStrategyIdentity(ToolIdentity): @@ -84,9 +84,9 @@ class AgentStrategyEntity(BaseModel): identity: AgentStrategyIdentity parameters: list[AgentStrategyParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The description of the agent strategy") - output_schema: Optional[dict] = None - features: Optional[list[AgentFeature]] = None - meta_version: Optional[str] = None + output_schema: dict | None = None + features: list[AgentFeature] | None = None + meta_version: str | None = None # pydantic configs model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/agent/strategy/base.py b/api/core/agent/strategy/base.py index a52a1dfd7a..8a9be05dde 100644 --- a/api/core/agent/strategy/base.py +++ b/api/core/agent/strategy/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Sequence -from typing import Any, Optional +from typing import Any from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyParameter @@ -16,10 +16,10 @@ class BaseAgentStrategy(ABC): self, params: dict[str, Any], user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - credentials: Optional[InvokeCredentials] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. @@ -37,9 +37,9 @@ class BaseAgentStrategy(ABC): self, params: dict[str, Any], user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - credentials: Optional[InvokeCredentials] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, ) -> Generator[AgentInvokeMessage, None, None]: pass diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index 04661581a7..a3cc798352 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Sequence -from typing import Any, Optional +from typing import Any from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter @@ -38,10 +38,10 @@ class PluginAgentStrategy(BaseAgentStrategy): self, params: dict[str, Any], user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - credentials: Optional[InvokeCredentials] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 97ede178c7..e925d6dd52 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -1,12 +1,10 @@ -from typing import Optional - from core.app.app_config.entities import SensitiveWordAvoidanceEntity from core.moderation.factory import ModerationFactory class SensitiveWordAvoidanceConfigManager: @classmethod - def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: + def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None: sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 8887d2500c..eab26e5af9 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -1,12 +1,10 @@ -from typing import Optional - from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity from core.agent.prompt.template import REACT_PROMPT_TEMPLATES class AgentConfigManager: @classmethod - def convert(cls, config: dict) -> Optional[AgentEntity]: + def convert(cls, config: dict) -> AgentEntity | None: """ Convert model config to model config diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index fcbf479e2e..4b824bde76 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,5 +1,4 @@ import uuid -from typing import Optional from core.app.app_config.entities import ( DatasetEntity, @@ -14,7 +13,7 @@ from services.dataset_service import DatasetService class DatasetConfigManager: @classmethod - def convert(cls, config: dict) -> Optional[DatasetEntity]: + def convert(cls, config: dict) -> DatasetEntity | None: """ Convert model config to model config diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 37745506e9..98ef6eb0c2 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from enum import StrEnum, auto -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator @@ -17,7 +17,7 @@ class ModelConfigEntity(BaseModel): provider: str model: str - mode: Optional[str] = None + mode: str | None = None parameters: dict[str, Any] = Field(default_factory=dict) stop: list[str] = Field(default_factory=list) @@ -53,7 +53,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel): assistant: str prompt: str - role_prefix: Optional[RolePrefixEntity] = None + role_prefix: RolePrefixEntity | None = None class PromptTemplateEntity(BaseModel): @@ -84,9 +84,9 @@ class PromptTemplateEntity(BaseModel): raise ValueError(f"invalid prompt type value {value}") prompt_type: PromptType - simple_prompt_template: Optional[str] = None - advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None - advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None + simple_prompt_template: str | None = None + advanced_chat_prompt_template: AdvancedChatPromptTemplateEntity | None = None + advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None class VariableEntityType(StrEnum): @@ -112,7 +112,7 @@ class VariableEntity(BaseModel): type: VariableEntityType required: bool = False hide: bool = False - max_length: Optional[int] = None + max_length: int | None = None options: Sequence[str] = Field(default_factory=list) allowed_file_types: Sequence[FileType] = Field(default_factory=list) allowed_file_extensions: Sequence[str] = Field(default_factory=list) @@ -186,8 +186,8 @@ class MetadataFilteringCondition(BaseModel): Metadata Filtering Condition. """ - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) class DatasetRetrieveConfigEntity(BaseModel): @@ -217,18 +217,18 @@ class DatasetRetrieveConfigEntity(BaseModel): return mode raise ValueError(f"invalid retrieve strategy value {value}") - query_variable: Optional[str] = None # Only when app mode is completion + query_variable: str | None = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy - top_k: Optional[int] = None - score_threshold: Optional[float] = 0.0 - rerank_mode: Optional[str] = "reranking_model" - reranking_model: Optional[dict] = None - weights: Optional[dict] = None - reranking_enabled: Optional[bool] = True - metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_model_config: Optional[ModelConfig] = None - metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None + top_k: int | None = None + score_threshold: float | None = 0.0 + rerank_mode: str | None = "reranking_model" + reranking_model: dict | None = None + weights: dict | None = None + reranking_enabled: bool | None = True + metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled" + metadata_model_config: ModelConfig | None = None + metadata_filtering_conditions: MetadataFilteringCondition | None = None class DatasetEntity(BaseModel): @@ -255,8 +255,8 @@ class TextToSpeechEntity(BaseModel): """ enabled: bool - voice: Optional[str] = None - language: Optional[str] = None + voice: str | None = None + language: str | None = None class TracingConfigEntity(BaseModel): @@ -269,15 +269,15 @@ class TracingConfigEntity(BaseModel): class AppAdditionalFeatures(BaseModel): - file_upload: Optional[FileUploadConfig] = None - opening_statement: Optional[str] = None + file_upload: FileUploadConfig | None = None + opening_statement: str | None = None suggested_questions: list[str] = [] suggested_questions_after_answer: bool = False show_retrieve_source: bool = False more_like_this: bool = False speech_to_text: bool = False - text_to_speech: Optional[TextToSpeechEntity] = None - trace_config: Optional[TracingConfigEntity] = None + text_to_speech: TextToSpeechEntity | None = None + trace_config: TracingConfigEntity | None = None class AppConfig(BaseModel): @@ -290,7 +290,7 @@ class AppConfig(BaseModel): app_mode: AppMode additional_features: AppAdditionalFeatures variables: list[VariableEntity] = [] - sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None + sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None class EasyUIBasedAppModelConfigFrom(StrEnum): @@ -313,7 +313,7 @@ class EasyUIBasedAppConfig(AppConfig): app_model_config_dict: dict model: ModelConfigEntity prompt_template: PromptTemplateEntity - dataset: Optional[DatasetEntity] = None + dataset: DatasetEntity | None = None external_data_variables: list[ExternalDataVariableEntity] = [] diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 84b032d6ca..42e19001b3 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -3,7 +3,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping -from typing import Any, Literal, Optional, Union, overload +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -390,7 +390,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): application_generate_entity: AdvancedChatAppGenerateEntity, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, - conversation: Optional[Conversation] = None, + conversation: Conversation | None = None, stream: bool = True, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 635754a201..b8e0b5b310 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -231,7 +231,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): def query_app_annotations_to_reply( self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom - ) -> Optional[MessageAnnotation]: + ) -> MessageAnnotation | None: """ Query app annotations to reply :param app_record: app record diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index cec3b83674..23ce8a7880 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,7 +4,7 @@ import time from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager from threading import Thread -from typing import Any, Optional, Union +from typing import Any, Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -233,7 +233,7 @@ class AdvancedChatAppGenerateTaskPipeline: return None def _wrapper_process_stream_response( - self, trace_manager: Optional[TraceQueueManager] = None + self, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id @@ -294,7 +294,7 @@ class AdvancedChatAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState: + def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState: """Fluent validation for graph runtime state.""" if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") @@ -411,8 +411,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueTextChunkEvent, *, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle text chunk events.""" @@ -538,8 +538,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowSucceededEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow succeeded events.""" @@ -569,8 +569,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowPartialSuccessEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow partial success events.""" @@ -601,8 +601,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowFailedEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow failed events.""" @@ -636,8 +636,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueStopEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle stop events.""" @@ -677,7 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueAdvancedChatMessageEndEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, + graph_runtime_state: GraphRuntimeState | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle advanced chat message end events.""" @@ -775,10 +775,10 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: Any, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + graph_runtime_state: GraphRuntimeState | None = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, ) -> Generator[StreamResponse, None, None]: """Dispatch events using elegant pattern matching.""" handlers = self._get_event_handlers() @@ -830,15 +830,15 @@ class AdvancedChatAppGenerateTaskPipeline: def _process_stream_response( self, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response using elegant Fluent Python patterns. Maintains exact same functionality as original 57-if-statement version. """ # Initialize graph runtime state - graph_runtime_state: Optional[GraphRuntimeState] = None + graph_runtime_state: GraphRuntimeState | None = None for queue_message in self._base_task_pipeline.queue_manager.listen(): event = queue_message.event @@ -888,7 +888,7 @@ class AdvancedChatAppGenerateTaskPipeline: if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None): + def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None): message = self._get_message(session=session) # If there are assistant files, remove markdown image links from answer diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 54d1a9595f..9ce841f432 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from core.agent.entities import AgentEntity from core.app.app_config.base_app_config_manager import BaseAppConfigManager @@ -30,7 +30,7 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): Agent Chatbot App Config Entity. """ - agent: Optional[AgentEntity] = None + agent: AgentEntity | None = None class AgentChatAppConfigManager(BaseAppConfigManager): @@ -39,8 +39,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager): cls, app_model: App, app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None, + conversation: Conversation | None = None, + override_config_dict: dict | None = None, ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 6681fc6e48..8f13599ead 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, Union, final +from typing import TYPE_CHECKING, Any, Union, final from sqlalchemy.orm import Session @@ -24,7 +24,7 @@ class BaseAppGenerator: def _prepare_user_inputs( self, *, - user_inputs: Optional[Mapping[str, Any]], + user_inputs: Mapping[str, Any] | None, variables: Sequence["VariableEntity"], tenant_id: str, strict_type_validation: bool = False, diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 2a7fe7902b..a58795bccb 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -2,7 +2,7 @@ import queue import time from abc import abstractmethod from enum import IntEnum, auto -from typing import Any, Optional +from typing import Any from sqlalchemy.orm import DeclarativeMeta @@ -116,7 +116,7 @@ class AppQueueManager: Set task stop flag :return: """ - result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id)) + result: Any | None = redis_client.get(cls._generate_task_belong_cache_key(task_id)) if result is None: return diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index dafdcdd429..e7db3bc41b 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,7 +1,7 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -82,11 +82,11 @@ class AppRunner: prompt_template_entity: PromptTemplateEntity, inputs: Mapping[str, str], files: Sequence["File"], - query: Optional[str] = None, - context: Optional[str] = None, - memory: Optional[TokenBufferMemory] = None, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + query: str | None = None, + context: str | None = None, + memory: TokenBufferMemory | None = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: """ Organize prompt messages :param context: @@ -161,7 +161,7 @@ class AppRunner: prompt_messages: list, text: str, stream: bool, - usage: Optional[LLMUsage] = None, + usage: LLMUsage | None = None, ): """ Direct output @@ -375,7 +375,7 @@ class AppRunner: def query_app_annotations_to_reply( self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom - ) -> Optional[MessageAnnotation]: + ) -> MessageAnnotation | None: """ Query app annotations to reply :param app_record: app record diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index 96a3db8502..4b6720a3c3 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -32,8 +30,8 @@ class ChatAppConfigManager(BaseAppConfigManager): cls, app_model: App, app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None, + conversation: Conversation | None = None, + override_config_dict: dict | None = None, ) -> ChatAppConfig: """ Convert app model config to chat app config diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 937b2a7dd7..1b4d28a5b8 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -1,7 +1,7 @@ import time from collections.abc import Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from sqlalchemy.orm import Session @@ -140,7 +140,7 @@ class WorkflowResponseConverter: event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[NodeStartStreamResponse]: + ) -> NodeStartStreamResponse | None: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_execution_id: @@ -190,7 +190,7 @@ class WorkflowResponseConverter: | QueueNodeExceptionEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[NodeFinishStreamResponse]: + ) -> NodeFinishStreamResponse | None: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_execution_id: @@ -235,7 +235,7 @@ class WorkflowResponseConverter: event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: + ) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_execution_id: diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 3a1f29689d..eb1902f12e 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -24,7 +22,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig): class CompletionAppConfigManager(BaseAppConfigManager): @classmethod def get_app_config( - cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None ) -> CompletionAppConfig: """ Convert app model config to completion app config diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 92f3b6507c..170c6a274b 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from typing import Optional, Union, cast +from typing import Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -84,7 +84,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): logger.exception("Failed to handle response, conversation_id: %s", conversation.id) raise e - def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: + def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig: if conversation: stmt = select(AppModelConfig).where( AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id @@ -112,7 +112,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity, ], - conversation: Optional[Conversation] = None, + conversation: Conversation | None = None, ) -> tuple[Conversation, Message]: """ Initialize generate records diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 60395f0416..83c29ca166 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -3,7 +3,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Literal, Optional, Union, overload +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -53,7 +53,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[True], call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, ) -> Generator[Mapping | str, None, None]: ... @overload @@ -67,7 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[False], call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, ) -> Mapping[str, Any]: ... @overload @@ -81,7 +81,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool, call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... def generate( @@ -94,7 +94,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool = True, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: files: Sequence[Mapping[str, Any]] = args.get("files") or [] @@ -200,7 +200,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ @@ -434,7 +434,7 @@ class WorkflowAppGenerator(BaseAppGenerator): queue_manager: AppQueueManager, context: contextvars.Context, variable_loader: VariableLoader, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, ): """ Generate worker in a new thread. diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 42b3575807..3026be27f8 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, cast +from typing import cast from configs import dify_config from core.app.apps.base_app_queue_manager import AppQueueManager @@ -31,7 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, variable_loader: VariableLoader, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, workflow: Workflow, system_user_id: str, ): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 1c950063dd..638c4e938c 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Callable, Generator from contextlib import contextmanager -from typing import Any, Optional, Union +from typing import Any, Union from sqlalchemy.orm import Session @@ -206,7 +206,7 @@ class WorkflowAppGenerateTaskPipeline: return None def _wrapper_process_stream_response( - self, trace_manager: Optional[TraceQueueManager] = None + self, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id @@ -268,7 +268,7 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState: + def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState: """Fluent validation for graph runtime state.""" if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") @@ -474,8 +474,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueWorkflowSucceededEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow succeeded events.""" @@ -508,8 +508,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueWorkflowPartialSuccessEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow partial success events.""" @@ -543,8 +543,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: Union[QueueWorkflowFailedEvent, QueueStopEvent], *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow failed and stop events.""" @@ -581,8 +581,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueTextChunkEvent, *, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle text chunk events.""" @@ -635,10 +635,10 @@ class WorkflowAppGenerateTaskPipeline: self, event: Any, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + graph_runtime_state: GraphRuntimeState | None = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, ) -> Generator[StreamResponse, None, None]: """Dispatch events using elegant pattern matching.""" handlers = self._get_event_handlers() @@ -701,8 +701,8 @@ class WorkflowAppGenerateTaskPipeline: def _process_stream_response( self, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response using elegant Fluent Python patterns. @@ -769,7 +769,7 @@ class WorkflowAppGenerateTaskPipeline: session.commit() def _text_chunk_to_stream_response( - self, text: str, from_variable_selector: Optional[list[str]] = None + self, text: str, from_variable_selector: list[str] | None = None ) -> TextChunkStreamResponse: """ Handle completed event. diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 1d5ebabaf7..4c0abd0983 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator @@ -96,7 +96,7 @@ class AppGenerateEntity(BaseModel): # app config app_config: Any = None - file_upload_config: Optional[FileUploadConfig] = None + file_upload_config: FileUploadConfig | None = None inputs: Mapping[str, Any] files: Sequence[File] @@ -114,7 +114,7 @@ class AppGenerateEntity(BaseModel): # tracing instance # Using Any to avoid circular import with TraceQueueManager - trace_manager: Optional[Any] = None + trace_manager: Any | None = None class EasyUIBasedAppGenerateEntity(AppGenerateEntity): @@ -126,7 +126,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): app_config: EasyUIBasedAppConfig = None # type: ignore model_conf: ModelConfigWithCredentialsEntity - query: Optional[str] = None + query: str | None = None # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -137,8 +137,8 @@ class ConversationAppGenerateEntity(AppGenerateEntity): Base entity for conversation-based app generation. """ - conversation_id: Optional[str] = None - parent_message_id: Optional[str] = Field( + conversation_id: str | None = None + parent_message_id: str | None = Field( default=None, description=( "Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API." @@ -188,7 +188,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): # app config app_config: WorkflowUIBasedAppConfig = None # type: ignore - workflow_run_id: Optional[str] = None + workflow_run_id: str | None = None query: str class SingleIterationRunEntity(BaseModel): @@ -199,7 +199,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): node_id: str inputs: Mapping - single_iteration_run: Optional[SingleIterationRunEntity] = None + single_iteration_run: SingleIterationRunEntity | None = None class SingleLoopRunEntity(BaseModel): """ @@ -209,7 +209,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): node_id: str inputs: Mapping - single_loop_run: Optional[SingleLoopRunEntity] = None + single_loop_run: SingleLoopRunEntity | None = None class WorkflowAppGenerateEntity(AppGenerateEntity): @@ -229,7 +229,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): node_id: str inputs: dict - single_iteration_run: Optional[SingleIterationRunEntity] = None + single_iteration_run: SingleIterationRunEntity | None = None class SingleLoopRunEntity(BaseModel): """ @@ -239,4 +239,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): node_id: str inputs: dict - single_loop_run: Optional[SingleLoopRunEntity] = None + single_loop_run: SingleLoopRunEntity | None = None diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 8d20b6bc1b..6d2808b447 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime from enum import StrEnum, auto -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -81,20 +81,20 @@ class QueueIterationStartEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None + metadata: Mapping[str, Any] | None = None class QueueIterationNextEvent(AppQueueEvent): @@ -109,19 +109,19 @@ class QueueIterationNextEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration run in parallel mode run id""" node_run_index: int - output: Optional[Any] = None # output for the current iteration - duration: Optional[float] = None + output: Any | None = None # output for the current iteration + duration: float | None = None class QueueIterationCompletedEvent(AppQueueEvent): @@ -135,23 +135,23 @@ class QueueIterationCompletedEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 - error: Optional[str] = None + error: str | None = None class QueueLoopStartEvent(AppQueueEvent): @@ -164,20 +164,20 @@ class QueueLoopStartEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None + metadata: Mapping[str, Any] | None = None class QueueLoopNextEvent(AppQueueEvent): @@ -192,19 +192,19 @@ class QueueLoopNextEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration run in parallel mode run id""" node_run_index: int - output: Optional[Any] = None # output for the current loop - duration: Optional[float] = None + output: Any | None = None # output for the current loop + duration: float | None = None class QueueLoopCompletedEvent(AppQueueEvent): @@ -218,23 +218,23 @@ class QueueLoopCompletedEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 - error: Optional[str] = None + error: str | None = None class QueueTextChunkEvent(AppQueueEvent): @@ -244,11 +244,11 @@ class QueueTextChunkEvent(AppQueueEvent): event: QueueEvent = QueueEvent.TEXT_CHUNK text: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None """from variable selector""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -285,9 +285,9 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES retriever_resources: Sequence[RetrievalSourceMetadata] - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -306,7 +306,7 @@ class QueueMessageEndEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.MESSAGE_END - llm_result: Optional[LLMResult] = None + llm_result: LLMResult | None = None class QueueAdvancedChatMessageEndEvent(AppQueueEvent): @@ -332,7 +332,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None class QueueWorkflowFailedEvent(AppQueueEvent): @@ -352,7 +352,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent): event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED exceptions_count: int - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None class QueueNodeStartedEvent(AppQueueEvent): @@ -367,23 +367,23 @@ class QueueNodeStartedEvent(AppQueueEvent): node_type: NodeType node_data: BaseNodeData node_run_index: int = 1 - predecessor_node_id: Optional[str] = None - parallel_id: Optional[str] = None + predecessor_node_id: str | None = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration run in parallel mode run id""" - agent_strategy: Optional[AgentNodeStrategyInit] = None + agent_strategy: AgentNodeStrategyInit | None = None class QueueNodeSucceededEvent(AppQueueEvent): @@ -397,30 +397,30 @@ class QueueNodeSucceededEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None - error: Optional[str] = None + error: str | None = None """single iteration duration map""" - iteration_duration_map: Optional[dict[str, float]] = None + iteration_duration_map: dict[str, float] | None = None """single loop duration map""" - loop_duration_map: Optional[dict[str, float]] = None + loop_duration_map: dict[str, float] | None = None class QueueAgentLogEvent(AppQueueEvent): @@ -436,7 +436,7 @@ class QueueAgentLogEvent(AppQueueEvent): error: str | None = None status: str data: Mapping[str, Any] - metadata: Optional[Mapping[str, Any]] = None + metadata: Mapping[str, Any] | None = None node_id: str @@ -445,10 +445,10 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent): event: QueueEvent = QueueEvent.RETRY - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str retry_index: int # retry index @@ -465,24 +465,24 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -498,24 +498,24 @@ class QueueNodeInLoopFailedEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -531,24 +531,24 @@ class QueueNodeExceptionEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -564,24 +564,24 @@ class QueueNodeFailedEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -610,7 +610,7 @@ class QueueErrorEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.ERROR - error: Optional[Any] = None + error: Any | None = None class QueuePingEvent(AppQueueEvent): @@ -689,13 +689,13 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent): parallel_id: str parallel_start_node_id: str - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -708,13 +708,13 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent): parallel_id: str parallel_start_node_id: str - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -727,12 +727,12 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent): parallel_id: str parallel_start_node_id: str - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 717e5d715a..92be2fce37 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum, auto -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field @@ -110,7 +110,7 @@ class MessageStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE id: str answer: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None class MessageAudioStreamResponse(StreamResponse): @@ -139,7 +139,7 @@ class MessageEndStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE_END id: str metadata: dict = Field(default_factory=dict) - files: Optional[Sequence[Mapping[str, Any]]] = None + files: Sequence[Mapping[str, Any]] | None = None class MessageFileStreamResponse(StreamResponse): @@ -172,12 +172,12 @@ class AgentThoughtStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.AGENT_THOUGHT id: str position: int - thought: Optional[str] = None - observation: Optional[str] = None - tool: Optional[str] = None - tool_labels: Optional[dict] = None - tool_input: Optional[str] = None - message_files: Optional[list[str]] = None + thought: str | None = None + observation: str | None = None + tool: str | None = None + tool_labels: dict | None = None + tool_input: str | None = None + message_files: list[str] | None = None class AgentMessageStreamResponse(StreamResponse): @@ -223,16 +223,16 @@ class WorkflowFinishStreamResponse(StreamResponse): id: str workflow_id: str status: str - outputs: Optional[Mapping[str, Any]] = None - error: Optional[str] = None + outputs: Mapping[str, Any] | None = None + error: str | None = None elapsed_time: float total_tokens: int total_steps: int - created_by: Optional[dict] = None + created_by: dict | None = None created_at: int finished_at: int - exceptions_count: Optional[int] = 0 - files: Optional[Sequence[Mapping[str, Any]]] = [] + exceptions_count: int | None = 0 + files: Sequence[Mapping[str, Any]] | None = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED workflow_run_id: str @@ -254,18 +254,18 @@ class NodeStartStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None created_at: int extras: dict = Field(default_factory=dict) - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None - parallel_run_id: Optional[str] = None - agent_strategy: Optional[AgentNodeStrategyInit] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None + parallel_run_id: str | None = None + agent_strategy: AgentNodeStrategyInit | None = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -311,23 +311,23 @@ class NodeFinishStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None status: str - error: Optional[str] = None + error: str | None = None elapsed_time: float - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None created_at: int finished_at: int - files: Optional[Sequence[Mapping[str, Any]]] = [] - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + files: Sequence[Mapping[str, Any]] | None = [] + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None event: StreamEvent = StreamEvent.NODE_FINISHED workflow_run_id: str @@ -380,23 +380,23 @@ class NodeRetryStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None status: str - error: Optional[str] = None + error: str | None = None elapsed_time: float - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None created_at: int finished_at: int - files: Optional[Sequence[Mapping[str, Any]]] = [] - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + files: Sequence[Mapping[str, Any]] | None = [] + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None retry_index: int = 0 event: StreamEvent = StreamEvent.NODE_RETRY @@ -448,10 +448,10 @@ class ParallelBranchStartStreamResponse(StreamResponse): parallel_id: str parallel_branch_id: str - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None created_at: int event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED @@ -471,12 +471,12 @@ class ParallelBranchFinishedStreamResponse(StreamResponse): parallel_id: str parallel_branch_id: str - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None status: str - error: Optional[str] = None + error: str | None = None created_at: int event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED @@ -502,8 +502,8 @@ class IterationNodeStartStreamResponse(StreamResponse): extras: dict = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.ITERATION_STARTED workflow_run_id: str @@ -526,12 +526,12 @@ class IterationNodeNextStreamResponse(StreamResponse): title: str index: int created_at: int - pre_iteration_output: Optional[Any] = None + pre_iteration_output: Any | None = None extras: dict = Field(default_factory=dict) - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parallel_mode_run_id: Optional[str] = None - duration: Optional[float] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parallel_mode_run_id: str | None = None + duration: float | None = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str @@ -552,19 +552,19 @@ class IterationNodeCompletedStreamResponse(StreamResponse): node_id: str node_type: str title: str - outputs: Optional[Mapping] = None + outputs: Mapping | None = None created_at: int - extras: Optional[dict] = None - inputs: Optional[Mapping] = None + extras: dict | None = None + inputs: Mapping | None = None status: WorkflowNodeExecutionStatus - error: Optional[str] = None + error: str | None = None elapsed_time: float total_tokens: int - execution_metadata: Optional[Mapping] = None + execution_metadata: Mapping | None = None finished_at: int steps: int - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.ITERATION_COMPLETED workflow_run_id: str @@ -589,8 +589,8 @@ class LoopNodeStartStreamResponse(StreamResponse): extras: dict = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.LOOP_STARTED workflow_run_id: str @@ -613,12 +613,12 @@ class LoopNodeNextStreamResponse(StreamResponse): title: str index: int created_at: int - pre_loop_output: Optional[Any] = None + pre_loop_output: Any | None = None extras: dict = Field(default_factory=dict) - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parallel_mode_run_id: Optional[str] = None - duration: Optional[float] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parallel_mode_run_id: str | None = None + duration: float | None = None event: StreamEvent = StreamEvent.LOOP_NEXT workflow_run_id: str @@ -639,19 +639,19 @@ class LoopNodeCompletedStreamResponse(StreamResponse): node_id: str node_type: str title: str - outputs: Optional[Mapping] = None + outputs: Mapping | None = None created_at: int - extras: Optional[dict] = None - inputs: Optional[Mapping] = None + extras: dict | None = None + inputs: Mapping | None = None status: WorkflowNodeExecutionStatus - error: Optional[str] = None + error: str | None = None elapsed_time: float total_tokens: int - execution_metadata: Optional[Mapping] = None + execution_metadata: Mapping | None = None finished_at: int steps: int - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.LOOP_COMPLETED workflow_run_id: str @@ -669,7 +669,7 @@ class TextChunkStreamResponse(StreamResponse): """ text: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None event: StreamEvent = StreamEvent.TEXT_CHUNK data: Data @@ -731,7 +731,7 @@ class WorkflowAppStreamResponse(AppStreamResponse): WorkflowAppStreamResponse entity """ - workflow_run_id: Optional[str] = None + workflow_run_id: str | None = None class AppBlockingResponse(BaseModel): @@ -796,8 +796,8 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): id: str workflow_id: str status: str - outputs: Optional[Mapping[str, Any]] = None - error: Optional[str] = None + outputs: Mapping[str, Any] | None = None + error: str | None = None elapsed_time: float total_tokens: int total_steps: int @@ -825,7 +825,7 @@ class AgentLogStreamResponse(StreamResponse): error: str | None = None status: str data: Mapping[str, Any] - metadata: Optional[Mapping[str, Any]] = None + metadata: Mapping[str, Any] | None = None node_id: str event: StreamEvent = StreamEvent.AGENT_LOG diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 3853dccdc5..79fbafe39e 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from sqlalchemy import select @@ -17,7 +16,7 @@ logger = logging.getLogger(__name__) class AnnotationReplyFeature: def query( self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom - ) -> Optional[MessageAnnotation]: + ) -> MessageAnnotation | None: """ Query app annotations to reply :param app_record: app record diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 6f13f11da0..ffa10cd43c 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -3,7 +3,7 @@ import time import uuid from collections.abc import Generator, Mapping from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Union from core.errors.error import AppInvokeQuotaExceededError from extensions.ext_redis import redis_client @@ -63,7 +63,7 @@ class RateLimit: if timeout_requests: redis_client.hdel(self.active_requests_key, *timeout_requests) - def enter(self, request_id: Optional[str] = None) -> str: + def enter(self, request_id: str | None = None) -> str: if self.disabled(): return RateLimit._UNLIMITED_REQUEST_ID if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL: diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 4931300901..45e3c0006b 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,5 @@ import logging import time -from typing import Optional from sqlalchemy import select from sqlalchemy.orm import Session @@ -101,7 +100,7 @@ class BasedGenerateTaskPipeline: """ return PingStreamResponse(task_id=self._application_generate_entity.task_id) - def _init_output_moderation(self) -> Optional[OutputModeration]: + def _init_output_moderation(self) -> OutputModeration | None: """ Init output moderation. :return: @@ -118,7 +117,7 @@ class BasedGenerateTaskPipeline: ) return None - def handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: + def handle_output_moderation_when_task_finished(self, completion: str) -> str | None: """ Handle output moderation when task finished. :param completion: completion diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index d4d5c9f7d2..67abb569e3 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Generator from threading import Thread -from typing import Optional, Union, cast +from typing import Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -109,7 +109,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): task_state=self._task_state, ) - self._conversation_name_generate_thread: Optional[Thread] = None + self._conversation_name_generate_thread: Thread | None = None def process( self, @@ -209,7 +209,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): return None def _wrapper_process_stream_response( - self, trace_manager: Optional[TraceQueueManager] = None + self, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id @@ -252,7 +252,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None + self, publisher: AppGeneratorTTSPublisher | None, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None): + def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None): """ Save message. :return: @@ -466,14 +466,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): task_id=self._application_generate_entity.task_id, id=message_id, answer=answer ) - def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]: + def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> AgentThoughtStreamResponse | None: """ Agent thought to stream response. :param event: agent thought event :return: """ with Session(db.engine, expire_on_commit=False) as session: - agent_thought: Optional[MessageAgentThought] = ( + agent_thought: MessageAgentThought | None = ( session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() ) diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 2cefb61df3..90ffdcf1f6 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -1,6 +1,6 @@ import logging from threading import Thread -from typing import Optional, Union +from typing import Union from flask import Flask, current_app from sqlalchemy import select @@ -52,7 +52,7 @@ class MessageCycleManager: self._application_generate_entity = application_generate_entity self._task_state = task_state - def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: + def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None: """ Generate conversation name. :param conversation_id: conversation id @@ -111,7 +111,7 @@ class MessageCycleManager: db.session.commit() db.session.close() - def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: + def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> MessageAnnotation | None: """ Handle annotation reply. :param event: event @@ -141,7 +141,7 @@ class MessageCycleManager: if self._application_generate_entity.app_config.additional_features.show_retrieve_source: self._task_state.metadata.retriever_resources = event.retriever_resources - def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: + def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None: """ Message file to stream response. :param event: event @@ -180,7 +180,7 @@ class MessageCycleManager: return None def message_to_stream_response( - self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None + self, answer: str, message_id: str, from_variable_selector: list[str] | None = None ) -> MessageStreamResponse: """ Message to stream response. diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 89190c36cc..1e0fba6215 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -5,7 +5,6 @@ import queue import re import threading from collections.abc import Iterable -from typing import Optional from core.app.entities.queue_entities import ( MessageQueueMessage, @@ -56,7 +55,7 @@ def _process_future( class AppGeneratorTTSPublisher: - def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None): + def __init__(self, tenant_id: str, voice: str, language: str | None = None): self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id self.msg_text = "" @@ -73,7 +72,7 @@ class AppGeneratorTTSPublisher: if not voice or voice not in values: self.voice = self.voices[0].get("value") self.max_sentence = 2 - self._last_audio_event: Optional[AudioTrunk] = None + self._last_audio_event: AudioTrunk | None = None # FIXME better way to handle this threading.start threading.Thread(target=self._runtime).start() self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 30cdab26dc..9ee02acc92 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -1,5 +1,5 @@ from collections.abc import Iterable, Mapping -from typing import Any, Optional, TextIO, Union +from typing import Any, TextIO, Union from pydantic import BaseModel @@ -23,7 +23,7 @@ def get_colored_text(text: str, color: str) -> str: return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" -def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None): +def print_text(text: str, color: str | None = None, end: str = "", file: TextIO | None = None): """Print text with highlighting and no end characters.""" text_to_print = get_colored_text(text, color) if color else text print(text_to_print, end=end, file=file) @@ -34,10 +34,10 @@ def print_text(text: str, color: Optional[str] = None, end: str = "", file: Opti class DifyAgentCallbackHandler(BaseModel): """Callback Handler that prints to std out.""" - color: Optional[str] = "" + color: str | None = "" current_loop: int = 1 - def __init__(self, color: Optional[str] = None): + def __init__(self, color: str | None = None): super().__init__() """Initialize callback handler.""" # use a specific color is not specified @@ -58,9 +58,9 @@ class DifyAgentCallbackHandler(BaseModel): tool_name: str, tool_inputs: Mapping[str, Any], tool_outputs: Iterable[ToolInvokeMessage] | str, - message_id: Optional[str] = None, - timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None, + message_id: str | None = None, + timer: Any | None = None, + trace_manager: TraceQueueManager | None = None, ): """If not the final action, print out observation.""" if dify_config.DEBUG: @@ -98,7 +98,7 @@ class DifyAgentCallbackHandler(BaseModel): else: print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) - def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any): + def on_agent_finish(self, color: str | None = None, **kwargs: Any): """Run on agent end.""" if dify_config.DEBUG: print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) diff --git a/api/core/callback_handler/workflow_tool_callback_handler.py b/api/core/callback_handler/workflow_tool_callback_handler.py index 350b18772b..23aabd9970 100644 --- a/api/core/callback_handler/workflow_tool_callback_handler.py +++ b/api/core/callback_handler/workflow_tool_callback_handler.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Iterable, Mapping -from typing import Any, Optional +from typing import Any from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text from core.ops.ops_trace_manager import TraceQueueManager @@ -14,9 +14,9 @@ class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): tool_name: str, tool_inputs: Mapping[str, Any], tool_outputs: Iterable[ToolInvokeMessage], - message_id: Optional[str] = None, - timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None, + message_id: str | None = None, + timer: Any | None = None, + trace_manager: TraceQueueManager | None = None, ) -> Generator[ToolInvokeMessage, None, None]: for tool_output in tool_outputs: print_text("\n[on_tool_execution]\n", color=self.color) diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index 90c9879733..6143b9b703 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,11 +1,9 @@ -from typing import Optional - from pydantic import BaseModel class PreviewDetail(BaseModel): content: str - child_chunks: Optional[list[str]] = None + child_chunks: list[str] | None = None class QAPreviewDetail(BaseModel): @@ -16,4 +14,4 @@ class QAPreviewDetail(BaseModel): class IndexingEstimate(BaseModel): total_segments: int preview: list[PreviewDetail] - qa_preview: Optional[list[QAPreviewDetail]] = None + qa_preview: list[QAPreviewDetail] | None = None diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 4794a691bd..663a8164c6 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,6 +1,5 @@ from collections.abc import Sequence from enum import StrEnum, auto -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -29,8 +28,8 @@ class SimpleModelProviderEntity(BaseModel): provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None supported_model_types: list[ModelType] def __init__(self, provider_entity: ProviderEntity): @@ -92,8 +91,8 @@ class DefaultModelProviderEntity(BaseModel): provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] = [] diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 5309e4e638..d694a27942 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -4,7 +4,6 @@ import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from typing import Optional from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import func, select @@ -92,7 +91,7 @@ class ProviderConfiguration(BaseModel): ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) - def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: + def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ Get current credentials. @@ -165,7 +164,7 @@ class ProviderConfiguration(BaseModel): return credentials - def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]: + def get_system_configuration_status(self) -> SystemConfigurationStatus | None: """ Get system configuration status. :return: @@ -793,9 +792,7 @@ class ProviderConfiguration(BaseModel): stmt = stmt.where(ProviderModelCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None - def get_custom_model_credential( - self, model_type: ModelType, model: str, credential_id: str | None - ) -> Optional[dict]: + def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None: """ Get custom model credentials. @@ -1272,7 +1269,7 @@ class ProviderConfiguration(BaseModel): return model_setting - def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]: + def get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None: """ Get provider model setting. :param model_type: model type @@ -1448,7 +1445,7 @@ class ProviderConfiguration(BaseModel): def get_provider_model( self, model_type: ModelType, model: str, only_active: bool = False - ) -> Optional[ModelWithProviderEntity]: + ) -> ModelWithProviderEntity | None: """ Get provider model. :param model_type: model type @@ -1465,7 +1462,7 @@ class ProviderConfiguration(BaseModel): return None def get_provider_models( - self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None + self, model_type: ModelType | None = None, only_active: bool = False, model: str | None = None ) -> list[ModelWithProviderEntity]: """ Get provider models. @@ -1649,7 +1646,7 @@ class ProviderConfiguration(BaseModel): model_types: Sequence[ModelType], provider_schema: ProviderEntity, model_setting_map: dict[ModelType, dict[str, ModelSettings]], - model: Optional[str] = None, + model: str | None = None, ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -1783,7 +1780,7 @@ class ProviderConfigurations(BaseModel): super().__init__(tenant_id=tenant_id) def get_models( - self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False + self, provider: str | None = None, model_type: ModelType | None = None, only_active: bool = False ) -> list[ModelWithProviderEntity]: """ Get available models. diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index ad23f4381d..0496959ce2 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,5 +1,5 @@ from enum import StrEnum, auto -from typing import Optional, Union +from typing import Union from pydantic import BaseModel, ConfigDict, Field @@ -49,7 +49,7 @@ class SystemConfigurationStatus(StrEnum): class RestrictModel(BaseModel): model: str - base_model_name: Optional[str] = None + base_model_name: str | None = None model_type: ModelType # pydantic configs @@ -84,9 +84,9 @@ class SystemConfiguration(BaseModel): """ enabled: bool - current_quota_type: Optional[ProviderQuotaType] = None + current_quota_type: ProviderQuotaType | None = None quota_configurations: list[QuotaConfiguration] = [] - credentials: Optional[dict] = None + credentials: dict | None = None class CustomProviderConfiguration(BaseModel): @@ -95,8 +95,8 @@ class CustomProviderConfiguration(BaseModel): """ credentials: dict - current_credential_id: Optional[str] = None - current_credential_name: Optional[str] = None + current_credential_id: str | None = None + current_credential_name: str | None = None available_credentials: list[CredentialConfiguration] = [] @@ -108,10 +108,10 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType credentials: dict | None = None - current_credential_id: Optional[str] = None - current_credential_name: Optional[str] = None + current_credential_id: str | None = None + current_credential_name: str | None = None available_model_credentials: list[CredentialConfiguration] = [] - unadded_to_model_list: Optional[bool] = False + unadded_to_model_list: bool | None = False # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -131,7 +131,7 @@ class CustomConfiguration(BaseModel): Model class for provider custom configuration. """ - provider: Optional[CustomProviderConfiguration] = None + provider: CustomProviderConfiguration | None = None models: list[CustomModelConfiguration] = [] can_added_models: list[UnaddedModelConfiguration] = [] @@ -205,12 +205,12 @@ class ProviderConfig(BasicProviderConfig): scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None required: bool = False - default: Optional[Union[int, str, float, bool]] = None - options: Optional[list[Option]] = None - label: Optional[I18nObject] = None - help: Optional[I18nObject] = None - url: Optional[str] = None - placeholder: Optional[I18nObject] = None + default: Union[int, str, float, bool] | None = None + options: list[Option] | None = None + label: I18nObject | None = None + help: I18nObject | None = None + url: str | None = None + placeholder: I18nObject | None = None def to_basic_provider_config(self) -> BasicProviderConfig: return BasicProviderConfig(type=self.type, name=self.name) diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 642f24a411..8c1ba98ae1 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -1,12 +1,9 @@ -from typing import Optional - - class LLMError(ValueError): """Base class for all LLM exceptions.""" - description: Optional[str] = None + description: str | None = None - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 8cb9b4ac58..c2789a7a35 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -4,7 +4,7 @@ import logging import os from enum import StrEnum, auto from pathlib import Path -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -19,12 +19,12 @@ class ExtensionModule(StrEnum): class ModuleExtension(BaseModel): - extension_class: Optional[Any] = None + extension_class: Any | None = None name: str - label: Optional[dict] = None - form_schema: Optional[list] = None + label: dict | None = None + form_schema: list | None = None builtin: bool = True - position: Optional[int] = None + position: int | None = None class Extensible: @@ -32,9 +32,9 @@ class Extensible: name: str tenant_id: str - config: Optional[dict] = None + config: dict | None = None - def __init__(self, tenant_id: str, config: Optional[dict] = None): + def __init__(self, tenant_id: str, config: dict | None = None): self.tenant_id = tenant_id self.config = config diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 45878e763f..564801f189 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -1,5 +1,3 @@ -from typing import Optional - from sqlalchemy import select from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor @@ -39,7 +37,7 @@ class ApiExternalDataTool(ExternalDataTool): if not api_based_extension: raise ValueError("api_based_extension_id is invalid") - def query(self, inputs: dict, query: Optional[str] = None) -> str: + def query(self, inputs: dict, query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index 81f1aaf174..cbec2e4e42 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional from core.extension.extensible import Extensible, ExtensionModule @@ -16,7 +15,7 @@ class ExternalDataTool(Extensible, ABC): variable: str """the tool variable name of app tool""" - def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None): + def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None): super().__init__(tenant_id, config) self.app_id = app_id self.variable = variable @@ -34,7 +33,7 @@ class ExternalDataTool(Extensible, ABC): raise NotImplementedError @abstractmethod - def query(self, inputs: dict, query: Optional[str] = None) -> str: + def query(self, inputs: dict, query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 6a9703a569..86bbb7060c 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -1,7 +1,7 @@ import logging from collections.abc import Mapping from concurrent.futures import Future, ThreadPoolExecutor, as_completed -from typing import Any, Optional +from typing import Any from flask import Flask, current_app @@ -63,7 +63,7 @@ class ExternalDataFetch: external_data_tool: ExternalDataVariableEntity, inputs: Mapping[str, Any], query: str, - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """ Query external data tool. :param flask_app: flask app diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 538bc3f525..6c542d681b 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from core.extension.extensible import ExtensionModule from extensions.ext_code_based_extension import code_based_extension @@ -26,7 +26,7 @@ class ExternalDataToolFactory: # FIXME mypy issue here, figure out how to fix it extension_class.validate_config(tenant_id, config) # type: ignore - def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str: + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/file/models.py b/api/core/file/models.py index 9b74fa387f..dbef7564d6 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field, model_validator @@ -26,7 +26,7 @@ class FileUploadConfig(BaseModel): File Upload Entity. """ - image_config: Optional[ImageConfig] = None + image_config: ImageConfig | None = None allowed_file_types: Sequence[FileType] = Field(default_factory=list) allowed_file_extensions: Sequence[str] = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) @@ -38,21 +38,21 @@ class File(BaseModel): # new and old data formats during serialization and deserialization. dify_model_identity: str = FILE_MODEL_IDENTITY - id: Optional[str] = None # message file id + id: str | None = None # message file id tenant_id: str type: FileType transfer_method: FileTransferMethod # If `transfer_method` is `FileTransferMethod.remote_url`, the # `remote_url` attribute must not be `None`. - remote_url: Optional[str] = None # remote url + remote_url: str | None = None # remote url # If `transfer_method` is `FileTransferMethod.local_file` or # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`. # # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. - related_id: Optional[str] = None - filename: Optional[str] = None - extension: Optional[str] = Field(default=None, description="File extension, should contain dot") - mime_type: Optional[str] = None + related_id: str | None = None + filename: str | None = None + extension: str | None = Field(default=None, description="File extension, should contain dot") + mime_type: str | None = None size: int = -1 # Those properties are private, should not be exposed to the outside. @@ -61,19 +61,19 @@ class File(BaseModel): def __init__( self, *, - id: Optional[str] = None, + id: str | None = None, tenant_id: str, type: FileType, transfer_method: FileTransferMethod, - remote_url: Optional[str] = None, - related_id: Optional[str] = None, - filename: Optional[str] = None, - extension: Optional[str] = None, - mime_type: Optional[str] = None, + remote_url: str | None = None, + related_id: str | None = None, + filename: str | None = None, + extension: str | None = None, + mime_type: str | None = None, size: int = -1, - storage_key: Optional[str] = None, - dify_model_identity: Optional[str] = FILE_MODEL_IDENTITY, - url: Optional[str] = None, + storage_key: str | None = None, + dify_model_identity: str | None = FILE_MODEL_IDENTITY, + url: str | None = None, ): super().__init__( id=id, @@ -108,7 +108,7 @@ class File(BaseModel): return text - def generate_url(self) -> Optional[str]: + def generate_url(self) -> str | None: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.remote_url elif self.transfer_method == FileTransferMethod.LOCAL_FILE: diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 2b580cb373..c44a8e1840 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -2,7 +2,7 @@ import logging from collections.abc import Mapping from enum import StrEnum from threading import Lock -from typing import Any, Optional +from typing import Any from httpx import Timeout, post from pydantic import BaseModel @@ -24,8 +24,8 @@ class CodeExecutionError(Exception): class CodeExecutionResponse(BaseModel): class Data(BaseModel): - stdout: Optional[str] = None - error: Optional[str] = None + stdout: str | None = None + error: str | None = None code: int message: str diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 2f160628ca..00fcfe0b80 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -1,7 +1,6 @@ import json from enum import StrEnum from json import JSONDecodeError -from typing import Optional from extensions.ext_redis import redis_client @@ -16,7 +15,7 @@ class ProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """ Get cached model provider credentials. diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index 26e738fced..ffb5148386 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod from json import JSONDecodeError -from typing import Any, Optional +from typing import Any from extensions.ext_redis import redis_client @@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC): """Generate cache key based on subclass implementation""" pass - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """Get cached provider credentials""" cached_credentials = redis_client.get(self.cache_key) if cached_credentials: @@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache): class NoOpProviderCredentialCache: """No-op provider credential cache""" - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """Get cached provider credentials""" return None diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index c2cc2e4db0..54674d4ff6 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -1,7 +1,6 @@ import json from enum import StrEnum from json import JSONDecodeError -from typing import Optional from extensions.ext_redis import redis_client @@ -19,7 +18,7 @@ class ToolParameterCache: f":identity_id:{identity_id}" ) - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """ Get cached model provider credentials. diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index 35e6e292d1..820502e558 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -1,7 +1,7 @@ import contextlib import re from collections.abc import Mapping -from typing import Any, Optional +from typing import Any def is_valid_trace_id(trace_id: str) -> bool: @@ -13,7 +13,7 @@ def is_valid_trace_id(trace_id: str) -> bool: return bool(re.match(r"^[a-zA-Z0-9\-_]{1,128}$", trace_id)) -def get_external_trace_id(request: Any) -> Optional[str]: +def get_external_trace_id(request: Any) -> str | None: """ Retrieve the trace_id from the request. @@ -61,7 +61,7 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]): return {} -def get_trace_id_from_otel_context() -> Optional[str]: +def get_trace_id_from_otel_context() -> str | None: """ Retrieve the current trace ID from the active OpenTelemetry trace context. Returns None if: @@ -88,7 +88,7 @@ def get_trace_id_from_otel_context() -> Optional[str]: return None -def parse_traceparent_header(traceparent: str) -> Optional[str]: +def parse_traceparent_header(traceparent: str) -> str | None: """ Parse the `traceparent` header to extract the trace_id. diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index a5d7f7aac7..af860a1070 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,5 +1,3 @@ -from typing import Optional - from flask import Flask from pydantic import BaseModel @@ -30,8 +28,8 @@ class FreeHostingQuota(HostingQuota): class HostingProvider(BaseModel): enabled: bool = False - credentials: Optional[dict] = None - quota_unit: Optional[QuotaUnit] = None + credentials: dict | None = None + quota_unit: QuotaUnit | None = None quotas: list[HostingQuota] = [] @@ -42,7 +40,7 @@ class HostedModerationConfig(BaseModel): class HostingConfiguration: provider_map: dict[str, HostingProvider] - moderation_config: Optional[HostedModerationConfig] = None + moderation_config: HostedModerationConfig | None = None def __init__(self): self.provider_map = {} diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index ed02b70b03..94e88b55b9 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,7 +5,7 @@ import re import threading import time import uuid -from typing import Any, Optional +from typing import Any from flask import current_app from sqlalchemy import select @@ -230,9 +230,9 @@ class IndexingRunner: tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, - doc_form: Optional[str] = None, + doc_form: str | None = None, doc_language: str = "English", - dataset_id: Optional[str] = None, + dataset_id: str | None = None, indexing_technique: str = "economy", ) -> IndexingEstimate: """ @@ -421,7 +421,7 @@ class IndexingRunner: max_tokens: int, chunk_overlap: int, separator: str, - embedding_model_instance: Optional[ModelInstance], + embedding_model_instance: ModelInstance | None, ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. @@ -655,7 +655,7 @@ class IndexingRunner: @staticmethod def _update_document_index_status( - document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None + document_id: str, after_indexing_status: str, extra_update_params: dict | None = None ): """ Update the document indexing status. diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index d4c4f10a12..83c727ffe0 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -2,7 +2,7 @@ import json import logging import re from collections.abc import Sequence -from typing import Optional, cast +from typing import cast import json_repair @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) class LLMGenerator: @classmethod def generate_conversation_name( - cls, tenant_id: str, query, conversation_id: Optional[str] = None, app_id: Optional[str] = None + cls, tenant_id: str, query, conversation_id: str | None = None, app_id: str | None = None ): prompt = CONVERSATION_TITLE_PROMPT diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index e0b70c132f..1e302b7668 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -2,7 +2,7 @@ import json from collections.abc import Generator, Mapping, Sequence from copy import deepcopy from enum import StrEnum -from typing import Any, Literal, Optional, cast, overload +from typing import Any, Literal, cast, overload import json_repair from pydantic import TypeAdapter, ValidationError @@ -51,12 +51,12 @@ def invoke_llm_with_structured_output( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: Literal[True], - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... @overload def invoke_llm_with_structured_output( @@ -66,12 +66,12 @@ def invoke_llm_with_structured_output( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: Literal[False], - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput: ... @overload def invoke_llm_with_structured_output( @@ -81,12 +81,12 @@ def invoke_llm_with_structured_output( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... def invoke_llm_with_structured_output( *, @@ -95,12 +95,12 @@ def invoke_llm_with_structured_output( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: """ Invoke large language model with structured output @@ -166,7 +166,7 @@ def invoke_llm_with_structured_output( def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]: result_text: str = "" prompt_messages: Sequence[PromptMessage] = [] - system_fingerprint: Optional[str] = None + system_fingerprint: str | None = None for event in llm_result: if isinstance(event, LLMResultChunk): prompt_messages = event.prompt_messages diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 5626849edf..7d938a8a7d 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -4,7 +4,6 @@ import json import os import secrets import urllib.parse -from typing import Optional from urllib.parse import urljoin, urlparse import httpx @@ -122,7 +121,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: return False, "" -def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]: +def discover_oauth_metadata(server_url: str, protocol_version: str | None = None) -> OAuthMetadata | None: """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" # First check if the server supports OAuth 2.0 Resource Discovery support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url) @@ -152,7 +151,7 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N def start_authorization( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_information: OAuthClientInformation, redirect_url: str, provider_id: str, @@ -207,7 +206,7 @@ def start_authorization( def exchange_authorization( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_information: OAuthClientInformation, authorization_code: str, code_verifier: str, @@ -242,7 +241,7 @@ def exchange_authorization( def refresh_authorization( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_information: OAuthClientInformation, refresh_token: str, ) -> OAuthTokens: @@ -273,7 +272,7 @@ def refresh_authorization( def register_client( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, ) -> OAuthClientInformationFull: """Performs OAuth 2.0 Dynamic Client Registration.""" @@ -297,8 +296,8 @@ def register_client( def auth( provider: OAuthClientProvider, server_url: str, - authorization_code: Optional[str] = None, - state_param: Optional[str] = None, + authorization_code: str | None = None, + state_param: str | None = None, for_list: bool = False, ) -> dict[str, str]: """Orchestrates the full auth flow with a server using secure Redis state storage.""" diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index bf1820f744..3a550eb1b6 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -1,5 +1,3 @@ -from typing import Optional - from configs import dify_config from core.mcp.types import ( OAuthClientInformation, @@ -37,7 +35,7 @@ class OAuthClientProvider: client_uri="https://github.com/langgenius/dify", ) - def client_information(self) -> Optional[OAuthClientInformation]: + def client_information(self) -> OAuthClientInformation | None: """Loads information about this OAuth client.""" client_information = self.mcp_provider.decrypted_credentials.get("client_information", {}) if not client_information: @@ -51,7 +49,7 @@ class OAuthClientProvider: {"client_information": client_information.model_dump()}, ) - def tokens(self) -> Optional[OAuthTokens]: + def tokens(self) -> OAuthTokens | None: """Loads any existing OAuth tokens for the current session.""" credentials = self.mcp_provider.decrypted_credentials if not credentials: diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index 1012dc2810..86ec2c4db9 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -2,7 +2,7 @@ import logging from collections.abc import Callable from contextlib import AbstractContextManager, ExitStack from types import TracebackType -from typing import Any, Optional +from typing import Any from urllib.parse import urlparse from core.mcp.client.sse_client import sse_client @@ -21,11 +21,11 @@ class MCPClient: provider_id: str, tenant_id: str, authed: bool = True, - authorization_code: Optional[str] = None, + authorization_code: str | None = None, for_list: bool = False, - headers: Optional[dict[str, str]] = None, - timeout: Optional[float] = None, - sse_read_timeout: Optional[float] = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + sse_read_timeout: float | None = None, ): # Initialize info self.provider_id = provider_id @@ -46,9 +46,9 @@ class MCPClient: self.token = self.provider.tokens() # Initialize session and client objects - self._session: Optional[ClientSession] = None - self._streams_context: Optional[AbstractContextManager[Any]] = None - self._session_context: Optional[ClientSession] = None + self._session: ClientSession | None = None + self._streams_context: AbstractContextManager[Any] | None = None + self._session_context: ClientSession | None = None self._exit_stack = ExitStack() # Whether the client has been initialized @@ -59,9 +59,7 @@ class MCPClient: self._initialized = True return self - def __exit__( - self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType] - ): + def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None): self.cleanup() def _initialize( diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index fbad5576aa..653b3773c0 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -4,7 +4,7 @@ from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from datetime import timedelta from types import TracebackType -from typing import Any, Generic, Optional, Self, TypeVar +from typing import Any, Generic, Self, TypeVar from httpx import HTTPStatusError from pydantic import BaseModel @@ -212,7 +212,7 @@ class BaseSession( request: SendRequestT, result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, - metadata: Optional[MessageMetadata] = None, + metadata: MessageMetadata | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index a2c3157b3b..7399e8a4b6 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -5,7 +5,6 @@ from typing import ( Any, Generic, Literal, - Optional, TypeAlias, TypeVar, ) @@ -1173,45 +1172,45 @@ class SessionMessage: """A message with specific metadata for transport-specific features.""" message: JSONRPCMessage - metadata: Optional[MessageMetadata] = None + metadata: MessageMetadata | None = None class OAuthClientMetadata(BaseModel): client_name: str redirect_uris: list[str] - grant_types: Optional[list[str]] = None - response_types: Optional[list[str]] = None - token_endpoint_auth_method: Optional[str] = None - client_uri: Optional[str] = None - scope: Optional[str] = None + grant_types: list[str] | None = None + response_types: list[str] | None = None + token_endpoint_auth_method: str | None = None + client_uri: str | None = None + scope: str | None = None class OAuthClientInformation(BaseModel): client_id: str - client_secret: Optional[str] = None + client_secret: str | None = None class OAuthClientInformationFull(OAuthClientInformation): client_name: str | None = None redirect_uris: list[str] - scope: Optional[str] = None - grant_types: Optional[list[str]] = None - response_types: Optional[list[str]] = None - token_endpoint_auth_method: Optional[str] = None + scope: str | None = None + grant_types: list[str] | None = None + response_types: list[str] | None = None + token_endpoint_auth_method: str | None = None class OAuthTokens(BaseModel): access_token: str token_type: str - expires_in: Optional[int] = None - refresh_token: Optional[str] = None - scope: Optional[str] = None + expires_in: int | None = None + refresh_token: str | None = None + scope: str | None = None class OAuthMetadata(BaseModel): authorization_endpoint: str token_endpoint: str - registration_endpoint: Optional[str] = None + registration_endpoint: str | None = None response_types_supported: list[str] - grant_types_supported: Optional[list[str]] = None - code_challenge_methods_supported: Optional[list[str]] = None + grant_types_supported: list[str] | None = None + code_challenge_methods_supported: list[str] | None = None diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 1f2525cfed..35af742f2a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from typing import Optional from sqlalchemy import select @@ -96,7 +95,7 @@ class TokenBufferMemory: return AssistantPromptMessage(content=prompt_message_contents) def get_history_prompt_messages( - self, max_token_limit: int = 2000, message_limit: Optional[int] = None + self, max_token_limit: int = 2000, message_limit: int | None = None ) -> Sequence[PromptMessage]: """ Get history prompt messages. @@ -187,7 +186,7 @@ class TokenBufferMemory: human_prefix: str = "Human", ai_prefix: str = "Assistant", max_token_limit: int = 2000, - message_limit: Optional[int] = None, + message_limit: int | None = None, ) -> str: """ Get history prompt text. diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 10df2ad79e..a63e94d59c 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -103,47 +103,47 @@ class ModelInstance: def invoke_llm( self, prompt_messages: Sequence[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: Literal[True] = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Generator: ... @overload def invoke_llm( self, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: Literal[False] = False, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResult: ... @overload def invoke_llm( self, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: ... def invoke_llm( self, prompt_messages: Sequence[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[Sequence[str]] = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -176,7 +176,7 @@ class ModelInstance: ) def get_llm_num_tokens( - self, prompt_messages: Sequence[PromptMessage], tools: Optional[Sequence[PromptMessageTool]] = None + self, prompt_messages: Sequence[PromptMessage], tools: Sequence[PromptMessageTool] | None = None ) -> int: """ Get number of tokens for llm @@ -199,7 +199,7 @@ class ModelInstance: ) def invoke_text_embedding( - self, texts: list[str], user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT + self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT ) -> TextEmbeddingResult: """ Invoke large language model @@ -246,9 +246,9 @@ class ModelInstance: self, query: str, docs: list[str], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -276,7 +276,7 @@ class ModelInstance: ), ) - def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: + def invoke_moderation(self, text: str, user: str | None = None) -> bool: """ Invoke moderation model @@ -297,7 +297,7 @@ class ModelInstance: ), ) - def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: + def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str: """ Invoke large language model @@ -318,7 +318,7 @@ class ModelInstance: ), ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]: """ Invoke large language tts model @@ -397,7 +397,7 @@ class ModelInstance: except Exception as e: raise e - def get_tts_voices(self, language: Optional[str] = None): + def get_tts_voices(self, language: str | None = None): """ Invoke large language tts model voices @@ -470,7 +470,7 @@ class LBModelManager: model_type: ModelType, model: str, load_balancing_configs: list[ModelLoadBalancingConfiguration], - managed_credentials: Optional[dict] = None, + managed_credentials: dict | None = None, ): """ Load balancing model manager @@ -495,7 +495,7 @@ class LBModelManager: else: load_balancing_config.credentials = managed_credentials - def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: + def fetch_next(self) -> ModelLoadBalancingConfiguration | None: """ Get next model load balancing config Strategy: Round Robin diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 5ce4c23dbb..a745a91510 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -31,10 +30,10 @@ class Callback(ABC): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Before invoke callback @@ -60,10 +59,10 @@ class Callback(ABC): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ On new chunk callback @@ -90,10 +89,10 @@ class Callback(ABC): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ After invoke callback @@ -120,10 +119,10 @@ class Callback(ABC): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Invoke error callback @@ -141,7 +140,7 @@ class Callback(ABC): """ raise NotImplementedError() - def print_text(self, text: str, color: Optional[str] = None, end: str = ""): + def print_text(self, text: str, color: str | None = None, end: str = ""): """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(text_to_print, end=end) diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 8411afca92..b366fcc57b 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -2,7 +2,7 @@ import json import logging import sys from collections.abc import Sequence -from typing import Optional, cast +from typing import cast from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -20,10 +20,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Before invoke callback @@ -76,10 +76,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ On new chunk callback @@ -106,10 +106,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ After invoke callback @@ -147,10 +147,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Invoke error callback diff --git a/api/core/model_runtime/entities/common_entities.py b/api/core/model_runtime/entities/common_entities.py index 659ad59bd6..c7353de5af 100644 --- a/api/core/model_runtime/entities/common_entities.py +++ b/api/core/model_runtime/entities/common_entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel @@ -8,7 +6,7 @@ class I18nObject(BaseModel): Model class for i18n object. """ - zh_Hans: Optional[str] = None + zh_Hans: str | None = None en_US: str def __init__(self, **data): diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index d5caddb7a3..17f6000d93 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from decimal import Decimal from enum import StrEnum -from typing import Any, Optional, TypedDict, Union +from typing import Any, TypedDict, Union from pydantic import BaseModel, Field @@ -150,13 +150,13 @@ class LLMResult(BaseModel): Model class for llm result. """ - id: Optional[str] = None + id: str | None = None model: str prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) message: AssistantPromptMessage usage: LLMUsage - system_fingerprint: Optional[str] = None - reasoning_content: Optional[str] = None + system_fingerprint: str | None = None + reasoning_content: str | None = None class LLMStructuredOutput(BaseModel): @@ -164,7 +164,7 @@ class LLMStructuredOutput(BaseModel): Model class for llm structured output. """ - structured_output: Optional[Mapping[str, Any]] = None + structured_output: Mapping[str, Any] | None = None class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput): @@ -180,8 +180,8 @@ class LLMResultChunkDelta(BaseModel): index: int message: AssistantPromptMessage - usage: Optional[LLMUsage] = None - finish_reason: Optional[str] = None + usage: LLMUsage | None = None + finish_reason: str | None = None class LLMResultChunk(BaseModel): @@ -191,7 +191,7 @@ class LLMResultChunk(BaseModel): model: str prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - system_fingerprint: Optional[str] = None + system_fingerprint: str | None = None delta: LLMResultChunkDelta diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 75ffe7c32f..9235c881e0 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,7 +1,7 @@ from abc import ABC from collections.abc import Mapping, Sequence from enum import StrEnum, auto -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, Union from pydantic import BaseModel, Field, field_serializer, field_validator @@ -146,8 +146,8 @@ class PromptMessage(ABC, BaseModel): """ role: PromptMessageRole - content: Optional[str | list[PromptMessageContentUnionTypes]] = None - name: Optional[str] = None + content: str | list[PromptMessageContentUnionTypes] | None = None + name: str | None = None def is_empty(self) -> bool: """ @@ -193,8 +193,8 @@ class PromptMessage(ABC, BaseModel): @field_serializer("content") def serialize_content( - self, content: Optional[Union[str, Sequence[PromptMessageContent]]] - ) -> Optional[str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent]]: + self, content: Union[str, Sequence[PromptMessageContent]] | None + ) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None: if content is None or isinstance(content, str): return content if isinstance(content, list): diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 8259335c50..aee6ce1108 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,6 +1,6 @@ from decimal import Decimal from enum import StrEnum, auto -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, model_validator @@ -154,7 +154,7 @@ class ProviderModel(BaseModel): model: str label: I18nObject model_type: ModelType - features: Optional[list[ModelFeature]] = None + features: list[ModelFeature] | None = None fetch_from: FetchFrom model_properties: dict[ModelPropertyKey, Any] deprecated: bool = False @@ -171,15 +171,15 @@ class ParameterRule(BaseModel): """ name: str - use_template: Optional[str] = None + use_template: str | None = None label: I18nObject type: ParameterType - help: Optional[I18nObject] = None + help: I18nObject | None = None required: bool = False - default: Optional[Any] = None - min: Optional[float] = None - max: Optional[float] = None - precision: Optional[int] = None + default: Any | None = None + min: float | None = None + max: float | None = None + precision: int | None = None options: list[str] = [] @@ -189,7 +189,7 @@ class PriceConfig(BaseModel): """ input: Decimal - output: Optional[Decimal] = None + output: Decimal | None = None unit: Decimal currency: str @@ -200,7 +200,7 @@ class AIModelEntity(ProviderModel): """ parameter_rules: list[ParameterRule] = [] - pricing: Optional[PriceConfig] = None + pricing: PriceConfig | None = None @model_validator(mode="after") def validate_model(self): diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index 451c2359b3..2ccc9e0eae 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,6 +1,5 @@ from collections.abc import Sequence from enum import Enum, StrEnum, auto -from typing import Optional from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -62,9 +61,9 @@ class CredentialFormSchema(BaseModel): label: I18nObject type: FormType required: bool = True - default: Optional[str] = None - options: Optional[list[FormOption]] = None - placeholder: Optional[I18nObject] = None + default: str | None = None + options: list[FormOption] | None = None + placeholder: I18nObject | None = None max_length: int = 0 show_on: list[FormShowOnObject] = [] @@ -79,7 +78,7 @@ class ProviderCredentialSchema(BaseModel): class FieldModelSchema(BaseModel): label: I18nObject - placeholder: Optional[I18nObject] = None + placeholder: I18nObject | None = None class ModelCredentialSchema(BaseModel): @@ -98,8 +97,8 @@ class SimpleProviderEntity(BaseModel): provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] models: list[AIModelEntity] = [] @@ -120,24 +119,24 @@ class ProviderEntity(BaseModel): provider: str label: I18nObject - description: Optional[I18nObject] = None - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None - icon_small_dark: Optional[I18nObject] = None - icon_large_dark: Optional[I18nObject] = None - background: Optional[str] = None - help: Optional[ProviderHelpEntity] = None + description: I18nObject | None = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None + icon_small_dark: I18nObject | None = None + icon_large_dark: I18nObject | None = None + background: str | None = None + help: ProviderHelpEntity | None = None supported_model_types: Sequence[ModelType] configurate_methods: list[ConfigurateMethod] models: list[AIModelEntity] = Field(default_factory=list) - provider_credential_schema: Optional[ProviderCredentialSchema] = None - model_credential_schema: Optional[ModelCredentialSchema] = None + provider_credential_schema: ProviderCredentialSchema | None = None + model_credential_schema: ModelCredentialSchema | None = None # pydantic configs model_config = ConfigDict(protected_namespaces=()) # position from plugin _position.yaml - position: Optional[dict[str, list[str]]] = {} + position: dict[str, list[str]] | None = {} @field_validator("models", mode="before") @classmethod diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index 6bcb707684..80cf01fb6c 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -1,12 +1,9 @@ -from typing import Optional - - class InvokeError(ValueError): """Base class for all LLM exceptions.""" - description: Optional[str] = None + description: str | None = None - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description def __str__(self): diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index f41818e270..a3d743c373 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,7 +1,6 @@ import decimal import hashlib from threading import Lock -from typing import Optional from pydantic import BaseModel, ConfigDict, Field @@ -99,7 +98,7 @@ class AIModel(BaseModel): model_schema = self.get_model_schema(model, credentials) # get price info from predefined model schema - price_config: Optional[PriceConfig] = None + price_config: PriceConfig | None = None if model_schema and model_schema.pricing: price_config = model_schema.pricing @@ -132,7 +131,7 @@ class AIModel(BaseModel): currency=price_config.currency, ) - def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: + def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None: """ Get model schema by model name and credentials @@ -171,7 +170,7 @@ class AIModel(BaseModel): return schema - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema from credentials @@ -229,7 +228,7 @@ class AIModel(BaseModel): return schema - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 1d7fd7d447..80dabffa10 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Sequence -from typing import Optional, Union +from typing import Union from pydantic import ConfigDict @@ -94,12 +94,12 @@ class LargeLanguageModel(AIModel): model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: """ Invoke large language model @@ -243,11 +243,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunk, None, None]: """ Invoke result generator @@ -328,7 +328,7 @@ class LargeLanguageModel(AIModel): model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> int: """ Get number of tokens for given prompt messages @@ -403,11 +403,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger before invoke callbacks @@ -451,11 +451,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger new chunk callbacks @@ -498,11 +498,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger after invoke callbacks @@ -548,11 +548,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger invoke error callbacks diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py index 19dc1d599a..c3ce6f17ad 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/core/model_runtime/model_providers/__base/moderation_model.py @@ -1,5 +1,4 @@ import time -from typing import Optional from pydantic import ConfigDict @@ -18,7 +17,7 @@ class ModerationModel(AIModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: + def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool: """ Invoke moderation model diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 569e756a3b..81a434405f 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -19,9 +17,9 @@ class RerankModel(AIModel): credentials: dict, query: str, docs: list[str], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> RerankResult: """ Invoke rerank model diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py index c69f65b681..57d7ccf350 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py @@ -1,4 +1,4 @@ -from typing import IO, Optional +from typing import IO from pydantic import ConfigDict @@ -17,7 +17,7 @@ class Speech2TextModel(AIModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str: """ Invoke speech to text model diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 3ce438955a..8b335c4951 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import ConfigDict from core.entities.embedding_type import EmbeddingInputType @@ -24,7 +22,7 @@ class TextEmbeddingModel(AIModel): model: str, credentials: dict, texts: list[str], - user: Optional[str] = None, + user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> TextEmbeddingResult: """ diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py index 8f8a638af6..23d36c03af 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py @@ -1,10 +1,10 @@ import logging from threading import Lock -from typing import Any, Optional +from typing import Any logger = logging.getLogger(__name__) -_tokenizer: Optional[Any] = None +_tokenizer: Any | None = None _lock = Lock() diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 9ee29f2f2f..ca391162a0 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -1,6 +1,5 @@ import logging from collections.abc import Iterable -from typing import Optional from pydantic import ConfigDict @@ -28,7 +27,7 @@ class TTSModel(AIModel): credentials: dict, content_text: str, voice: str, - user: Optional[str] = None, + user: str | None = None, ) -> Iterable[bytes]: """ Invoke large language model @@ -56,7 +55,7 @@ class TTSModel(AIModel): except Exception as e: raise self._transform_invoke_error(e) - def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None): + def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None): """ Retrieves the list of voices supported by a given text-to-speech (TTS) model. diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 35c1dfb144..2434425933 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -2,7 +2,6 @@ import hashlib import logging from collections.abc import Sequence from threading import Lock -from typing import Optional import contexts from core.model_runtime.entities.model_entities import AIModelEntity, ModelType @@ -206,9 +205,9 @@ class ModelProviderFactory: def get_models( self, *, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - provider_configs: Optional[list[ProviderConfig]] = None, + provider: str | None = None, + model_type: ModelType | None = None, + provider_configs: list[ProviderConfig] | None = None, ) -> list[SimpleProviderEntity]: """ Get all models for given model type diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index f65339fbfc..c758eaf49f 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -8,7 +8,7 @@ from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6 from pathlib import Path, PurePath from re import Pattern from types import GeneratorType -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Union from uuid import UUID from pydantic import BaseModel @@ -98,7 +98,7 @@ def jsonable_encoder( exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, + custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, sqlalchemy_safe: bool = True, ) -> Any: custom_encoder = custom_encoder or {} diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index ce7bd21110..573f4ec2a7 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field from sqlalchemy import select @@ -87,7 +85,7 @@ class ApiModeration(Moderation): return result @staticmethod - def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: + def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension | None: stmt = select(APIBasedExtension).where( APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id ) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 340483c894..d76b4689be 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from enum import StrEnum, auto -from typing import Optional from pydantic import BaseModel, Field @@ -34,7 +33,7 @@ class Moderation(Extensible, ABC): module: ExtensionModule = ExtensionModule.MODERATION - def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None): + def __init__(self, app_id: str, tenant_id: str, config: dict | None = None): super().__init__(tenant_id, config) self.app_id = app_id diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 3ac33966cb..21dc58f16f 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationError @@ -21,7 +21,7 @@ class InputModeration: inputs: Mapping[str, Any], query: str, message_id: str, - trace_manager: Optional[TraceQueueManager] = None, + trace_manager: TraceQueueManager | None = None, ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 6993ec8b0b..a97e3d4253 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -1,7 +1,7 @@ import logging import threading import time -from typing import Any, Optional +from typing import Any from flask import Flask, current_app from pydantic import BaseModel, ConfigDict @@ -27,11 +27,11 @@ class OutputModeration(BaseModel): rule: ModerationRule queue_manager: AppQueueManager - thread: Optional[threading.Thread] = None + thread: threading.Thread | None = None thread_running: bool = True buffer: str = "" is_final_chunk: bool = False - final_output: Optional[str] = None + final_output: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) def should_direct_output(self) -> bool: @@ -127,7 +127,7 @@ class OutputModeration(BaseModel): if result.action == ModerationAction.DIRECT_OUTPUT: break - def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: + def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> ModerationOutputsResult | None: try: moderation_factory = ModerationFactory( name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index c661050637..db8684139d 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,7 +1,6 @@ import json import logging from collections.abc import Sequence -from typing import Optional from urllib.parse import urljoin from opentelemetry.trace import Link, Status, StatusCode @@ -123,7 +122,7 @@ class AliyunDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index 881ec2141c..09cb6e3fc1 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -7,7 +7,6 @@ import uuid from collections import deque from collections.abc import Sequence from datetime import datetime -from typing import Optional import requests from opentelemetry import trace as trace_api @@ -184,7 +183,7 @@ def generate_span_id() -> int: return span_id -def convert_to_trace_id(uuid_v4: Optional[str]) -> int: +def convert_to_trace_id(uuid_v4: str | None) -> int: try: uuid_obj = uuid.UUID(uuid_v4) return uuid_obj.int @@ -192,7 +191,7 @@ def convert_to_trace_id(uuid_v4: Optional[str]) -> int: raise ValueError(f"Invalid UUID input: {e}") -def convert_string_to_id(string: Optional[str]) -> int: +def convert_string_to_id(string: str | None) -> int: if not string: return generate_span_id() hash_bytes = hashlib.sha256(string.encode("utf-8")).digest() @@ -200,7 +199,7 @@ def convert_string_to_id(string: Optional[str]) -> int: return id -def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int: +def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int: try: uuid_obj = uuid.UUID(uuid_v4) except Exception as e: @@ -209,7 +208,7 @@ def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int: return convert_string_to_id(combined_key) -def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]: +def convert_datetime_to_nanoseconds(start_time_a: datetime | None) -> int | None: if start_time_a is None: return None timestamp_in_seconds = start_time_a.timestamp() diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py index 1caa822cd0..f3dcbc5b8f 100644 --- a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py +++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from typing import Optional from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import Event, Status, StatusCode @@ -10,12 +9,12 @@ class SpanData(BaseModel): model_config = {"arbitrary_types_allowed": True} trace_id: int = Field(..., description="The unique identifier for the trace.") - parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.") + parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.") span_id: int = Field(..., description="The unique identifier for this span.") name: str = Field(..., description="The name of the span.") attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.") events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.") links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.") status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.") - start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.") - end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.") + start_time: int | None = Field(..., description="The start time of the span in nanoseconds.") + end_time: int | None = Field(..., description="The end time of the span in nanoseconds.") diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index c5fbe4d78b..1497bc1863 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -3,7 +3,7 @@ import json import logging import os from datetime import datetime, timedelta -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from urllib.parse import urlparse from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes @@ -92,14 +92,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra raise -def datetime_to_nanos(dt: Optional[datetime]) -> int: +def datetime_to_nanos(dt: datetime | None) -> int: """Convert datetime to nanoseconds since epoch. If None, use current time.""" if dt is None: dt = datetime.now() return int(dt.timestamp() * 1_000_000_000) -def string_to_trace_id128(string: Optional[str]) -> int: +def string_to_trace_id128(string: str | None) -> int: """ Convert any input string into a stable 128-bit integer trace ID. @@ -284,7 +284,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): return file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data if message_file_data is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -308,7 +308,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): # Add end user data if available if trace_info.message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() ) if end_user_data is not None: diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 1870da3781..d6f8164590 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,20 +1,20 @@ from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, ConfigDict, field_serializer, field_validator class BaseTraceInfo(BaseModel): - message_id: Optional[str] = None - message_data: Optional[Any] = None - inputs: Optional[Union[str, dict[str, Any], list]] = None - outputs: Optional[Union[str, dict[str, Any], list]] = None - start_time: Optional[datetime] = None - end_time: Optional[datetime] = None + message_id: str | None = None + message_data: Any | None = None + inputs: Union[str, dict[str, Any], list] | None = None + outputs: Union[str, dict[str, Any], list] | None = None + start_time: datetime | None = None + end_time: datetime | None = None metadata: dict[str, Any] - trace_id: Optional[str] = None + trace_id: str | None = None @field_validator("inputs", "outputs") @classmethod @@ -36,8 +36,8 @@ class BaseTraceInfo(BaseModel): class WorkflowTraceInfo(BaseTraceInfo): workflow_data: Any = None - conversation_id: Optional[str] = None - workflow_app_log_id: Optional[str] = None + conversation_id: str | None = None + workflow_app_log_id: str | None = None workflow_id: str tenant_id: str workflow_run_id: str @@ -46,7 +46,7 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_inputs: Mapping[str, Any] workflow_run_outputs: Mapping[str, Any] workflow_run_version: str - error: Optional[str] = None + error: str | None = None total_tokens: int file_list: list[str] query: str @@ -58,9 +58,9 @@ class MessageTraceInfo(BaseTraceInfo): message_tokens: int answer_tokens: int total_tokens: int - error: Optional[str] = None - file_list: Optional[Union[str, dict[str, Any], list]] = None - message_file_data: Optional[Any] = None + error: str | None = None + file_list: Union[str, dict[str, Any], list] | None = None + message_file_data: Any | None = None conversation_mode: str @@ -73,17 +73,17 @@ class ModerationTraceInfo(BaseTraceInfo): class SuggestedQuestionTraceInfo(BaseTraceInfo): total_tokens: int - status: Optional[str] = None - error: Optional[str] = None - from_account_id: Optional[str] = None - agent_based: Optional[bool] = None - from_source: Optional[str] = None - model_provider: Optional[str] = None - model_id: Optional[str] = None + status: str | None = None + error: str | None = None + from_account_id: str | None = None + agent_based: bool | None = None + from_source: str | None = None + model_provider: str | None = None + model_id: str | None = None suggested_question: list[str] level: str - status_message: Optional[str] = None - workflow_run_id: Optional[str] = None + status_message: str | None = None + workflow_run_id: str | None = None model_config = ConfigDict(protected_namespaces=()) @@ -98,7 +98,7 @@ class ToolTraceInfo(BaseTraceInfo): tool_outputs: str metadata: dict[str, Any] message_file_data: Any = None - error: Optional[str] = None + error: str | None = None tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] @@ -106,7 +106,7 @@ class ToolTraceInfo(BaseTraceInfo): class GenerateNameTraceInfo(BaseTraceInfo): - conversation_id: Optional[str] = None + conversation_id: str | None = None tenant_id: str diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index 46ba1c45b9..312c7d3676 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic_core.core_schema import ValidationInfo @@ -52,50 +52,50 @@ class LangfuseTrace(BaseModel): Langfuse trace model """ - id: Optional[str] = Field( + id: str | None = Field( default=None, description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems " "or when creating a distributed trace. Traces are upserted on id.", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Identifier of the trace. Useful for sorting/filtering in the UI.", ) - input: Optional[Union[str, dict[str, Any], list, None]] = Field( + input: Union[str, dict[str, Any], list, None] | None = Field( default=None, description="The input of the trace. Can be any JSON object." ) - output: Optional[Union[str, dict[str, Any], list, None]] = Field( + output: Union[str, dict[str, Any], list, None] | None = Field( default=None, description="The output of the trace. Can be any JSON object." ) - metadata: Optional[dict[str, Any]] = Field( + metadata: dict[str, Any] | None = Field( default=None, description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated " "via the API.", ) - user_id: Optional[str] = Field( + user_id: str | None = Field( default=None, description="The id of the user that triggered the execution. Used to provide user-level analytics.", ) - session_id: Optional[str] = Field( + session_id: str | None = Field( default=None, description="Used to group multiple traces into a session in Langfuse. Use your own session/thread identifier.", ) - version: Optional[str] = Field( + version: str | None = Field( default=None, description="The version of the trace type. Used to understand how changes to the trace type affect metrics. " "Useful in debugging.", ) - release: Optional[str] = Field( + release: str | None = Field( default=None, description="The release identifier of the current deployment. Used to understand how changes of different " "deployments affect metrics. Useful in debugging.", ) - tags: Optional[list[str]] = Field( + tags: list[str] | None = Field( default=None, description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET " "API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.", ) - public: Optional[bool] = Field( + public: bool | None = Field( default=None, description="You can make a trace public to share it via a public link. This allows others to view the trace " "without needing to log in or be members of your Langfuse project.", @@ -113,61 +113,61 @@ class LangfuseSpan(BaseModel): Langfuse span model """ - id: Optional[str] = Field( + id: str | None = Field( default=None, description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.", ) - session_id: Optional[str] = Field( + session_id: str | None = Field( default=None, description="Used to group multiple spans into a session in Langfuse. Use your own session/thread identifier.", ) - trace_id: Optional[str] = Field( + trace_id: str | None = Field( default=None, description="The id of the trace the span belongs to. Used to link spans to traces.", ) - user_id: Optional[str] = Field( + user_id: str | None = Field( default=None, description="The id of the user that triggered the execution. Used to provide user-level analytics.", ) - start_time: Optional[datetime | str] = Field( + start_time: datetime | str | None = Field( default_factory=datetime.now, description="The time at which the span started, defaults to the current time.", ) - end_time: Optional[datetime | str] = Field( + end_time: datetime | str | None = Field( default=None, description="The time at which the span ended. Automatically set by span.end().", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Identifier of the span. Useful for sorting/filtering in the UI.", ) - metadata: Optional[dict[str, Any]] = Field( + metadata: dict[str, Any] | None = Field( default=None, description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated " "via the API.", ) - level: Optional[str] = Field( + level: str | None = Field( default=None, description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of " "traces with elevated error levels and for highlighting in the UI.", ) - status_message: Optional[str] = Field( + status_message: str | None = Field( default=None, description="The status message of the span. Additional field for context of the event. E.g. the error " "message of an error event.", ) - input: Optional[Union[str, Mapping[str, Any], list, None]] = Field( + input: Union[str, Mapping[str, Any], list, None] | None = Field( default=None, description="The input of the span. Can be any JSON object." ) - output: Optional[Union[str, Mapping[str, Any], list, None]] = Field( + output: Union[str, Mapping[str, Any], list, None] | None = Field( default=None, description="The output of the span. Can be any JSON object." ) - version: Optional[str] = Field( + version: str | None = Field( default=None, description="The version of the span type. Used to understand how changes to the span type affect metrics. " "Useful in debugging.", ) - parent_observation_id: Optional[str] = Field( + parent_observation_id: str | None = Field( default=None, description="The id of the observation the span belongs to. Used to link spans to observations.", ) @@ -188,15 +188,15 @@ class UnitEnum(StrEnum): class GenerationUsage(BaseModel): - promptTokens: Optional[int] = None - completionTokens: Optional[int] = None - total: Optional[int] = None - input: Optional[int] = None - output: Optional[int] = None - unit: Optional[UnitEnum] = None - inputCost: Optional[float] = None - outputCost: Optional[float] = None - totalCost: Optional[float] = None + promptTokens: int | None = None + completionTokens: int | None = None + total: int | None = None + input: int | None = None + output: int | None = None + unit: UnitEnum | None = None + inputCost: float | None = None + outputCost: float | None = None + totalCost: float | None = None @field_validator("input", "output") @classmethod @@ -206,69 +206,69 @@ class GenerationUsage(BaseModel): class LangfuseGeneration(BaseModel): - id: Optional[str] = Field( + id: str | None = Field( default=None, description="The id of the generation can be set, defaults to random id.", ) - trace_id: Optional[str] = Field( + trace_id: str | None = Field( default=None, description="The id of the trace the generation belongs to. Used to link generations to traces.", ) - parent_observation_id: Optional[str] = Field( + parent_observation_id: str | None = Field( default=None, description="The id of the observation the generation belongs to. Used to link generations to observations.", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Identifier of the generation. Useful for sorting/filtering in the UI.", ) - start_time: Optional[datetime | str] = Field( + start_time: datetime | str | None = Field( default_factory=datetime.now, description="The time at which the generation started, defaults to the current time.", ) - completion_start_time: Optional[datetime | str] = Field( + completion_start_time: datetime | str | None = Field( default=None, description="The time at which the completion started (streaming). Set it to get latency analytics broken " "down into time until completion started and completion duration.", ) - end_time: Optional[datetime | str] = Field( + end_time: datetime | str | None = Field( default=None, description="The time at which the generation ended. Automatically set by generation.end().", ) - model: Optional[str] = Field(default=None, description="The name of the model used for the generation.") - model_parameters: Optional[dict[str, Any]] = Field( + model: str | None = Field(default=None, description="The name of the model used for the generation.") + model_parameters: dict[str, Any] | None = Field( default=None, description="The parameters of the model used for the generation; can be any key-value pairs.", ) - input: Optional[Any] = Field( + input: Any | None = Field( default=None, description="The prompt used for the generation. Can be any string or JSON object.", ) - output: Optional[Any] = Field( + output: Any | None = Field( default=None, description="The completion generated by the model. Can be any string or JSON object.", ) - usage: Optional[GenerationUsage] = Field( + usage: GenerationUsage | None = Field( default=None, description="The usage object supports the OpenAi structure with tokens and a more generic version with " "detailed costs and units.", ) - metadata: Optional[dict[str, Any]] = Field( + metadata: dict[str, Any] | None = Field( default=None, description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being " "updated via the API.", ) - level: Optional[LevelEnum] = Field( + level: LevelEnum | None = Field( default=None, description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering " "of traces with elevated error levels and for highlighting in the UI.", ) - status_message: Optional[str] = Field( + status_message: str | None = Field( default=None, description="The status message of the generation. Additional field for context of the event. E.g. the error " "message of an error event.", ) - version: Optional[str] = Field( + version: str | None = Field( default=None, description="The version of the generation type. Used to understand how changes to the span type affect " "metrics. Useful in debugging.", diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 3a03d9f4fe..f05e59c28e 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -1,7 +1,6 @@ import logging import os from datetime import datetime, timedelta -from typing import Optional from langfuse import Langfuse # type: ignore from sqlalchemy.orm import sessionmaker @@ -242,7 +241,7 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -399,7 +398,7 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.add_span(langfuse_span_data=name_generation_span_data) - def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None): + def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None): format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {} try: self.langfuse_client.trace(**format_trace_data) @@ -407,7 +406,7 @@ class LangFuseDataTrace(BaseTraceInstance): except Exception as e: raise ValueError(f"LangFuse Failed to create trace: {str(e)}") - def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None): + def add_span(self, langfuse_span_data: LangfuseSpan | None = None): format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} try: self.langfuse_client.span(**format_span_data) @@ -415,12 +414,12 @@ class LangFuseDataTrace(BaseTraceInstance): except Exception as e: raise ValueError(f"LangFuse Failed to create span: {str(e)}") - def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None): + def update_span(self, span, langfuse_span_data: LangfuseSpan | None = None): format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} span.end(**format_span_data) - def add_generation(self, langfuse_generation_data: Optional[LangfuseGeneration] = None): + def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None): format_generation_data = ( filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} ) @@ -430,7 +429,7 @@ class LangFuseDataTrace(BaseTraceInstance): except Exception as e: raise ValueError(f"LangFuse Failed to create generation: {str(e)}") - def update_generation(self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None): + def update_generation(self, generation, langfuse_generation_data: LangfuseGeneration | None = None): format_generation_data = ( filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} ) diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 4fd01136ba..f73ba01c8b 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, Field, field_validator from pydantic_core.core_schema import ValidationInfo @@ -20,36 +20,36 @@ class LangSmithRunType(StrEnum): class LangSmithTokenUsage(BaseModel): - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - total_tokens: Optional[int] = None + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None class LangSmithMultiModel(BaseModel): - file_list: Optional[list[str]] = Field(None, description="List of files") + file_list: list[str] | None = Field(None, description="List of files") class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): - name: Optional[str] = Field(..., description="Name of the run") - inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run") - outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run") + name: str | None = Field(..., description="Name of the run") + inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the run") + outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the run") run_type: LangSmithRunType = Field(..., description="Type of the run") - start_time: Optional[datetime | str] = Field(None, description="Start time of the run") - end_time: Optional[datetime | str] = Field(None, description="End time of the run") - extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") - error: Optional[str] = Field(None, description="Error message of the run") - serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run") - parent_run_id: Optional[str] = Field(None, description="Parent run ID") - events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") - tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") - dotted_order: Optional[str] = Field(None, description="Dotted order of the run") - id: Optional[str] = Field(None, description="ID of the run") - session_id: Optional[str] = Field(None, description="Session ID associated with the run") - session_name: Optional[str] = Field(None, description="Session name associated with the run") - reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") - input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") - output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") + start_time: datetime | str | None = Field(None, description="Start time of the run") + end_time: datetime | str | None = Field(None, description="End time of the run") + extra: dict[str, Any] | None = Field(None, description="Extra information of the run") + error: str | None = Field(None, description="Error message of the run") + serialized: dict[str, Any] | None = Field(None, description="Serialized data of the run") + parent_run_id: str | None = Field(None, description="Parent run ID") + events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run") + tags: list[str] | None = Field(None, description="Tags associated with the run") + trace_id: str | None = Field(None, description="Trace ID associated with the run") + dotted_order: str | None = Field(None, description="Dotted order of the run") + id: str | None = Field(None, description="ID of the run") + session_id: str | None = Field(None, description="Session ID associated with the run") + session_name: str | None = Field(None, description="Session name associated with the run") + reference_example_id: str | None = Field(None, description="Reference example ID associated with the run") + input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run") + output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run") @field_validator("inputs", "outputs") @classmethod @@ -128,15 +128,15 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): class LangSmithRunUpdateModel(BaseModel): run_id: str = Field(..., description="ID of the run") - trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") - dotted_order: Optional[str] = Field(None, description="Dotted order of the run") - parent_run_id: Optional[str] = Field(None, description="Parent run ID") - end_time: Optional[datetime | str] = Field(None, description="End time of the run") - error: Optional[str] = Field(None, description="Error message of the run") - inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run") - outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run") - events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") - tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") - input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") - output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") + trace_id: str | None = Field(None, description="Trace ID associated with the run") + dotted_order: str | None = Field(None, description="Dotted order of the run") + parent_run_id: str | None = Field(None, description="Parent run ID") + end_time: datetime | str | None = Field(None, description="End time of the run") + error: str | None = Field(None, description="Error message of the run") + inputs: dict[str, Any] | None = Field(None, description="Inputs of the run") + outputs: dict[str, Any] | None = Field(None, description="Outputs of the run") + events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run") + tags: list[str] | None = Field(None, description="Tags associated with the run") + extra: dict[str, Any] | None = Field(None, description="Extra information of the run") + input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run") + output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index f9e5128e89..ff8a346d27 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -2,7 +2,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import Optional, cast +from typing import cast from langsmith import Client from langsmith.schemas import RunBase @@ -247,7 +247,7 @@ class LangSmithDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): # get message file data file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) metadata = trace_info.metadata @@ -260,7 +260,7 @@ class LangSmithDataTrace(BaseTraceInstance): metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index dd6a424ddb..20e880e940 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -2,7 +2,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import Optional, cast +from typing import cast from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 @@ -47,7 +47,7 @@ def wrap_metadata(metadata, **kwargs): return metadata -def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]): +def prepare_opik_uuid(user_datetime: datetime | None, user_uuid: str | None): """Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most messages and objects. The type-hints of BaseTraceInfo indicates that objects start_time and message_id could be null which means we cannot map @@ -264,7 +264,7 @@ class OpikDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): # get message file data file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data if message_file_data is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -282,7 +282,7 @@ class OpikDataTrace(BaseTraceInstance): metadata["file_list"] = file_list if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index a2f1969bc8..aeed928581 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -5,7 +5,7 @@ import queue import threading import time from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Union from uuid import UUID, uuid4 from cachetools import LRUCache @@ -218,7 +218,7 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: Optional[TraceAppConfig] = ( + trace_config_data: TraceAppConfig | None = ( db.session.query(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -242,7 +242,7 @@ class OpsTraceManager: @classmethod def get_ops_trace_instance( cls, - app_id: Optional[Union[UUID, str]] = None, + app_id: Union[UUID, str] | None = None, ): """ Get ops trace through model config @@ -255,7 +255,7 @@ class OpsTraceManager: if app_id is None: return None - app: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.query(App).where(App.id == app_id).first() if app is None: return None @@ -329,7 +329,7 @@ class OpsTraceManager: except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app_config: App | None = db.session.query(App).where(App.id == app_id).first() if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -347,7 +347,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.query(App).where(App.id == app_id).first() if not app: raise ValueError("App not found") if not app.tracing: @@ -405,11 +405,11 @@ class TraceTask: def __init__( self, trace_type: Any, - message_id: Optional[str] = None, - workflow_execution: Optional[WorkflowExecution] = None, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - timer: Optional[Any] = None, + message_id: str | None = None, + workflow_execution: WorkflowExecution | None = None, + conversation_id: str | None = None, + user_id: str | None = None, + timer: Any | None = None, **kwargs, ): self.trace_type = trace_type @@ -823,7 +823,7 @@ class TraceTask: return generate_name_trace_info -trace_manager_timer: Optional[threading.Timer] = None +trace_manager_timer: threading.Timer | None = None trace_manager_queue: queue.Queue = queue.Queue() trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 2c0afb1600..5e8651d6f9 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from datetime import datetime -from typing import Optional, Union +from typing import Union from urllib.parse import urlparse from sqlalchemy import select @@ -49,9 +49,7 @@ def replace_text_with_content(data): return data -def generate_dotted_order( - run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None -) -> str: +def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_dotted_order: str | None = None) -> str: """ generate dotted_order for langsmith """ diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index 7f489f37ac..ef1a3be45b 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, Field, field_validator from pydantic_core.core_schema import ValidationInfo @@ -8,24 +8,24 @@ from core.ops.utils import replace_text_with_content class WeaveTokenUsage(BaseModel): - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - total_tokens: Optional[int] = None + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None class WeaveMultiModel(BaseModel): - file_list: Optional[list[str]] = Field(None, description="List of files") + file_list: list[str] | None = Field(None, description="List of files") class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): id: str = Field(..., description="ID of the trace") op: str = Field(..., description="Name of the operation") - inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace") - outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace") - attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( + inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the trace") + outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the trace") + attributes: Union[str, dict[str, Any], list, None] | None = Field( None, description="Metadata and attributes associated with trace" ) - exception: Optional[str] = Field(None, description="Exception message of the trace") + exception: str | None = Field(None, description="Exception message of the trace") @field_validator("inputs", "outputs") @classmethod diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 8eb94cc679..c8532c63be 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -2,7 +2,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import Any, Optional, cast +from typing import Any, cast import wandb import weave @@ -223,7 +223,7 @@ class WeaveDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): # get message file data file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) attributes = trace_info.metadata @@ -236,7 +236,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -424,7 +424,7 @@ class WeaveDataTrace(BaseTraceInstance): logger.debug("Weave API check failed: %s", str(e)) raise ValueError(f"Weave API check failed: {str(e)}") - def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): + def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None): call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) self.calls[run_data.id] = call if parent_run_id: diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 86cc3839f2..9352a55be0 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping -from typing import Optional, Union +from typing import Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -53,8 +53,8 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): app_id: str, user_id: str, tenant_id: str, - conversation_id: Optional[str], - query: Optional[str], + conversation_id: str | None, + query: str | None, stream: bool, inputs: Mapping, files: list[dict], diff --git a/api/core/plugin/backwards_invocation/base.py b/api/core/plugin/backwards_invocation/base.py index 2a5f857576..a89b0f95be 100644 --- a/api/core/plugin/backwards_invocation/base.py +++ b/api/core/plugin/backwards_invocation/base.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar from pydantic import BaseModel @@ -23,5 +23,5 @@ T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel) class BaseBackwardsInvocationResponse(BaseModel, Generic[T]): - data: Optional[T] = None + data: T | None = None error: str = "" diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index 06773504d9..c2d1574e67 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.backwards_invocation.base import BaseBackwardsInvocation @@ -23,7 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): provider: str, tool_name: str, tool_parameters: dict[str, Any], - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tool diff --git a/api/core/plugin/entities/endpoint.py b/api/core/plugin/entities/endpoint.py index d7ba75bb4f..e5bca140f8 100644 --- a/api/core/plugin/entities/endpoint.py +++ b/api/core/plugin/entities/endpoint.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional from pydantic import BaseModel, Field, model_validator @@ -24,7 +23,7 @@ class EndpointProviderDeclaration(BaseModel): """ settings: list[ProviderConfig] = Field(default_factory=list) - endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list[EndpointDeclaration]) + endpoints: list[EndpointDeclaration] | None = Field(default_factory=list[EndpointDeclaration]) class EndpointEntity(BasePluginEntity): diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 1c13a621d4..e0762619e6 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field, model_validator from core.model_runtime.entities.provider_entities import ProviderEntity @@ -19,11 +17,11 @@ class MarketplacePluginDeclaration(BaseModel): resource: PluginResourceRequirements = Field( ..., description="Specification of computational resources needed to run the plugin" ) - endpoint: Optional[EndpointProviderDeclaration] = Field( + endpoint: EndpointProviderDeclaration | None = Field( None, description="Configuration for the plugin's API endpoint, if applicable" ) - model: Optional[ProviderEntity] = Field(None, description="Details of the AI model used by the plugin, if any") - tool: Optional[ToolProviderEntity] = Field( + model: ProviderEntity | None = Field(None, description="Details of the AI model used by the plugin, if any") + tool: ToolProviderEntity | None = Field( None, description="Information about the tool functionality provided by the plugin, if any" ) latest_version: str = Field( diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 01bb011ce7..0f7604b368 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -1,6 +1,6 @@ import json from enum import StrEnum, auto -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, Field, field_validator @@ -12,9 +12,7 @@ from core.workflow.nodes.base.entities import NumberType class PluginParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") - icon: Optional[str] = Field( - default=None, description="The icon of the option, can be a url or a base64 encoded image" - ) + icon: str | None = Field(default=None, description="The icon of the option, can be a url or a base64 encoded image") @field_validator("value", mode="before") @classmethod @@ -74,15 +72,15 @@ class PluginParameterTemplate(BaseModel): class PluginParameter(BaseModel): name: str = Field(..., description="The name of the parameter") label: I18nObject = Field(..., description="The label presented to the user") - placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user") + placeholder: I18nObject | None = Field(default=None, description="The placeholder presented to the user") scope: str | None = None - auto_generate: Optional[PluginParameterAutoGenerate] = None - template: Optional[PluginParameterTemplate] = None + auto_generate: PluginParameterAutoGenerate | None = None + template: PluginParameterTemplate | None = None required: bool = False - default: Optional[Union[float, int, str]] = None - min: Optional[Union[float, int]] = None - max: Optional[Union[float, int]] = None - precision: Optional[int] = None + default: Union[float, int, str] | None = None + min: Union[float, int] | None = None + max: Union[float, int] | None = None + precision: int | None = None options: list[PluginParameterOption] = Field(default_factory=list) @field_validator("options", mode="before") diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 261d97c2b6..adc80d1e94 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -2,7 +2,7 @@ import datetime import re from collections.abc import Mapping from enum import StrEnum, auto -from typing import Any, Optional +from typing import Any from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -28,34 +28,34 @@ class PluginResourceRequirements(BaseModel): class Permission(BaseModel): class Tool(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) class Model(BaseModel): - enabled: Optional[bool] = Field(default=False) - llm: Optional[bool] = Field(default=False) - text_embedding: Optional[bool] = Field(default=False) - rerank: Optional[bool] = Field(default=False) - tts: Optional[bool] = Field(default=False) - speech2text: Optional[bool] = Field(default=False) - moderation: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) + llm: bool | None = Field(default=False) + text_embedding: bool | None = Field(default=False) + rerank: bool | None = Field(default=False) + tts: bool | None = Field(default=False) + speech2text: bool | None = Field(default=False) + moderation: bool | None = Field(default=False) class Node(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) class Endpoint(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) class Storage(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) size: int = Field(ge=1024, le=1073741824, default=1048576) - tool: Optional[Tool] = Field(default=None) - model: Optional[Model] = Field(default=None) - node: Optional[Node] = Field(default=None) - endpoint: Optional[Endpoint] = Field(default=None) - storage: Optional[Storage] = Field(default=None) + tool: Tool | None = Field(default=None) + model: Model | None = Field(default=None) + node: Node | None = Field(default=None) + endpoint: Endpoint | None = Field(default=None) + storage: Storage | None = Field(default=None) - permission: Optional[Permission] = Field(default=None) + permission: Permission | None = Field(default=None) class PluginCategory(StrEnum): @@ -67,17 +67,17 @@ class PluginCategory(StrEnum): class PluginDeclaration(BaseModel): class Plugins(BaseModel): - tools: Optional[list[str]] = Field(default_factory=list[str]) - models: Optional[list[str]] = Field(default_factory=list[str]) - endpoints: Optional[list[str]] = Field(default_factory=list[str]) + tools: list[str] | None = Field(default_factory=list[str]) + models: list[str] | None = Field(default_factory=list[str]) + endpoints: list[str] | None = Field(default_factory=list[str]) class Meta(BaseModel): - minimum_dify_version: Optional[str] = Field(default=None) - version: Optional[str] = Field(default=None) + minimum_dify_version: str | None = Field(default=None) + version: str | None = Field(default=None) @field_validator("minimum_dify_version") @classmethod - def validate_minimum_dify_version(cls, v: Optional[str]) -> Optional[str]: + def validate_minimum_dify_version(cls, v: str | None) -> str | None: if v is None: return v try: @@ -87,23 +87,23 @@ class PluginDeclaration(BaseModel): raise ValueError(f"Invalid version format: {v}") from e version: str = Field(...) - author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$") + author: str | None = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$") name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$") description: I18nObject icon: str - icon_dark: Optional[str] = Field(default=None) + icon_dark: str | None = Field(default=None) label: I18nObject category: PluginCategory created_at: datetime.datetime resource: PluginResourceRequirements plugins: Plugins tags: list[str] = Field(default_factory=list) - repo: Optional[str] = Field(default=None) + repo: str | None = Field(default=None) verified: bool = Field(default=False) - tool: Optional[ToolProviderEntity] = None - model: Optional[ProviderEntity] = None - endpoint: Optional[EndpointProviderDeclaration] = None - agent_strategy: Optional[AgentStrategyProviderEntity] = None + tool: ToolProviderEntity | None = None + model: ProviderEntity | None = None + endpoint: EndpointProviderDeclaration | None = None + agent_strategy: AgentStrategyProviderEntity | None = None meta: Meta @field_validator("version") @@ -233,9 +233,9 @@ class PluginDependency(BaseModel): type: Type value: Github | Marketplace | Package - current_identifier: Optional[str] = None + current_identifier: str | None = None class MissingPluginDependency(BaseModel): plugin_unique_identifier: str - current_identifier: Optional[str] = None + current_identifier: str | None = None diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index f1d6860bb4..d6f0dd8121 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Generic, TypeVar from pydantic import BaseModel, ConfigDict, Field @@ -24,7 +24,7 @@ class PluginDaemonBasicResponse(BaseModel, Generic[T]): code: int message: str - data: Optional[T] = None + data: T | None = None class InstallPluginMessage(BaseModel): @@ -174,7 +174,7 @@ class PluginVerification(BaseModel): class PluginDecodeResponse(BaseModel): unique_identifier: str = Field(description="The unique identifier of the plugin.") manifest: PluginDeclaration - verification: Optional[PluginVerification] = Field(default=None, description="Basic verification information") + verification: PluginVerification | None = Field(default=None, description="Basic verification information") class PluginOAuthAuthorizationUrlResponse(BaseModel): diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 3a783dad3e..10f37f75f8 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -35,7 +35,7 @@ class InvokeCredentials(BaseModel): class PluginInvokeContext(BaseModel): - credentials: Optional[InvokeCredentials] = Field( + credentials: InvokeCredentials | None = Field( default_factory=InvokeCredentials, description="Credentials context for the plugin invocation or backward invocation.", ) @@ -50,7 +50,7 @@ class RequestInvokeTool(BaseModel): provider: str tool: str tool_parameters: dict - credential_id: Optional[str] = None + credential_id: str | None = None class BaseRequestInvokeModel(BaseModel): @@ -70,9 +70,9 @@ class RequestInvokeLLM(BaseRequestInvokeModel): mode: str completion_params: dict[str, Any] = Field(default_factory=dict) prompt_messages: list[PromptMessage] = Field(default_factory=list) - tools: Optional[list[PromptMessageTool]] = Field(default_factory=list[PromptMessageTool]) - stop: Optional[list[str]] = Field(default_factory=list[str]) - stream: Optional[bool] = False + tools: list[PromptMessageTool] | None = Field(default_factory=list[PromptMessageTool]) + stop: list[str] | None = Field(default_factory=list[str]) + stream: bool | None = False model_config = ConfigDict(protected_namespaces=()) @@ -194,10 +194,10 @@ class RequestInvokeApp(BaseModel): app_id: str inputs: dict[str, Any] - query: Optional[str] = None + query: str | None = None response_mode: Literal["blocking", "streaming"] - conversation_id: Optional[str] = None - user: Optional[str] = None + conversation_id: str | None = None + user: str | None = None files: list[dict] = Field(default_factory=list) diff --git a/api/core/plugin/impl/agent.py b/api/core/plugin/impl/agent.py index 526f6f2961..0b55f20522 100644 --- a/api/core/plugin/impl/agent.py +++ b/api/core/plugin/impl/agent.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.agent.entities import AgentInvokeMessage from core.plugin.entities.plugin import GenericProviderID @@ -82,10 +82,10 @@ class PluginAgentClient(BasePluginClient): agent_provider: str, agent_strategy: str, agent_params: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - context: Optional[PluginInvokeContext] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + context: PluginInvokeContext | None = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent with the given tenant, user, plugin, provider, name and parameters. diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 85a72d9f82..153da142f4 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -1,6 +1,6 @@ import binascii from collections.abc import Generator, Sequence -from typing import IO, Optional +from typing import IO from core.model_runtime.entities.llm_entities import LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -151,9 +151,9 @@ class PluginModelClient(BasePluginClient): model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, ) -> Generator[LLMResultChunk, None, None]: """ @@ -200,7 +200,7 @@ class PluginModelClient(BasePluginClient): model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> int: """ Get number of tokens for llm @@ -325,8 +325,8 @@ class PluginModelClient(BasePluginClient): credentials: dict, query: str, docs: list[str], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, + score_threshold: float | None = None, + top_n: int | None = None, ) -> RerankResult: """ Invoke rerank @@ -414,7 +414,7 @@ class PluginModelClient(BasePluginClient): provider: str, model: str, credentials: dict, - language: Optional[str] = None, + language: str | None = None, ): """ Get tts model voices diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 7199c0d15a..bb68f4700c 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -81,9 +81,9 @@ class PluginToolManager(BasePluginClient): credentials: dict[str, Any], credential_type: CredentialType, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Invoke the tool with the given tenant, user, plugin, provider, name, credentials and parameters. @@ -153,9 +153,9 @@ class PluginToolManager(BasePluginClient): provider: str, credentials: dict[str, Any], tool: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get the runtime parameters of the tool diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 11c6e5c23b..5f2ffefd94 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Optional, cast +from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import file_manager @@ -41,11 +41,11 @@ class AdvancedPromptTransform(PromptTransform): inputs: Mapping[str, str], query: str, files: Sequence[File], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], + context: str | None, + memory_config: MemoryConfig | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: prompt_messages = [] @@ -80,13 +80,13 @@ class AdvancedPromptTransform(PromptTransform): self, prompt_template: CompletionModelPromptTemplate, inputs: Mapping[str, str], - query: Optional[str], + query: str | None, files: Sequence[File], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], + context: str | None, + memory_config: MemoryConfig | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Get completion model prompt messages. @@ -141,13 +141,13 @@ class AdvancedPromptTransform(PromptTransform): self, prompt_template: list[ChatModelMessage], inputs: Mapping[str, str], - query: Optional[str], + query: str | None, files: Sequence[File], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], + context: str | None, + memory_config: MemoryConfig | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Get chat model prompt messages. diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 09f017a7db..a96b094e6d 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, cast +from typing import cast from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, @@ -23,7 +23,7 @@ class AgentHistoryPromptTransform(PromptTransform): model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage], history_messages: list[PromptMessage], - memory: Optional[TokenBufferMemory] = None, + memory: TokenBufferMemory | None = None, ): self.model_config = model_config self.prompt_messages = prompt_messages diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index c8e7b414df..7094633093 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -12,7 +12,7 @@ class ChatModelMessage(BaseModel): text: str role: PromptMessageRole - edition_type: Optional[Literal["basic", "jinja2"]] = None + edition_type: Literal["basic", "jinja2"] | None = None class CompletionModelPromptTemplate(BaseModel): @@ -21,7 +21,7 @@ class CompletionModelPromptTemplate(BaseModel): """ text: str - edition_type: Optional[Literal["basic", "jinja2"]] = None + edition_type: Literal["basic", "jinja2"] | None = None class MemoryConfig(BaseModel): @@ -43,8 +43,8 @@ class MemoryConfig(BaseModel): """ enabled: bool - size: Optional[int] = None + size: int | None = None - role_prefix: Optional[RolePrefix] = None + role_prefix: RolePrefix | None = None window: WindowConfig - query_prompt_template: Optional[str] = None + query_prompt_template: str | None = None diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 1f040599be..a6e873d587 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -55,8 +55,8 @@ class PromptTransform: memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None, + human_prefix: str | None = None, + ai_prefix: str | None = None, ) -> str: """Get memory messages.""" kwargs: dict[str, Any] = {"max_token_limit": max_token_limit} diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index fdab0af103..d1d518a55d 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -2,7 +2,7 @@ import json import os from collections.abc import Mapping, Sequence from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -45,11 +45,11 @@ class SimplePromptTransform(PromptTransform): inputs: Mapping[str, str], query: str, files: Sequence["File"], - context: Optional[str], - memory: Optional[TokenBufferMemory], + context: str | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: inputs = {key: str(value) for key, value in inputs.items()} model_mode = ModelMode(model_config.mode) @@ -86,9 +86,9 @@ class SimplePromptTransform(PromptTransform): model_config: ModelConfigWithCredentialsEntity, pre_prompt: str, inputs: dict, - query: Optional[str] = None, - context: Optional[str] = None, - histories: Optional[str] = None, + query: str | None = None, + context: str | None = None, + histories: str | None = None, ) -> tuple[str, dict]: # get prompt template prompt_template_config = self.get_prompt_template( @@ -182,12 +182,12 @@ class SimplePromptTransform(PromptTransform): pre_prompt: str, inputs: dict, query: str, - context: Optional[str], + context: str | None, files: Sequence["File"], - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: prompt_messages: list[PromptMessage] = [] # get prompt @@ -228,12 +228,12 @@ class SimplePromptTransform(PromptTransform): pre_prompt: str, inputs: dict, query: str, - context: Optional[str], + context: str | None, files: Sequence["File"], - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: # get prompt prompt, prompt_rules = self._get_prompt_str_and_rules( app_mode=app_mode, @@ -281,7 +281,7 @@ class SimplePromptTransform(PromptTransform): self, prompt: str, files: Sequence["File"], - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> UserPromptMessage: if files: prompt_message_contents: list[PromptMessageContentUnionTypes] = [] diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index e4e8b09a04..082c6c4c50 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,7 +3,7 @@ import json from collections import defaultdict from collections.abc import Sequence from json import JSONDecodeError -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -281,7 +281,7 @@ class ProviderManager: model_type_instance=model_type_instance, ) - def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]: + def get_default_model(self, tenant_id: str, model_type: ModelType) -> DefaultModelEntity | None: """ Get default model. @@ -1036,8 +1036,8 @@ class ProviderManager: def _to_model_settings( self, provider_entity: ProviderEntity, - provider_model_settings: Optional[list[ProviderModelSetting]] = None, - load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None, + provider_model_settings: list[ProviderModelSetting] | None = None, + load_balancing_model_configs: list[LoadBalancingModelConfig] | None = None, ) -> list[ModelSettings]: """ Convert to model settings. diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index d17d76333e..696e3e967f 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError @@ -18,8 +16,8 @@ class DataPostProcessor: self, tenant_id: str, reranking_mode: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, + reranking_model: dict | None = None, + weights: dict | None = None, reorder_enabled: bool = False, ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) @@ -29,9 +27,9 @@ class DataPostProcessor: self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: if self.rerank_runner: documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) @@ -45,9 +43,9 @@ class DataPostProcessor: self, reranking_mode: str, tenant_id: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, - ) -> Optional[BaseRerankRunner]: + reranking_model: dict | None = None, + weights: dict | None = None, + ) -> BaseRerankRunner | None: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: runner = RerankRunnerFactory.create_rerank_runner( runner_type=reranking_mode, @@ -74,12 +72,12 @@ class DataPostProcessor: return runner return None - def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: + def _get_reorder_runner(self, reorder_enabled) -> ReorderRunner | None: if reorder_enabled: return ReorderRunner() return None - def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None: + def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None: if reranking_model: try: model_manager = ModelManager() diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 096f40f707..70690a4c56 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Optional +from typing import Any import orjson from pydantic import BaseModel @@ -143,7 +143,7 @@ class Jieba(BaseKeyword): storage.delete(file_key) storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8")) - def _get_dataset_keyword_table(self) -> Optional[dict]: + def _get_dataset_keyword_table(self) -> dict | None: dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index a6214d955b..81619570f9 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,5 +1,5 @@ import re -from typing import Optional, cast +from typing import cast class JiebaKeywordTableHandler: @@ -10,7 +10,7 @@ class JiebaKeywordTableHandler: jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore - def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: + def extract_keywords(self, text: str, max_keywords_per_chunk: int | None = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" import jieba.analyse # type: ignore diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index fefd42f84d..429744c0de 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,6 +1,5 @@ import concurrent.futures from concurrent.futures import ThreadPoolExecutor -from typing import Optional from flask import Flask, current_app from sqlalchemy import select @@ -39,11 +38,11 @@ class RetrievalService: dataset_id: str, query: str, top_k: int, - score_threshold: Optional[float] = 0.0, - reranking_model: Optional[dict] = None, + score_threshold: float | None = 0.0, + reranking_model: dict | None = None, reranking_mode: str = "reranking_model", - weights: Optional[dict] = None, - document_ids_filter: Optional[list[str]] = None, + weights: dict | None = None, + document_ids_filter: list[str] | None = None, ): if not query: return [] @@ -125,8 +124,8 @@ class RetrievalService: cls, dataset_id: str, query: str, - external_retrieval_model: Optional[dict] = None, - metadata_filtering_conditions: Optional[dict] = None, + external_retrieval_model: dict | None = None, + metadata_filtering_conditions: dict | None = None, ): stmt = select(Dataset).where(Dataset.id == dataset_id) dataset = db.session.scalar(stmt) @@ -145,7 +144,7 @@ class RetrievalService: return all_documents @classmethod - def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: + def _get_dataset(cls, dataset_id: str) -> Dataset | None: with Session(db.engine) as session: return session.query(Dataset).where(Dataset.id == dataset_id).first() @@ -158,7 +157,7 @@ class RetrievalService: top_k: int, all_documents: list, exceptions: list, - document_ids_filter: Optional[list[str]] = None, + document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): try: @@ -182,12 +181,12 @@ class RetrievalService: dataset_id: str, query: str, top_k: int, - score_threshold: Optional[float], - reranking_model: Optional[dict], + score_threshold: float | None, + reranking_model: dict | None, all_documents: list, retrieval_method: str, exceptions: list, - document_ids_filter: Optional[list[str]] = None, + document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): try: @@ -235,12 +234,12 @@ class RetrievalService: dataset_id: str, query: str, top_k: int, - score_threshold: Optional[float], - reranking_model: Optional[dict], + score_threshold: float | None, + reranking_model: dict | None, all_documents: list, retrieval_method: str, exceptions: list, - document_ids_filter: Optional[list[str]] = None, + document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): try: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index c3a6127e4a..77a0fa6cf2 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, model_validator @@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): account: str account_password: str namespace: str = "dify" - namespace_password: Optional[str] = None + namespace_password: str | None = None metrics: str = "cosine" read_timeout: int = 60000 diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index e7128b183e..de1572410c 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional +from typing import Any import chromadb from chromadb import QueryResult, Settings @@ -20,8 +20,8 @@ class ChromaConfig(BaseModel): port: int tenant: str database: str - auth_provider: Optional[str] = None - auth_credentials: Optional[str] = None + auth_provider: str | None = None + auth_credentials: str | None = None def to_chroma_params(self): settings = Settings( diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index eb4cbd2324..7a4f11a453 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -84,7 +84,7 @@ class ClickzettaConnectionPool: self._pool_locks: dict[str, threading.Lock] = {} self._max_pool_size = 5 # Maximum connections per configuration self._connection_timeout = 300 # 5 minutes timeout - self._cleanup_thread: Optional[threading.Thread] = None + self._cleanup_thread: threading.Thread | None = None self._shutdown = False self._start_cleanup_thread() @@ -303,8 +303,8 @@ class ClickzettaVector(BaseVector): """ # Class-level write queue and lock for serializing writes - _write_queue: Optional[queue.Queue] = None - _write_thread: Optional[threading.Thread] = None + _write_queue: queue.Queue | None = None + _write_thread: threading.Thread | None = None _write_lock = threading.Lock() _shutdown = False @@ -328,7 +328,7 @@ class ClickzettaVector(BaseVector): def __init__(self, vector_instance: "ClickzettaVector"): self.vector = vector_instance - self.connection: Optional[Connection] = None + self.connection: Connection | None = None def __enter__(self) -> "Connection": self.connection = self.vector._get_connection() diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py index 7118029d40..7b00928b7b 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any from flask import current_app @@ -22,8 +22,8 @@ class ElasticSearchJaVector(ElasticSearchVector): def create_collection( self, embeddings: list[list[float]], - metadatas: Optional[list[dict[Any, Any]]] = None, - index_params: Optional[dict] = None, + metadatas: list[dict[Any, Any]] | None = None, + index_params: dict | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index df1c747585..2c147fa7ca 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, Optional, cast +from typing import Any, cast from urllib.parse import urlparse import requests @@ -24,18 +24,18 @@ logger = logging.getLogger(__name__) class ElasticSearchConfig(BaseModel): # Regular Elasticsearch config - host: Optional[str] = None - port: Optional[int] = None - username: Optional[str] = None - password: Optional[str] = None + host: str | None = None + port: int | None = None + username: str | None = None + password: str | None = None # Elastic Cloud specific config - cloud_url: Optional[str] = None # Cloud URL for Elasticsearch Cloud - api_key: Optional[str] = None + cloud_url: str | None = None # Cloud URL for Elasticsearch Cloud + api_key: str | None = None # Common config use_cloud: bool = False - ca_certs: Optional[str] = None + ca_certs: str | None = None verify_certs: bool = False request_timeout: int = 100000 retry_on_timeout: bool = True @@ -256,8 +256,8 @@ class ElasticSearchVector(BaseVector): def create_collection( self, embeddings: list[list[float]], - metadatas: Optional[list[dict[Any, Any]]] = None, - index_params: Optional[dict] = None, + metadatas: list[dict[Any, Any]] | None = None, + index_params: dict | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index 0eca37a129..cfee090768 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -1,7 +1,7 @@ import json import logging import ssl -from typing import Any, Optional +from typing import Any from elasticsearch import Elasticsearch from pydantic import BaseModel, model_validator @@ -157,8 +157,8 @@ class HuaweiCloudVector(BaseVector): def create_collection( self, embeddings: list[list[float]], - metadatas: Optional[list[dict[Any, Any]]] = None, - index_params: Optional[dict] = None, + metadatas: list[dict[Any, Any]] | None = None, + index_params: dict | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 5097412c2c..f3ec30d178 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -2,7 +2,7 @@ import copy import json import logging import time -from typing import Any, Optional +from typing import Any from opensearchpy import OpenSearch, helpers from opensearchpy.helpers import BulkIndexError @@ -29,10 +29,10 @@ UGC_INDEX_PREFIX = "ugc_index" class LindormVectorStoreConfig(BaseModel): hosts: str - username: Optional[str] = None - password: Optional[str] = None - using_ugc: Optional[bool] = False - request_timeout: Optional[float] = 1.0 # timeout units: s + username: str | None = None + password: str | None = None + using_ugc: bool | None = False + request_timeout: float | None = 1.0 # timeout units: s @model_validator(mode="before") @classmethod @@ -448,13 +448,13 @@ def default_text_search_query( query_text: str, k: int = 4, text_field: str = Field.CONTENT_KEY.value, - must: Optional[list[dict]] = None, - must_not: Optional[list[dict]] = None, - should: Optional[list[dict]] = None, + must: list[dict] | None = None, + must_not: list[dict] | None = None, + should: list[dict] | None = None, minimum_should_match: int = 0, - filters: Optional[list[dict]] = None, - routing: Optional[str] = None, - routing_field: Optional[str] = None, + filters: list[dict] | None = None, + routing: str | None = None, + routing_field: str | None = None, **kwargs, ): query_clause: dict[str, Any] = {} @@ -505,13 +505,13 @@ def default_vector_search_query( query_vector: list[float], k: int = 4, min_score: str = "0.0", - ef_search: Optional[str] = None, # only for hnsw - nprobe: Optional[str] = None, # "2000" - reorder_factor: Optional[str] = None, # "20" - client_refactor: Optional[str] = None, # "true" + ef_search: str | None = None, # only for hnsw + nprobe: str | None = None, # "2000" + reorder_factor: str | None = None, # "20" + client_refactor: str | None = None, # "true" vector_field: str = Field.VECTOR.value, - filters: Optional[list[dict]] = None, - filter_type: Optional[str] = None, + filters: list[dict] | None = None, + filter_type: str | None = None, **kwargs, ): if filters is not None: diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 3dd073ce50..f6e8fa6d0f 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -3,7 +3,7 @@ import logging import uuid from collections.abc import Callable from functools import wraps -from typing import Any, Concatenate, Optional, ParamSpec, TypeVar +from typing import Any, Concatenate, ParamSpec, TypeVar from mo_vector.client import MoVectorClient # type: ignore from pydantic import BaseModel, model_validator @@ -74,7 +74,7 @@ class MatrixoneVector(BaseVector): self.client = self._get_client(len(embeddings[0]), True) return self.add_texts(texts, embeddings) - def _get_client(self, dimension: Optional[int] = None, create_table: bool = False) -> MoVectorClient: + def _get_client(self, dimension: int | None = None, create_table: bool = False) -> MoVectorClient: """ Create a new client for the collection. diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 2ec48ae365..5f32feb709 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any from packaging import version from pydantic import BaseModel, model_validator @@ -26,13 +26,13 @@ class MilvusConfig(BaseModel): """ uri: str # Milvus server URI - token: Optional[str] = None # Optional token for authentication - user: Optional[str] = None # Username for authentication - password: Optional[str] = None # Password for authentication + token: str | None = None # Optional token for authentication + user: str | None = None # Username for authentication + password: str | None = None # Password for authentication batch_size: int = 100 # Batch size for operations database: str = "default" # Database name enable_hybrid_search: bool = False # Flag to enable hybrid search - analyzer_params: Optional[str] = None # Analyzer params + analyzer_params: str | None = None # Analyzer params @model_validator(mode="before") @classmethod @@ -79,7 +79,7 @@ class MilvusVector(BaseVector): self._load_collection_fields() self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported - def _load_collection_fields(self, fields: Optional[list[str]] = None): + def _load_collection_fields(self, fields: list[str] | None = None): if fields is None: # Load collection fields from remote server collection_info = self._client.describe_collection(self._collection_name) @@ -292,7 +292,7 @@ class MilvusVector(BaseVector): ) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None ): """ Create a new collection in Milvus with the specified schema and index parameters. diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 3f65a4a275..764248e6fa 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Literal, Optional +from typing import Any, Literal from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers @@ -26,10 +26,10 @@ class OpenSearchConfig(BaseModel): secure: bool = False # use_ssl verify_certs: bool = True auth_method: Literal["basic", "aws_managed_iam"] = "basic" - user: Optional[str] = None - password: Optional[str] = None - aws_region: Optional[str] = None - aws_service: Optional[str] = None + user: str | None = None + password: str | None = None + aws_region: str | None = None + aws_service: str | None = None @model_validator(mode="before") @classmethod @@ -236,7 +236,7 @@ class OpenSearchVector(BaseVector): return docs def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None ): lock_name = f"vector_indexing_lock_{self._collection_name.lower()}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index d329220580..d46f29bd64 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import qdrant_client from flask import current_app @@ -46,7 +46,7 @@ class PathQdrantParams(BaseModel): class UrlQdrantParams(BaseModel): url: str - api_key: Optional[str] + api_key: str | None timeout: float verify: bool grpc_port: int @@ -55,9 +55,9 @@ class UrlQdrantParams(BaseModel): class QdrantConfig(BaseModel): endpoint: str - api_key: Optional[str] = None + api_key: str | None = None timeout: float = 20 - root_path: Optional[str] = None + root_path: str | None = None grpc_port: int = 6334 prefer_grpc: bool = False replication_factor: int = 1 @@ -189,10 +189,10 @@ class QdrantVector(BaseVector): self, texts: Iterable[str], embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, batch_size: int = 64, - group_id: Optional[str] = None, + group_id: str | None = None, ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest @@ -234,7 +234,7 @@ class QdrantVector(BaseVector): def _build_payloads( cls, texts: Iterable[str], - metadatas: Optional[list[dict]], + metadatas: list[dict] | None, content_payload_key: str, metadata_payload_key: str, group_id: str, diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 9d3dc7c622..99698fcdd0 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -1,6 +1,6 @@ import json import uuid -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, model_validator from sqlalchemy import Column, String, Table, create_engine, insert @@ -160,7 +160,7 @@ class RelytVector(BaseVector): else: return None - def delete_by_uuids(self, ids: Optional[list[str]] = None): + def delete_by_uuids(self, ids: list[str] | None = None): """Delete by vector IDs. Args: @@ -241,7 +241,7 @@ class RelytVector(BaseVector): self, embedding: list[float], k: int = 4, - filter: Optional[dict] = None, + filter: dict | None = None, ) -> list[tuple[Document, float]]: # Add the filter if provided diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index 27685b7ddf..e91d9bb0d6 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -2,7 +2,7 @@ import json import logging import math from collections.abc import Iterable -from typing import Any, Optional +from typing import Any import tablestore # type: ignore from pydantic import BaseModel, model_validator @@ -22,11 +22,11 @@ logger = logging.getLogger(__name__) class TableStoreConfig(BaseModel): - access_key_id: Optional[str] = None - access_key_secret: Optional[str] = None - instance_name: Optional[str] = None - endpoint: Optional[str] = None - normalize_full_text_bm25_score: Optional[bool] = False + access_key_id: str | None = None + access_key_secret: str | None = None + instance_name: str | None = None + endpoint: str | None = None + normalize_full_text_bm25_score: bool | None = False @model_validator(mode="before") @classmethod diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 2485857070..291d047c04 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, Optional +from typing import Any from pydantic import BaseModel from tcvdb_text.encoder import BM25Encoder # type: ignore @@ -24,10 +24,10 @@ logger = logging.getLogger(__name__) class TencentConfig(BaseModel): url: str - api_key: Optional[str] = None + api_key: str | None = None timeout: float = 30 - username: Optional[str] = None - database: Optional[str] = None + username: str | None = None + database: str | None = None index_type: str = "HNSW" metric_type: str = "IP" shard: int = 1 diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 7055581459..f90a311df4 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import qdrant_client import requests @@ -45,9 +45,9 @@ if TYPE_CHECKING: class TidbOnQdrantConfig(BaseModel): endpoint: str - api_key: Optional[str] = None + api_key: str | None = None timeout: float = 20 - root_path: Optional[str] = None + root_path: str | None = None grpc_port: int = 6334 prefer_grpc: bool = False replication_factor: int = 1 @@ -180,10 +180,10 @@ class TidbOnQdrantVector(BaseVector): self, texts: Iterable[str], embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, batch_size: int = 64, - group_id: Optional[str] = None, + group_id: str | None = None, ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest @@ -225,7 +225,7 @@ class TidbOnQdrantVector(BaseVector): def _build_payloads( cls, texts: Iterable[str], - metadatas: Optional[list[dict]], + metadatas: list[dict] | None, content_payload_key: str, metadata_payload_key: str, group_id: str, diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index b2cc51d034..dc4f026ff3 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,7 +1,7 @@ import logging import time from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from sqlalchemy import select @@ -32,7 +32,7 @@ class AbstractVectorFactory(ABC): class Vector: - def __init__(self, dataset: Dataset, attributes: Optional[list] = None): + def __init__(self, dataset: Dataset, attributes: list | None = None): if attributes is None: attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self._dataset = dataset @@ -180,7 +180,7 @@ class Vector: case _: raise ValueError(f"Vector store {vector_type} is not supported.") - def create(self, texts: Optional[list] = None, **kwargs): + def create(self, texts: list | None = None, **kwargs): if texts: start = time.time() logger.info("start embedding %s texts %s", len(texts), start) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 43dde37c7e..3ec08b93ed 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -1,6 +1,6 @@ import datetime import json -from typing import Any, Optional +from typing import Any import requests import weaviate # type: ignore @@ -19,7 +19,7 @@ from models.dataset import Dataset class WeaviateConfig(BaseModel): endpoint: str - api_key: Optional[str] = None + api_key: str | None = None batch_size: int = 100 @model_validator(mode="before") diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 63c6db8d06..74a2653e9d 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any, Optional +from typing import Any from sqlalchemy import func, select @@ -15,7 +15,7 @@ class DatasetDocumentStore: self, dataset: Dataset, user_id: str, - document_id: Optional[str] = None, + document_id: str | None = None, ): self._dataset = dataset self._user_id = user_id @@ -176,7 +176,7 @@ class DatasetDocumentStore: result = self.get_document_segment(doc_id) return result is not None - def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[Document]: + def get_document(self, doc_id: str, raise_error: bool = True) -> Document | None: document_segment = self.get_document_segment(doc_id) if document_segment is None: @@ -217,16 +217,16 @@ class DatasetDocumentStore: document_segment.index_node_hash = doc_hash db.session.commit() - def get_document_hash(self, doc_id: str) -> Optional[str]: + def get_document_hash(self, doc_id: str) -> str | None: """Get the stored hash for a document, if it exists.""" document_segment = self.get_document_segment(doc_id) if document_segment is None: return None - data: Optional[str] = document_segment.index_node_hash + data: str | None = document_segment.index_node_hash return data - def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: + def get_document_segment(self, doc_id: str) -> DocumentSegment | None: stmt = select(DocumentSegment).where( DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id ) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 43be9cde69..5f94129a0c 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -1,6 +1,6 @@ import base64 import logging -from typing import Any, Optional, cast +from typing import Any, cast import numpy as np from sqlalchemy.exc import IntegrityError @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) class CacheEmbedding(Embeddings): - def __init__(self, model_instance: ModelInstance, user: Optional[str] = None): + def __init__(self, model_instance: ModelInstance, user: str | None = None): self._model_instance = model_instance self._user = user diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index 800422d888..8e92191568 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel from models.dataset import DocumentSegment @@ -19,5 +17,5 @@ class RetrievalSegments(BaseModel): model_config = {"arbitrary_types_allowed": True} segment: DocumentSegment - child_chunks: Optional[list[RetrievalChildChunk]] = None - score: Optional[float] = None + child_chunks: list[RetrievalChildChunk] | None = None + score: float | None = None diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index 00120425c9..aca879df7d 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -1,23 +1,23 @@ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel class RetrievalSourceMetadata(BaseModel): - position: Optional[int] = None - dataset_id: Optional[str] = None - dataset_name: Optional[str] = None - document_id: Optional[str] = None - document_name: Optional[str] = None - data_source_type: Optional[str] = None - segment_id: Optional[str] = None - retriever_from: Optional[str] = None - score: Optional[float] = None - hit_count: Optional[int] = None - word_count: Optional[int] = None - segment_position: Optional[int] = None - index_node_hash: Optional[str] = None - content: Optional[str] = None - page: Optional[int] = None - doc_metadata: Optional[dict[str, Any]] = None - title: Optional[str] = None + position: int | None = None + dataset_id: str | None = None + dataset_name: str | None = None + document_id: str | None = None + document_name: str | None = None + data_source_type: str | None = None + segment_id: str | None = None + retriever_from: str | None = None + score: float | None = None + hit_count: int | None = None + word_count: int | None = None + segment_position: int | None = None + index_node_hash: str | None = None + content: str | None = None + page: int | None = None + doc_metadata: dict[str, Any] | None = None + title: str | None = None diff --git a/api/core/rag/entities/context_entities.py b/api/core/rag/entities/context_entities.py index cd18ad081f..a2b03d54ba 100644 --- a/api/core/rag/entities/context_entities.py +++ b/api/core/rag/entities/context_entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel @@ -9,4 +7,4 @@ class DocumentContext(BaseModel): """ content: str - score: Optional[float] = None + score: float | None = None diff --git a/api/core/rag/entities/metadata_entities.py b/api/core/rag/entities/metadata_entities.py index 1f054bccdb..723229a332 100644 --- a/api/core/rag/entities/metadata_entities.py +++ b/api/core/rag/entities/metadata_entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field @@ -43,5 +43,5 @@ class MetadataCondition(BaseModel): Metadata Condition. """ - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py index 60dbc449f7..1f91a3ece1 100644 --- a/api/core/rag/extractor/blob/blob.py +++ b/api/core/rag/extractor/blob/blob.py @@ -12,7 +12,7 @@ import mimetypes from collections.abc import Generator, Mapping from io import BufferedReader, BytesIO from pathlib import Path, PurePath -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, ConfigDict, model_validator @@ -30,17 +30,17 @@ class Blob(BaseModel): """ data: Union[bytes, str, None] = None # Raw data - mimetype: Optional[str] = None # Not to be confused with a file extension + mimetype: str | None = None # Not to be confused with a file extension encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string # Location where the original content was found # Represent location on the local file system # Useful for situations where downstream code assumes it must work with file paths # rather than in-memory content. - path: Optional[PathLike] = None + path: PathLike | None = None model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) @property - def source(self) -> Optional[str]: + def source(self) -> str | None: """The source location of the blob as string if known otherwise none.""" return str(self.path) if self.path else None @@ -91,7 +91,7 @@ class Blob(BaseModel): path: PathLike, *, encoding: str = "utf-8", - mime_type: Optional[str] = None, + mime_type: str | None = None, guess_type: bool = True, ) -> Blob: """Load the blob from a path like object. @@ -120,8 +120,8 @@ class Blob(BaseModel): data: Union[str, bytes], *, encoding: str = "utf-8", - mime_type: Optional[str] = None, - path: Optional[str] = None, + mime_type: str | None = None, + path: str | None = None, ) -> Blob: """Initialize the blob from in-memory data. diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 5b67403902..3bfae9d6bd 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -1,7 +1,6 @@ """Abstract interface for document loader implementations.""" import csv -from typing import Optional import pandas as pd @@ -21,10 +20,10 @@ class CSVExtractor(BaseExtractor): def __init__( self, file_path: str, - encoding: Optional[str] = None, + encoding: str | None = None, autodetect_encoding: bool = False, - source_column: Optional[str] = None, - csv_args: Optional[dict] = None, + source_column: str | None = None, + csv_args: dict | None = None, ): """Initialize with file path.""" self._file_path = file_path diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 52d64f591f..04a35d6f1f 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, ConfigDict from models.dataset import Document @@ -14,7 +12,7 @@ class NotionInfo(BaseModel): notion_workspace_id: str notion_obj_id: str notion_page_type: str - document: Optional[Document] = None + document: Document | None = None tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) @@ -43,10 +41,10 @@ class ExtractSetting(BaseModel): """ datasource_type: str - upload_file: Optional[UploadFile] = None - notion_info: Optional[NotionInfo] = None - website_info: Optional[WebsiteInfo] = None - document_model: Optional[str] = None + upload_file: UploadFile | None = None + notion_info: NotionInfo | None = None + website_info: WebsiteInfo | None = None + document_model: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) def __init__(self, **data): diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index baa3fdf2eb..ea9c6bd73a 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" import os -from typing import Optional, cast +from typing import cast import pandas as pd from openpyxl import load_workbook @@ -18,7 +18,7 @@ class ExcelExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): + def __init__(self, file_path: str, encoding: str | None = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index b5ea08173b..0c70844000 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -1,7 +1,7 @@ import re import tempfile from pathlib import Path -from typing import Optional, Union +from typing import Union from urllib.parse import unquote from configs import dify_config @@ -90,7 +90,7 @@ class ExtractProcessor: @classmethod def extract( - cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: Optional[str] = None + cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: @@ -104,7 +104,7 @@ class ExtractProcessor: input_file = Path(file_path) file_extension = input_file.suffix.lower() etl_type = dify_config.ETL_TYPE - extractor: Optional[BaseExtractor] = None + extractor: BaseExtractor | None = None if etl_type == "Unstructured": unstructured_api_url = dify_config.UNSTRUCTURED_API_URL or "" unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY or "" diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 17f7d8661f..00004409d6 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -1,17 +1,17 @@ """Document loader helpers.""" import concurrent.futures -from typing import NamedTuple, Optional, cast +from typing import NamedTuple, cast class FileEncoding(NamedTuple): """A file encoding as the NamedTuple.""" - encoding: Optional[str] + encoding: str | None """The encoding of the file.""" confidence: float """The confidence of the encoding.""" - language: Optional[str] + language: str | None """The language of the file.""" diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index 3845392c8d..79d6ae2dac 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -2,7 +2,6 @@ import re from pathlib import Path -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.helpers import detect_file_encodings @@ -22,7 +21,7 @@ class MarkdownExtractor(BaseExtractor): file_path: str, remove_hyperlinks: bool = False, remove_images: bool = False, - encoding: Optional[str] = None, + encoding: str | None = None, autodetect_encoding: bool = True, ): """Initialize with file path.""" @@ -45,13 +44,13 @@ class MarkdownExtractor(BaseExtractor): return documents - def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]: + def markdown_to_tups(self, markdown_text: str) -> list[tuple[str | None, str]]: """Convert a markdown file to a dictionary. The keys are the headers and the values are the text under each header. """ - markdown_tups: list[tuple[Optional[str], str]] = [] + markdown_tups: list[tuple[str | None, str]] = [] lines = markdown_text.split("\n") current_header = None @@ -94,7 +93,7 @@ class MarkdownExtractor(BaseExtractor): content = re.sub(pattern, r"\1", content) return content - def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]: + def parse_tups(self, filepath: str) -> list[tuple[str | None, str]]: """Parse file into tuples.""" content = "" try: diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index fa96d73cf2..1779f26994 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,7 +1,7 @@ import json import logging import operator -from typing import Any, Optional, cast +from typing import Any, cast import requests from sqlalchemy import select @@ -36,8 +36,8 @@ class NotionExtractor(BaseExtractor): notion_obj_id: str, notion_page_type: str, tenant_id: str, - document_model: Optional[DocumentModel] = None, - notion_access_token: Optional[str] = None, + document_model: DocumentModel | None = None, + notion_access_token: str | None = None, ): self._notion_access_token = None self._document_model = document_model @@ -328,7 +328,7 @@ class NotionExtractor(BaseExtractor): result_lines = "\n".join(result_lines_arr) return result_lines - def update_last_edited_time(self, document_model: Optional[DocumentModel]): + def update_last_edited_time(self, document_model: DocumentModel | None): if not document_model: return diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 3c43f34104..80530d99a6 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -2,7 +2,6 @@ import contextlib from collections.abc import Iterator -from typing import Optional from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor @@ -18,7 +17,7 @@ class PdfExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, file_cache_key: Optional[str] = None): + def __init__(self, file_path: str, file_cache_key: str | None = None): """Initialize with file path.""" self._file_path = file_path self._file_cache_key = file_cache_key diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py index a00d328cb1..93f301ceff 100644 --- a/api/core/rag/extractor/text_extractor.py +++ b/api/core/rag/extractor/text_extractor.py @@ -1,7 +1,6 @@ """Abstract interface for document loader implementations.""" from pathlib import Path -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.helpers import detect_file_encodings @@ -16,7 +15,7 @@ class TextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): + def __init__(self, file_path: str, encoding: str | None = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 2427de8292..ad04bd0bd1 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,7 +1,6 @@ import base64 import contextlib import logging -from typing import Optional from bs4 import BeautifulSoup @@ -17,7 +16,7 @@ class UnstructuredEmailExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index fa91f7dd03..fc14ee6275 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import pypandoc # type: ignore @@ -20,7 +19,7 @@ class UnstructuredEpubExtractor(BaseExtractor): def __init__( self, file_path: str, - api_url: Optional[str] = None, + api_url: str | None = None, api_key: str = "", ): """Initialize with file path.""" diff --git a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 0a0c8d3a1c..23030d7739 100644 --- a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -16,7 +15,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor): """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index d363449c29..f29e639d1b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredMsgExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index ecc272a2f0..c12a55ee4b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredPPTExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index e7bf6fd2e6..99e3eec501 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredPPTXExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index 916cdc3f2b..d75e166f1b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredXmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index c59a70ea57..fe983aa86a 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient @@ -9,7 +9,7 @@ class WaterCrawlProvider: def __init__(self, api_key, base_url: str | None = None): self.client = WaterCrawlAPIClient(api_key, base_url) - def crawl_url(self, url, options: Optional[dict | Any] = None): + def crawl_url(self, url, options: dict | Any | None = None): options = options or {} spider_options = { "max_depth": 1, diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index c099fd1d5c..1e904e72e2 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,7 +1,6 @@ """Abstract interface for document loader implementations.""" from abc import ABC, abstractmethod -from typing import Optional from configs import dify_config from core.model_manager import ModelInstance @@ -31,7 +30,7 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): raise NotImplementedError @abstractmethod @@ -52,7 +51,7 @@ class BaseIndexProcessor(ABC): max_tokens: int, chunk_overlap: int, separator: str, - embedding_model_instance: Optional[ModelInstance], + embedding_model_instance: ModelInstance | None, ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 997b0b953b..5e0b24c354 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,7 +1,6 @@ """Paragraph index processor.""" import uuid -from typing import Optional from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword @@ -85,7 +84,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): else: keyword.add_texts(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 514596200f..f87e61b51c 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -1,7 +1,6 @@ """Paragraph index processor.""" import uuid -from typing import Optional from configs import dify_config from core.model_manager import ModelInstance @@ -109,7 +108,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): ] vector.create(formatted_child_documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids if dataset.indexing_technique == "high_quality": delete_child_chunks = kwargs.get("delete_child_chunks") or False @@ -187,7 +186,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): document_node: Document, rules: Rule, process_rule_mode: str, - embedding_model_instance: Optional[ModelInstance], + embedding_model_instance: ModelInstance | None, ) -> list[ChildDocument]: if not rules.subchunk_segmentation: raise ValueError("No subchunk segmentation found in rules.") diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index a4ec828e2f..2ca444ca86 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,7 +4,6 @@ import logging import re import threading import uuid -from typing import Optional import pandas as pd from flask import Flask, current_app @@ -128,7 +127,7 @@ class QAIndexProcessor(BaseIndexProcessor): vector = Vector(dataset) vector.create(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index ff63a6780e..b70d8bf559 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -10,7 +10,7 @@ class ChildDocument(BaseModel): page_content: str - vector: Optional[list[float]] = None + vector: list[float] | None = None """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). @@ -23,16 +23,16 @@ class Document(BaseModel): page_content: str - vector: Optional[list[float]] = None + vector: list[float] | None = None """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ metadata: dict = Field(default_factory=dict) - provider: Optional[str] = "dify" + provider: str | None = "dify" - children: Optional[list[ChildDocument]] = None + children: list[ChildDocument] | None = None class BaseDocumentTransformer(ABC): diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 818b04b2ff..3561def008 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional from core.rag.models.document import Document @@ -10,9 +9,9 @@ class BaseRerankRunner(ABC): self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 7a6ebd1f39..e855b0083f 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.model_manager import ModelInstance from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner @@ -13,9 +11,9 @@ class RerankModelRunner(BaseRerankRunner): self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index ab49e43b70..c455db6095 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -1,6 +1,5 @@ import math from collections import Counter -from typing import Optional import numpy as np @@ -22,9 +21,9 @@ class WeightRerankRunner(BaseRerankRunner): self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 93bad23f2b..b08f80da49 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -4,7 +4,7 @@ import re import threading from collections import Counter, defaultdict from collections.abc import Generator, Mapping -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from flask import Flask, current_app from sqlalchemy import Float, and_, or_, select, text @@ -85,9 +85,9 @@ class DatasetRetrieval: show_retrieve_source: bool, hit_callback: DatasetIndexToolCallbackHandler, message_id: str, - memory: Optional[TokenBufferMemory] = None, - inputs: Optional[Mapping[str, Any]] = None, - ) -> Optional[str]: + memory: TokenBufferMemory | None = None, + inputs: Mapping[str, Any] | None = None, + ) -> str | None: """ Retrieve dataset. :param app_id: app_id @@ -290,9 +290,9 @@ class DatasetRetrieval: model_instance: ModelInstance, model_config: ModelConfigWithCredentialsEntity, planning_strategy: PlanningStrategy, - message_id: Optional[str] = None, - metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, - metadata_condition: Optional[MetadataCondition] = None, + message_id: str | None = None, + metadata_filter_document_ids: dict[str, list[str]] | None = None, + metadata_condition: MetadataCondition | None = None, ): tools = [] for dataset in available_datasets: @@ -410,12 +410,12 @@ class DatasetRetrieval: top_k: int, score_threshold: float, reranking_mode: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict[str, Any]] = None, + reranking_model: dict | None = None, + weights: dict[str, Any] | None = None, reranking_enable: bool = True, - message_id: Optional[str] = None, - metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, - metadata_condition: Optional[MetadataCondition] = None, + message_id: str | None = None, + metadata_filter_document_ids: dict[str, list[str]] | None = None, + metadata_condition: MetadataCondition | None = None, ): if not available_datasets: return [] @@ -505,9 +505,7 @@ class DatasetRetrieval: return all_documents - def _on_retrieval_end( - self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None - ): + def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None): """Handle retrieval end.""" dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: @@ -588,8 +586,8 @@ class DatasetRetrieval: query: str, top_k: int, all_documents: list, - document_ids_filter: Optional[list[str]] = None, - metadata_condition: Optional[MetadataCondition] = None, + document_ids_filter: list[str] | None = None, + metadata_condition: MetadataCondition | None = None, ): with flask_app.app_context(): dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -664,7 +662,7 @@ class DatasetRetrieval: hit_callback: DatasetIndexToolCallbackHandler, user_id: str, inputs: dict, - ) -> Optional[list[DatasetRetrieverBaseTool]]: + ) -> list[DatasetRetrieverBaseTool] | None: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id @@ -853,9 +851,9 @@ class DatasetRetrieval: user_id: str, metadata_filtering_mode: str, metadata_model_config: ModelConfig, - metadata_filtering_conditions: Optional[MetadataFilteringCondition], + metadata_filtering_conditions: MetadataFilteringCondition | None, inputs: dict, - ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: + ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: document_query = db.session.query(DatasetDocument).where( DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.indexing_status == "completed", @@ -950,7 +948,7 @@ class DatasetRetrieval: def _automatic_metadata_filter_func( self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig - ) -> Optional[list[dict[str, Any]]]: + ) -> list[dict[str, Any]] | None: # get all metadata field metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) metadata_fields = db.session.scalars(metadata_stmt).all() @@ -1005,7 +1003,7 @@ class DatasetRetrieval: return automatic_metadata_filters def _process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list + self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list ): if value is None and condition not in ("empty", "not empty"): return diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index d654463be9..8356861242 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any from core.model_manager import ModelInstance from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer @@ -24,7 +24,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): @classmethod def from_encoder( cls: type[TS], - embedding_model_instance: Optional[ModelInstance], + embedding_model_instance: ModelInstance | None, allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037 disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 **kwargs: Any, @@ -48,7 +48,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): - def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any): + def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) self._fixed_separator = fixed_separator diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index bfa14e03d6..41e6d771e9 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -9,7 +9,6 @@ from dataclasses import dataclass from typing import ( Any, Literal, - Optional, TypeVar, Union, ) @@ -71,7 +70,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: list[dict] | None = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] @@ -94,7 +93,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): metadatas.append(doc.metadata or {}) return self.create_documents(texts, metadatas=metadatas) - def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: + def _join_docs(self, docs: list[str], separator: str) -> str | None: text = separator.join(docs) text = text.strip() if text == "": @@ -194,7 +193,7 @@ class TokenTextSplitter(TextSplitter): def __init__( self, encoding_name: str = "gpt2", - model_name: Optional[str] = None, + model_name: str | None = None, allowed_special: Union[Literal["all"], Set[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, @@ -245,7 +244,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): def __init__( self, - separators: Optional[list[str]] = None, + separators: list[str] | None = None, keep_separator: bool = True, **kwargs: Any, ): diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index d6f40491b6..eda7b54d6a 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -6,7 +6,7 @@ providing improved performance by offloading database operations to background w """ import logging -from typing import Optional, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -39,8 +39,8 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): _session_factory: sessionmaker _tenant_id: str - _app_id: Optional[str] - _triggered_from: Optional[WorkflowRunTriggeredFrom] + _app_id: str | None + _triggered_from: WorkflowRunTriggeredFrom | None _creator_user_id: str _creator_user_role: CreatorUserRole @@ -48,8 +48,8 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: Optional[str], - triggered_from: Optional[WorkflowRunTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowRunTriggeredFrom | None, ): """ Initialize the repository with Celery task configuration and context information. diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 95ad9f25fe..21a0b7eefe 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -7,7 +7,7 @@ providing improved performance by offloading database operations to background w import logging from collections.abc import Sequence -from typing import Optional, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -44,8 +44,8 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): _session_factory: sessionmaker _tenant_id: str - _app_id: Optional[str] - _triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom] + _app_id: str | None + _triggered_from: WorkflowNodeExecutionTriggeredFrom | None _creator_user_id: str _creator_user_role: CreatorUserRole _execution_cache: dict[str, WorkflowNodeExecution] @@ -55,8 +55,8 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: Optional[str], - triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowNodeExecutionTriggeredFrom | None, ): """ Initialize the repository with Celery task configuration and context information. @@ -151,7 +151,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): def get_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache. diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 46b028b219..ca5e33d83a 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -4,7 +4,7 @@ SQLAlchemy implementation of the WorkflowExecutionRepository. import json import logging -from typing import Optional, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -44,8 +44,8 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: Optional[str], - triggered_from: Optional[WorkflowRunTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowRunTriggeredFrom | None, ): """ Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 8702af9f80..de5fca9f44 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -5,7 +5,7 @@ SQLAlchemy implementation of the WorkflowNodeExecutionRepository. import json import logging from collections.abc import Sequence -from typing import Optional, Union +from typing import Union import psycopg2.errors from sqlalchemy import UnaryExpression, asc, desc, select @@ -52,8 +52,8 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: Optional[str], - triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowNodeExecutionTriggeredFrom | None, ): """ Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. @@ -279,7 +279,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) def get_db_models_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecutionModel]: """ Retrieve all WorkflowNodeExecution database models for a specific workflow run. @@ -334,7 +334,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) def get_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 5a2b803932..6e0462c530 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Generator from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from models.model import File @@ -46,9 +46,9 @@ class Tool(ABC): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage]: if self.runtime and self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters) @@ -96,17 +96,17 @@ class Tool(ABC): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: pass def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get the runtime parameters @@ -119,9 +119,9 @@ class Tool(ABC): def get_merged_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get merged runtime parameters @@ -196,7 +196,7 @@ class Tool(ABC): message=ToolInvokeMessage.TextMessage(text=text), ) - def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage: + def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage: """ create a blob message diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index ddec7b1329..3de0014c61 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from openai import BaseModel from pydantic import Field @@ -13,9 +13,9 @@ class ToolRuntime(BaseModel): """ tenant_id: str - tool_id: Optional[str] = None - invoke_from: Optional[InvokeFrom] = None - tool_invoke_from: Optional[ToolInvokeFrom] = None + tool_id: str | None = None + invoke_from: InvokeFrom | None = None + tool_invoke_from: ToolInvokeFrom | None = None credentials: dict[str, Any] = Field(default_factory=dict) credential_type: CredentialType = Field(default=CredentialType.API_KEY) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 5c24920871..af9b5b31c2 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -1,6 +1,6 @@ import io from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.file.enums import FileType from core.file.file_manager import download @@ -18,9 +18,9 @@ class ASRTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: file = tool_parameters.get("audio_file") if file.type != FileType.AUDIO: # type: ignore @@ -56,9 +56,9 @@ class ASRTool(BuiltinTool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: parameters = [] diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index f191968812..8bc159bb85 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -1,6 +1,6 @@ import io from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -16,9 +16,9 @@ class TTSTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: provider, model = tool_parameters.get("model").split("#") # type: ignore voice = tool_parameters.get(f"voice#{provider}#{model}") @@ -72,9 +72,9 @@ class TTSTool(BuiltinTool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: parameters = [] diff --git a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py index b4e650e0ed..4383943199 100644 --- a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py +++ b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage from core.tools.builtin_tool.tool import BuiltinTool @@ -12,9 +12,9 @@ class SimpleCode(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke simple code diff --git a/api/core/tools/builtin_tool/providers/time/tools/current_time.py b/api/core/tools/builtin_tool/providers/time/tools/current_time.py index d054afac96..44f94c2723 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/current_time.py +++ b/api/core/tools/builtin_tool/providers/time/tools/current_time.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import UTC, datetime -from typing import Any, Optional +from typing import Any from pytz import timezone as pytz_timezone @@ -13,9 +13,9 @@ class CurrentTimeTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tools diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index a8fd6ec2cd..197b062e44 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any import pytz @@ -14,9 +14,9 @@ class LocaltimeToTimestampTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Convert localtime to timestamp diff --git a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py index 0ef6331530..462e4be5ce 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any import pytz @@ -14,9 +14,9 @@ class TimestampToLocaltimeTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Convert timestamp to localtime diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index 91316b859a..babfa9bcd9 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any import pytz @@ -14,9 +14,9 @@ class TimezoneConversionTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Convert time to equivalent time zone diff --git a/api/core/tools/builtin_tool/providers/time/tools/weekday.py b/api/core/tools/builtin_tool/providers/time/tools/weekday.py index 158ce701c0..e26b316bd5 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/weekday.py +++ b/api/core/tools/builtin_tool/providers/time/tools/weekday.py @@ -1,7 +1,7 @@ import calendar from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage @@ -12,9 +12,9 @@ class WeekdayTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Calculate the day of the week for a given date diff --git a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py index 3bee710879..9d668ac9eb 100644 --- a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py +++ b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage @@ -12,9 +12,9 @@ class WebscraperTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tools diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 190af999b1..13dd2114d3 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -2,7 +2,7 @@ import json from collections.abc import Generator from dataclasses import dataclass from os import getenv -from typing import Any, Optional, Union +from typing import Any, Union from urllib.parse import urlencode import httpx @@ -376,9 +376,9 @@ class ApiTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke http request diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index ca3be26ff9..ee2b438f5b 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -14,9 +14,9 @@ class ToolApiEntity(BaseModel): name: str # identifier label: I18nObject # label description: I18nObject - parameters: Optional[list[ToolParameter]] = None + parameters: list[ToolParameter] | None = None labels: list[str] = Field(default_factory=list) - output_schema: Optional[dict] = None + output_schema: dict | None = None ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]] @@ -28,25 +28,25 @@ class ToolProviderApiEntity(BaseModel): name: str # identifier description: I18nObject icon: str | dict - icon_dark: Optional[str | dict] = Field(default=None, description="The dark icon of the tool") + icon_dark: str | dict | None = Field(default=None, description="The dark icon of the tool") label: I18nObject # label type: ToolProviderType - masked_credentials: Optional[dict] = None - original_credentials: Optional[dict] = None + masked_credentials: dict | None = None + original_credentials: dict | None = None is_team_authorization: bool = False allow_delete: bool = True - plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") - plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") + plugin_id: str | None = Field(default="", description="The plugin id of the tool") + plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool") tools: list[ToolApiEntity] = Field(default_factory=list) labels: list[str] = Field(default_factory=list) # MCP - server_url: Optional[str] = Field(default="", description="The server url of the tool") + server_url: str | None = Field(default="", description="The server url of the tool") updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) - server_identifier: Optional[str] = Field(default="", description="The server identifier of the MCP tool") - timeout: Optional[float] = Field(default=30.0, description="The timeout of the MCP tool") - sse_read_timeout: Optional[float] = Field(default=300.0, description="The SSE read timeout of the MCP tool") - masked_headers: Optional[dict[str, str]] = Field(default=None, description="The masked headers of the MCP tool") - original_headers: Optional[dict[str, str]] = Field(default=None, description="The original headers of the MCP tool") + server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool") + timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool") + sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool") + masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool") + original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool") @field_validator("tools", mode="before") @classmethod diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index aadbbeb843..2c6d9c1964 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field @@ -9,9 +7,9 @@ class I18nObject(BaseModel): """ en_US: str - zh_Hans: Optional[str] = Field(default=None) - pt_BR: Optional[str] = Field(default=None) - ja_JP: Optional[str] = Field(default=None) + zh_Hans: str | None = Field(default=None) + pt_BR: str | None = Field(default=None) + ja_JP: str | None = Field(default=None) def __init__(self, **data): super().__init__(**data) diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index ffeeabbc1c..eba20b07f0 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel from core.tools.entities.tool_entities import ToolParameter @@ -16,14 +14,14 @@ class ApiToolBundle(BaseModel): # method method: str # summary - summary: Optional[str] = None + summary: str | None = None # operation_id - operation_id: Optional[str] = None + operation_id: str | None = None # parameters - parameters: Optional[list[ToolParameter]] = None + parameters: list[ToolParameter] | None = None # author author: str # icon - icon: Optional[str] = None + icon: str | None = None # openapi operation openapi: dict diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index f934b0bf96..62dad1a50b 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -2,7 +2,7 @@ import base64 import contextlib from collections.abc import Mapping from enum import StrEnum, auto -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator @@ -182,11 +182,11 @@ class ToolInvokeMessage(BaseModel): id: str label: str = Field(..., description="The label of the log") - parent_id: Optional[str] = Field(default=None, description="Leave empty for root log") - error: Optional[str] = Field(default=None, description="The error message") + parent_id: str | None = Field(default=None, description="Leave empty for root log") + error: str | None = Field(default=None, description="The error message") status: LogStatus = Field(..., description="The status of the log") data: Mapping[str, Any] = Field(..., description="Detailed log data") - metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") + metadata: Mapping[str, Any] | None = Field(default=None, description="The metadata of the log") class RetrieverResourceMessage(BaseModel): retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") @@ -241,7 +241,7 @@ class ToolInvokeMessage(BaseModel): class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") - file_var: Optional[dict[str, Any]] = None + file_var: dict[str, Any] | None = None class ToolParameter(PluginParameter): @@ -285,11 +285,11 @@ class ToolParameter(PluginParameter): LLM = auto() # will be set by LLM type: ToolParameterType = Field(..., description="The type of the parameter") - human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") + human_description: I18nObject | None = Field(default=None, description="The description presented to the user") form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") - llm_description: Optional[str] = None + llm_description: str | None = None # MCP object and array type parameters use this field to store the schema - input_schema: Optional[dict] = None + input_schema: dict | None = None @classmethod def get_simple_instance( @@ -298,7 +298,7 @@ class ToolParameter(PluginParameter): llm_description: str, typ: ToolParameterType, required: bool, - options: Optional[list[str]] = None, + options: list[str] | None = None, ) -> "ToolParameter": """ get a simple tool parameter @@ -339,9 +339,9 @@ class ToolProviderIdentity(BaseModel): name: str = Field(..., description="The name of the tool") description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") - icon_dark: Optional[str] = Field(default=None, description="The dark icon of the tool") + icon_dark: str | None = Field(default=None, description="The dark icon of the tool") label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field( + tags: list[ToolLabelEnum] | None = Field( default=[], description="The tags of the tool", ) @@ -352,7 +352,7 @@ class ToolIdentity(BaseModel): name: str = Field(..., description="The name of the tool") label: I18nObject = Field(..., description="The label of the tool") provider: str = Field(..., description="The provider of the tool") - icon: Optional[str] = None + icon: str | None = None class ToolDescription(BaseModel): @@ -363,8 +363,8 @@ class ToolDescription(BaseModel): class ToolEntity(BaseModel): identity: ToolIdentity parameters: list[ToolParameter] = Field(default_factory=list) - description: Optional[ToolDescription] = None - output_schema: Optional[dict] = None + description: ToolDescription | None = None + output_schema: dict | None = None has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") # pydantic configs @@ -385,9 +385,9 @@ class OAuthSchema(BaseModel): class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity - plugin_id: Optional[str] = None + plugin_id: str | None = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) - oauth_schema: Optional[OAuthSchema] = None + oauth_schema: OAuthSchema | None = None class ToolProviderEntityWithPlugin(ToolProviderEntity): @@ -410,8 +410,8 @@ class ToolInvokeMeta(BaseModel): """ time_cost: float = Field(..., description="The time cost of the tool invoke") - error: Optional[str] = None - tool_config: Optional[dict] = None + error: str | None = None + tool_config: dict | None = None @classmethod def empty(cls) -> "ToolInvokeMeta": @@ -463,11 +463,11 @@ class ToolSelector(BaseModel): type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter") required: bool = Field(..., description="Whether the parameter is required") description: str = Field(..., description="The description of the parameter") - default: Optional[Union[int, float, str]] = None - options: Optional[list[PluginParameterOption]] = None + default: Union[int, float, str] | None = None + options: list[PluginParameterOption] | None = None provider_id: str = Field(..., description="The id of the provider") - credential_id: Optional[str] = Field(default=None, description="The id of the credential") + credential_id: str | None = Field(default=None, description="The id of the credential") tool_name: str = Field(..., description="The name of the tool") tool_description: str = Field(..., description="The description of the tool") tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 5f6eb045ab..60b393e1ea 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional, Self +from typing import Any, Self from core.mcp.types import Tool as RemoteMCPTool from core.tools.__base.tool_provider import ToolProviderController @@ -25,9 +25,9 @@ class MCPToolProviderController(ToolProviderController): provider_id: str, tenant_id: str, server_url: str, - headers: Optional[dict[str, str]] = None, - timeout: Optional[float] = None, - sse_read_timeout: Optional[float] = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + sse_read_timeout: float | None = None, ): super().__init__(entity) self.entity: ToolProviderEntityWithPlugin = entity diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 21d256ae03..976d4dc942 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -1,7 +1,7 @@ import base64 import json from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.mcp_client import MCPClient @@ -20,9 +20,9 @@ class MCPTool(Tool): icon: str, server_url: str, provider_id: str, - headers: Optional[dict[str, str]] = None, - timeout: Optional[float] = None, - sse_read_timeout: Optional[float] = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + sse_read_timeout: float | None = None, ): super().__init__(entity, runtime) self.tenant_id = tenant_id @@ -40,9 +40,9 @@ class MCPTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: from core.tools.errors import ToolInvokeError diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index e649caec1d..828dc3b810 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.plugin.impl.tool import PluginToolManager from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -16,7 +16,7 @@ class PluginTool(Tool): self.tenant_id = tenant_id self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - self.runtime_parameters: Optional[list[ToolParameter]] = None + self.runtime_parameters: list[ToolParameter] | None = None def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.PLUGIN @@ -25,9 +25,9 @@ class PluginTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: manager = PluginToolManager() @@ -57,9 +57,9 @@ class PluginTool(Tool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get the runtime parameters diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index c3fdc37303..0154ffe883 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Iterable from copy import deepcopy from datetime import UTC, datetime from mimetypes import guess_type -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from yarl import URL @@ -51,10 +51,10 @@ class ToolEngine: message: Message, invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, - trace_manager: Optional[TraceQueueManager] = None, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + trace_manager: TraceQueueManager | None = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> tuple[str, list[str], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. @@ -152,10 +152,10 @@ class ToolEngine: user_id: str, workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_call_depth: int, - thread_pool_id: Optional[str] = None, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + thread_pool_id: str | None = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Workflow invokes the tool with the given arguments. @@ -196,9 +196,9 @@ class ToolEngine: tool: Tool, tool_parameters: dict, user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: """ Invoke the tool with the given arguments. diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index ad650196ce..6289f1d335 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -6,7 +6,7 @@ import os import time from collections.abc import Generator from mimetypes import guess_extension, guess_type -from typing import Optional, Union +from typing import Union from uuid import uuid4 import httpx @@ -72,10 +72,10 @@ class ToolFileManager: *, user_id: str, tenant_id: str, - conversation_id: Optional[str], + conversation_id: str | None, file_binary: bytes, mimetype: str, - filename: Optional[str] = None, + filename: str | None = None, ) -> ToolFile: extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex @@ -112,7 +112,7 @@ class ToolFileManager: user_id: str, tenant_id: str, file_url: str, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> ToolFile: # try to download image try: @@ -217,7 +217,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Optional[Generator], Optional[ToolFile]]: + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: """ get file binary diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index b29da3f0ba..f1f4969d22 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,7 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Union, cast import sqlalchemy as sa from pydantic import TypeAdapter @@ -157,7 +157,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: """ get the tool runtime @@ -358,7 +358,7 @@ class ToolManager: app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional[VariablePool] = None, + variable_pool: VariablePool | None = None, ) -> Tool: """ get the agent tool runtime @@ -400,7 +400,7 @@ class ToolManager: node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional[VariablePool] = None, + variable_pool: VariablePool | None = None, ) -> Tool: """ get the workflow tool runtime @@ -443,7 +443,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Tool: """ get tool runtime from plugin @@ -977,7 +977,7 @@ class ToolManager: def _convert_tool_parameters_type( cls, parameters: list[ToolParameter], - variable_pool: Optional[VariablePool], + variable_pool: VariablePool | None, tool_configurations: dict[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py index 2e572099b3..ac2967d0c1 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -13,7 +12,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): description: str = "use this to retrieve a dataset. " tenant_id: str top_k: int = 4 - score_threshold: Optional[float] = None + score_threshold: float | None = None hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] return_resource: bool retriever_from: str diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index b536c5a25c..0e2237befd 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, cast +from typing import Any, cast from pydantic import BaseModel, Field from sqlalchemy import select @@ -37,7 +37,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): args_schema: type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " dataset_id: str - user_id: Optional[str] = None + user_id: str | None = None retrieve_config: DatasetRetrieveConfigEntity inputs: dict diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index d5803e33e7..a62d419243 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -87,9 +87,9 @@ class DatasetRetrieverTool(Tool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: return [ ToolParameter( @@ -112,9 +112,9 @@ class DatasetRetrieverTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke dataset retriever tool diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py index 5820be0ffb..45ad14cb8e 100644 --- a/api/core/tools/utils/encryption.py +++ b/api/core/tools/utils/encryption.py @@ -1,6 +1,6 @@ import contextlib from copy import deepcopy -from typing import Any, Optional, Protocol +from typing import Any, Protocol from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter @@ -13,7 +13,7 @@ class ProviderConfigCache(Protocol): Interface for provider configuration cache operations """ - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """Get cached provider configuration""" ... diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index bf075bd730..0851a54338 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -3,7 +3,6 @@ from collections.abc import Generator from datetime import date, datetime from decimal import Decimal from mimetypes import guess_extension -from typing import Optional from uuid import UUID import numpy as np @@ -60,7 +59,7 @@ class ToolFileMessageTransformer: messages: Generator[ToolInvokeMessage, None, None], user_id: str, tenant_id: str, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Transform tool message and handle file download @@ -165,5 +164,5 @@ class ToolFileMessageTransformer: yield message @classmethod - def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: + def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str: return f"/files/tools/{tool_file_id}{extension or '.bin'}" diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 251d914800..526f5c8b9a 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -5,7 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models. """ import json -from typing import Optional, cast +from typing import cast from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult @@ -51,7 +51,7 @@ class ModelInvocationUtils: if not schema: raise InvokeModelError("No model schema found") - max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) + max_tokens: int | None = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index cae21633fe..2e306db6c7 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -2,7 +2,6 @@ import re from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError -from typing import Optional from flask import request from requests import get @@ -198,9 +197,9 @@ class ApiBasedToolSchemaParser: return bundles @staticmethod - def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: + def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None: parameter = parameter or {} - typ: Optional[str] = None + typ: str | None = None if parameter.get("format") == "binary": return ToolParameter.ToolParameterType.FILE diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_oauth_encryption.py index f3c946b95f..6b7007842d 100644 --- a/api/core/tools/utils/system_oauth_encryption.py +++ b/api/core/tools/utils/system_oauth_encryption.py @@ -2,7 +2,7 @@ import base64 import hashlib import logging from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from Crypto.Cipher import AES from Crypto.Random import get_random_bytes @@ -28,7 +28,7 @@ class SystemOAuthEncrypter: using AES-CBC mode with a key derived from the application's SECRET_KEY. """ - def __init__(self, secret_key: Optional[str] = None): + def __init__(self, secret_key: str | None = None): """ Initialize the OAuth encrypter. @@ -130,7 +130,7 @@ class SystemOAuthEncrypter: # Factory function for creating encrypter instances -def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter: +def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter: """ Create an OAuth encrypter instance. @@ -144,7 +144,7 @@ def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAu # Global encrypter instance (for backward compatibility) -_oauth_encrypter: Optional[SystemOAuthEncrypter] = None +_oauth_encrypter: SystemOAuthEncrypter | None = None def get_system_oauth_encrypter() -> SystemOAuthEncrypter: diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index d8403c2e15..52c16c34a0 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -2,7 +2,7 @@ import mimetypes import re from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Optional, cast +from typing import Any, cast from urllib.parse import unquote import chardet @@ -27,7 +27,7 @@ def page_result(text: str, cursor: int, max_length: int) -> str: return text[cursor : cursor + max_length] -def get_url(url: str, user_agent: Optional[str] = None) -> str: +def get_url(url: str, user_agent: str | None = None) -> str: """Fetch URL and return the contents as a string.""" headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 18e6993b38..4d9c8895fc 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -1,5 +1,4 @@ from collections.abc import Mapping -from typing import Optional from pydantic import Field @@ -207,7 +206,7 @@ class WorkflowToolProviderController(ToolProviderController): return self.tools - def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: # type: ignore + def get_tool(self, tool_name: str) -> WorkflowTool | None: # type: ignore """ get tool by name diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index c9d62388f2..6a1ac51528 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from typing import Any, Optional +from typing import Any from sqlalchemy import select @@ -39,7 +39,7 @@ class WorkflowTool(Tool): entity: ToolEntity, runtime: ToolRuntime, label: str = "Workflow", - thread_pool_id: Optional[str] = None, + thread_pool_id: str | None = None, ): self.workflow_app_id = workflow_app_id self.workflow_as_tool_id = workflow_as_tool_id @@ -63,9 +63,9 @@ class WorkflowTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke the tool diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py index ec62be605f..6fce5a83b9 100644 --- a/api/core/workflow/callbacks/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, @@ -37,7 +35,7 @@ _TEXT_COLOR_MAPPING = { class WorkflowLoggingCallback(WorkflowCallback): def __init__(self): - self.current_node_id: Optional[str] = None + self.current_node_id: str | None = None def on_event(self, event: GraphEngineEvent): if isinstance(event, GraphRunStartedEvent): @@ -250,7 +248,7 @@ class WorkflowLoggingCallback(WorkflowCallback): ) self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue") - def print_text(self, text: str, color: Optional[str] = None, end: str = "\n"): + def print_text(self, text: str, color: str | None = None, end: str = "\n"): """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(f"{text_to_print}", end=end) diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 687ec8e47c..d672136d97 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -14,16 +14,16 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING - inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[Mapping[str, Any]] = None # process data - outputs: Optional[Mapping[str, Any]] = None # node outputs - metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata - llm_usage: Optional[LLMUsage] = None # llm usage + inputs: Mapping[str, Any] | None = None # node inputs + process_data: Mapping[str, Any] | None = None # process data + outputs: Mapping[str, Any] | None = None # node outputs + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # node metadata + llm_usage: LLMUsage | None = None # llm usage - edge_source_handle: Optional[str] = None # source handle id of node with multiple branches + edge_source_handle: str | None = None # source handle id of node with multiple branches - error: Optional[str] = None # error message if status is failed - error_type: Optional[str] = None # error type if status is failed + error: str | None = None # error message if status is failed + error_type: str | None = None # error type if status is failed # single step node run retry retry_index: int = 0 diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index f00dc11aa6..2e86605419 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -8,7 +8,7 @@ implementation details like tenant_id, app_id, etc. from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -45,7 +45,7 @@ class WorkflowExecution(BaseModel): graph: Mapping[str, Any] = Field(...) inputs: Mapping[str, Any] = Field(...) - outputs: Optional[Mapping[str, Any]] = None + outputs: Mapping[str, Any] | None = None status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING error_message: str = Field(default="") @@ -54,7 +54,7 @@ class WorkflowExecution(BaseModel): exceptions_count: int = Field(default=0) started_at: datetime = Field(...) - finished_at: Optional[datetime] = None + finished_at: datetime | None = None @property def elapsed_time(self) -> float: diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py index ff72d7cbf3..e00099cda8 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/core/workflow/entities/workflow_node_execution.py @@ -9,7 +9,7 @@ and don't contain implementation details like tenant_id, app_id, etc. from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -77,41 +77,41 @@ class WorkflowNodeExecution(BaseModel): # NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`. # While `node_execution_id` may sometimes be a UUID string, this is not guaranteed. # In most scenarios, `id` should be used as the primary identifier. - node_execution_id: Optional[str] = None + node_execution_id: str | None = None workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging) + workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging) # --------- Core identification fields ends --------- # Execution positioning and flow index: int # Sequence number for ordering in trace visualization - predecessor_node_id: Optional[str] = None # ID of the node that executed before this one + predecessor_node_id: str | None = None # ID of the node that executed before this one node_id: str # ID of the node being executed node_type: NodeType # Type of node (e.g., start, llm, knowledge) title: str # Display title of the node # Execution data - inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node - process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data - outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node + inputs: Mapping[str, Any] | None = None # Input variables used by this node + process_data: Mapping[str, Any] | None = None # Intermediate processing data + outputs: Mapping[str, Any] | None = None # Output variables produced by this node # Execution state status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status - error: Optional[str] = None # Error message if execution failed + error: str | None = None # Error message if execution failed elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds # Additional metadata - metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.) # Timing information created_at: datetime # When execution started - finished_at: Optional[datetime] = None # When execution completed + finished_at: datetime | None = None # When execution completed def update_from_mapping( self, - inputs: Optional[Mapping[str, Any]] = None, - process_data: Optional[Mapping[str, Any]] = None, - outputs: Optional[Mapping[str, Any]] = None, - metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None, + inputs: Mapping[str, Any] | None = None, + process_data: Mapping[str, Any] | None = None, + outputs: Mapping[str, Any] | None = None, + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, ): """ Update the model from mappings. diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 6e72f8b152..c2865cdb02 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -29,7 +29,7 @@ class GraphRunStartedEvent(BaseGraphEvent): class GraphRunSucceededEvent(BaseGraphEvent): - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None """outputs""" @@ -40,7 +40,7 @@ class GraphRunFailedEvent(BaseGraphEvent): class GraphRunPartialSucceededEvent(BaseGraphEvent): exceptions_count: int = Field(..., description="exception count") - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None ########################################### @@ -54,33 +54,33 @@ class BaseNodeEvent(GraphEngineEvent): node_type: NodeType = Field(..., description="node type") node_data: BaseNodeData = Field(..., description="node data") route_node_state: RouteNodeState = Field(..., description="route node state") - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" # The version of the node, or "1" if not specified. node_version: str = "1" class NodeRunStartedEvent(BaseNodeEvent): - predecessor_node_id: Optional[str] = None + predecessor_node_id: str | None = None """predecessor node id""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration node parallel mode run id""" - agent_strategy: Optional[AgentNodeStrategyInit] = None + agent_strategy: AgentNodeStrategyInit | None = None class NodeRunStreamChunkEvent(BaseNodeEvent): chunk_content: str = Field(..., description="chunk content") - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None """from variable selector""" @@ -125,13 +125,13 @@ class BaseParallelBranchEvent(GraphEngineEvent): """parallel id""" parallel_start_node_id: str = Field(..., description="parallel start node id") """parallel start node id""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -157,45 +157,45 @@ class BaseIterationEvent(GraphEngineEvent): iteration_node_id: str = Field(..., description="iteration node id") iteration_node_type: NodeType = Field(..., description="node type, iteration or loop") iteration_node_data: BaseNodeData = Field(..., description="node data") - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration run in parallel mode run id""" class IterationRunStartedEvent(BaseIterationEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None + inputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None class IterationRunNextEvent(BaseIterationEvent): index: int = Field(..., description="index") - pre_iteration_output: Optional[Any] = None - duration: Optional[float] = None + pre_iteration_output: Any | None = None + duration: float | None = None class IterationRunSucceededEvent(BaseIterationEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 - iteration_duration_map: Optional[dict[str, float]] = None + iteration_duration_map: dict[str, float] | None = None class IterationRunFailedEvent(BaseIterationEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 error: str = Field(..., description="failed reason") @@ -210,45 +210,45 @@ class BaseLoopEvent(GraphEngineEvent): loop_node_id: str = Field(..., description="loop node id") loop_node_type: NodeType = Field(..., description="node type, loop or loop") loop_node_data: BaseNodeData = Field(..., description="node data") - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """loop run in parallel mode run id""" class LoopRunStartedEvent(BaseLoopEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None + inputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None class LoopRunNextEvent(BaseLoopEvent): index: int = Field(..., description="index") - pre_loop_output: Optional[Any] = None - duration: Optional[float] = None + pre_loop_output: Any | None = None + duration: float | None = None class LoopRunSucceededEvent(BaseLoopEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 - loop_duration_map: Optional[dict[str, float]] = None + loop_duration_map: dict[str, float] | None = None class LoopRunFailedEvent(BaseLoopEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 error: str = Field(..., description="failed reason") @@ -270,7 +270,7 @@ class AgentLogEvent(BaseAgentEvent): error: str | None = Field(..., description="error") status: str = Field(..., description="status") data: Mapping[str, Any] = Field(..., description="data") - metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata") + metadata: Mapping[str, Any] | None = Field(default=None, description="metadata") node_id: str = Field(..., description="agent node id") diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index d8d1825d94..bb4a7e1e81 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,7 +1,7 @@ import uuid from collections import defaultdict from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from pydantic import BaseModel, Field @@ -17,18 +17,18 @@ from core.workflow.nodes.end.entities import EndStreamParam class GraphEdge(BaseModel): source_node_id: str = Field(..., description="source node id") target_node_id: str = Field(..., description="target node id") - run_condition: Optional[RunCondition] = None + run_condition: RunCondition | None = None """run condition""" class GraphParallel(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id") start_from_node_id: str = Field(..., description="start from node id") - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id""" - end_to_node_id: Optional[str] = None + end_to_node_id: str | None = None """end to node id""" @@ -54,7 +54,7 @@ class Graph(BaseModel): end_stream_param: EndStreamParam = Field(..., description="end stream param") @classmethod - def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph": + def init(cls, graph_config: Mapping[str, Any], root_node_id: str | None = None) -> "Graph": """ Init graph @@ -253,7 +253,7 @@ class Graph(BaseModel): start_node_id: str, parallel_mapping: dict[str, GraphParallel], node_parallel_mapping: dict[str, str], - parent_parallel: Optional[GraphParallel] = None, + parent_parallel: GraphParallel | None = None, ): """ Recursively add parallel ids @@ -422,9 +422,9 @@ class Graph(BaseModel): cls, parallel_mapping: dict[str, GraphParallel], graph_edge: GraphEdge, - parallel: Optional[GraphParallel] = None, - parent_parallel: Optional[GraphParallel] = None, - ) -> Optional[GraphParallel]: + parallel: GraphParallel | None = None, + parent_parallel: GraphParallel | None = None, + ) -> GraphParallel | None: """ Get current parallel """ diff --git a/api/core/workflow/graph_engine/entities/run_condition.py b/api/core/workflow/graph_engine/entities/run_condition.py index eedce8842b..7b9a379215 100644 --- a/api/core/workflow/graph_engine/entities/run_condition.py +++ b/api/core/workflow/graph_engine/entities/run_condition.py @@ -1,5 +1,5 @@ import hashlib -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -10,10 +10,10 @@ class RunCondition(BaseModel): type: Literal["branch_identify", "condition"] """condition type""" - branch_identify: Optional[str] = None + branch_identify: str | None = None """branch identify like: sourceHandle, required when type is branch_identify""" - conditions: Optional[list[Condition]] = None + conditions: list[Condition] | None = None """conditions to run the node, required when type is condition""" @property diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index d010ce05b8..c6b8a0b334 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -1,7 +1,6 @@ import uuid from datetime import datetime from enum import StrEnum, auto -from typing import Optional from pydantic import BaseModel, Field @@ -24,7 +23,7 @@ class RouteNodeState(BaseModel): node_id: str """node id""" - node_run_result: Optional[NodeRunResult] = None + node_run_result: NodeRunResult | None = None """node run result""" status: Status = Status.RUNNING @@ -33,16 +32,16 @@ class RouteNodeState(BaseModel): start_at: datetime """start time""" - paused_at: Optional[datetime] = None + paused_at: datetime | None = None """paused time""" - finished_at: Optional[datetime] = None + finished_at: datetime | None = None """finished time""" - failed_reason: Optional[str] = None + failed_reason: str | None = None """failed reason""" - paused_by: Optional[str] = None + paused_by: str | None = None """paused by""" index: int = 1 diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 9b0b187a7c..bdb8070add 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -6,7 +6,7 @@ import uuid from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy -from typing import Any, Optional, cast +from typing import Any, cast from flask import Flask, current_app @@ -103,7 +103,7 @@ class GraphEngine: graph_runtime_state: GraphRuntimeState, max_execution_steps: int, max_execution_time: int, - thread_pool_id: Optional[str] = None, + thread_pool_id: str | None = None, ): thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT thread_pool_max_workers = 10 @@ -223,9 +223,9 @@ class GraphEngine: def _run( self, start_node_id: str, - in_parallel_id: Optional[str] = None, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + in_parallel_id: str | None = None, + parent_parallel_id: str | None = None, + parent_parallel_start_node_id: str | None = None, handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: parallel_start_node_id = None @@ -233,7 +233,7 @@ class GraphEngine: parallel_start_node_id = start_node_id next_node_id = start_node_id - previous_route_node_state: Optional[RouteNodeState] = None + previous_route_node_state: RouteNodeState | None = None while True: # max steps reached if self.graph_runtime_state.node_run_steps > self.max_execution_steps: @@ -444,8 +444,8 @@ class GraphEngine: def _run_parallel_branches( self, edge_mappings: list[GraphEdge], - in_parallel_id: Optional[str] = None, - parallel_start_node_id: Optional[str] = None, + in_parallel_id: str | None = None, + parallel_start_node_id: str | None = None, handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent | str, None, None]: # if nodes has no run conditions, parallel run all nodes @@ -534,8 +534,8 @@ class GraphEngine: q: queue.Queue, parallel_id: str, parallel_start_node_id: str, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + parent_parallel_id: str | None = None, + parent_parallel_start_node_id: str | None = None, handle_exceptions: list[str] = [], ): """ @@ -600,10 +600,10 @@ class GraphEngine: self, node: BaseNode, route_node_state: RouteNodeState, - parallel_id: Optional[str] = None, - parallel_start_node_id: Optional[str] = None, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + parallel_id: str | None = None, + parallel_start_node_id: str | None = None, + parent_parallel_id: str | None = None, + parent_parallel_start_node_id: str | None = None, handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: """ diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 2e5912652c..70ec2ee2b9 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from packaging.version import Version from pydantic import ValidationError @@ -69,7 +69,7 @@ class AgentNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = AgentNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -78,7 +78,7 @@ class AgentNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -401,7 +401,7 @@ class AgentNode(BaseNode): icon = None return icon - def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None: # get conversation id conversation_id_variable = self.graph_runtime_state.variable_pool.get( ["sys", SystemVariableKey.CONVERSATION_ID.value] diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py index d5955bdd7d..944f5f0b20 100644 --- a/api/core/workflow/nodes/agent/exc.py +++ b/api/core/workflow/nodes/agent/exc.py @@ -1,6 +1,3 @@ -from typing import Optional - - class AgentNodeError(Exception): """Base exception for all agent node errors.""" @@ -12,7 +9,7 @@ class AgentNodeError(Exception): class AgentStrategyError(AgentNodeError): """Exception raised when there's an error with the agent strategy.""" - def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None): + def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None): self.strategy_name = strategy_name self.provider_name = provider_name super().__init__(message) @@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError): class AgentStrategyNotFoundError(AgentStrategyError): """Exception raised when the specified agent strategy is not found.""" - def __init__(self, strategy_name: str, provider_name: Optional[str] = None): + def __init__(self, strategy_name: str, provider_name: str | None = None): super().__init__( f"Agent strategy '{strategy_name}' not found" + (f" for provider '{provider_name}'" if provider_name else ""), @@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError): class AgentInvocationError(AgentNodeError): """Exception raised when there's an error invoking the agent.""" - def __init__(self, message: str, original_error: Optional[Exception] = None): + def __init__(self, message: str, original_error: Exception | None = None): self.original_error = original_error super().__init__(message) @@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError): class AgentParameterError(AgentNodeError): """Exception raised when there's an error with agent parameters.""" - def __init__(self, message: str, parameter_name: Optional[str] = None): + def __init__(self, message: str, parameter_name: str | None = None): self.parameter_name = parameter_name super().__init__(message) @@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError): class AgentVariableError(AgentNodeError): """Exception raised when there's an error with variables in the agent node.""" - def __init__(self, message: str, variable_name: Optional[str] = None): + def __init__(self, message: str, variable_name: str | None = None): self.variable_name = variable_name super().__init__(message) @@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError): class ToolFileError(AgentNodeError): """Exception raised when there's an error with a tool file.""" - def __init__(self, message: str, file_id: Optional[str] = None): + def __init__(self, message: str, file_id: str | None = None): self.file_id = file_id super().__init__(message) @@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError): class AgentMessageTransformError(AgentNodeError): """Exception raised when there's an error transforming agent messages.""" - def __init__(self, message: str, original_error: Optional[Exception] = None): + def __init__(self, message: str, original_error: Exception | None = None): self.original_error = original_error super().__init__(message) @@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError): class AgentModelError(AgentNodeError): """Exception raised when there's an error with the model used by the agent.""" - def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None): + def __init__(self, message: str, model_name: str | None = None, provider: str | None = None): self.model_name = model_name self.provider = provider super().__init__(message) @@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError): class AgentMemoryError(AgentNodeError): """Exception raised when there's an error with the agent's memory.""" - def __init__(self, message: str, conversation_id: Optional[str] = None): + def __init__(self, message: str, conversation_id: str | None = None): self.conversation_id = conversation_id super().__init__(message) @@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError): def __init__( self, message: str, - variable_name: Optional[str] = None, - expected_type: Optional[str] = None, - actual_type: Optional[str] = None, + variable_name: str | None = None, + expected_type: str | None = None, + actual_type: str | None = None, ): self.variable_name = variable_name self.expected_type = expected_type diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 116250c5ca..184f109127 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from core.variables import ArrayFileSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult @@ -25,7 +25,7 @@ class AnswerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = AnswerNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -34,7 +34,7 @@ class AnswerNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 9e8e1787e5..00eb28b882 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -1,7 +1,6 @@ import logging from abc import ABC, abstractmethod from collections.abc import Generator -from typing import Optional from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent @@ -72,7 +71,7 @@ class StreamProcessor(ABC): for node_id in unreachable_first_node_ids: self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) - def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: + def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: str | None = None) -> list[str]: if node_id not in self.rest_node_ids: self.rest_node_ids.append(node_id) node_ids = [] diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 90e45e9d25..c1dac5a1da 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,7 +1,7 @@ import json from abc import ABC from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, model_validator @@ -121,10 +121,10 @@ class RetryConfig(BaseModel): class BaseNodeData(ABC, BaseModel): title: str - desc: Optional[str] = None + desc: str | None = None version: str = "1" - error_strategy: Optional[ErrorStrategy] = None - default_value: Optional[list[DefaultValue]] = None + error_strategy: ErrorStrategy | None = None + default_value: list[DefaultValue] | None = None retry_config: RetryConfig = RetryConfig() @property @@ -135,7 +135,7 @@ class BaseNodeData(ABC, BaseModel): class BaseIterationNodeData(BaseNodeData): - start_node_id: Optional[str] = None + start_node_id: str | None = None class BaseIterationState(BaseModel): @@ -150,7 +150,7 @@ class BaseIterationState(BaseModel): class BaseLoopNodeData(BaseNodeData): - start_node_id: Optional[str] = None + start_node_id: str | None = None class BaseLoopState(BaseModel): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 3aee9b2cc2..0fe8aa5908 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,7 +1,7 @@ import logging from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Union from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -26,8 +26,8 @@ class BaseNode: graph_init_params: "GraphInitParams", graph: "Graph", graph_runtime_state: "GraphRuntimeState", - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None, + previous_node_id: str | None = None, + thread_pool_id: str | None = None, ): self.id = id self.tenant_id = graph_init_params.tenant_id @@ -141,7 +141,7 @@ class BaseNode: return {} @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): return {} @property @@ -170,7 +170,7 @@ class BaseNode: # to BaseNodeData properties in a type-safe way @abstractmethod - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: """Get the error strategy for this node.""" ... @@ -185,7 +185,7 @@ class BaseNode: ... @abstractmethod - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: """Get the node description.""" ... @@ -201,7 +201,7 @@ class BaseNode: # Public interface properties that delegate to abstract methods @property - def error_strategy(self) -> Optional[ErrorStrategy]: + def error_strategy(self) -> ErrorStrategy | None: """Get the error strategy for this node.""" return self._get_error_strategy() @@ -216,7 +216,7 @@ class BaseNode: return self._get_title() @property - def description(self) -> Optional[str]: + def description(self) -> str | None: """Get the node description.""" return self._get_description() diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index d32d868651..d5cf242182 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from decimal import Decimal -from typing import Any, Optional +from typing import Any from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage @@ -31,7 +31,7 @@ class CodeNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = CodeNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -40,7 +40,7 @@ class CodeNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -50,7 +50,7 @@ class CodeNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): """ Get default config of node. :param filters: filter by node config parameters. @@ -161,7 +161,7 @@ class CodeNode(BaseNode): def _transform_result( self, result: Mapping[str, Any], - output_schema: Optional[dict[str, CodeNodeData.Output]], + output_schema: dict[str, CodeNodeData.Output] | None, prefix: str = "", depth: int = 1, ): diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 9d380c6fb6..ab23e0ae83 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Literal, Optional +from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel @@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: Optional[dict[str, "CodeNodeData.Output"]] = None + children: dict[str, "CodeNodeData.Output"] | None = None class Dependency(BaseModel): name: str @@ -44,4 +44,4 @@ class CodeNodeData(BaseNodeData): code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str outputs: dict[str, Output] - dependencies: Optional[list[Dependency]] = None + dependencies: list[Dependency] | None = None diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 7848bab446..b488fec84a 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,7 +5,7 @@ import logging import os import tempfile from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any import chardet import docx @@ -50,7 +50,7 @@ class DocumentExtractorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = DocumentExtractorNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -59,7 +59,7 @@ class DocumentExtractorNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 0aff039e92..b49fdc141f 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -17,7 +17,7 @@ class EndNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = EndNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -26,7 +26,7 @@ class EndNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 8d7ba25d47..5a7db6e0e6 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,7 +1,7 @@ import mimetypes from collections.abc import Sequence from email.message import Message -from typing import Any, Literal, Optional +from typing import Any, Literal import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator @@ -18,7 +18,7 @@ class HttpRequestNodeAuthorizationConfig(BaseModel): class HttpRequestNodeAuthorization(BaseModel): type: Literal["no-auth", "api-key"] - config: Optional[HttpRequestNodeAuthorizationConfig] = None + config: HttpRequestNodeAuthorizationConfig | None = None @field_validator("config", mode="before") @classmethod @@ -88,9 +88,9 @@ class HttpRequestNodeData(BaseNodeData): authorization: HttpRequestNodeAuthorization headers: str params: str - body: Optional[HttpRequestNodeBody] = None - timeout: Optional[HttpRequestNodeTimeout] = None - ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY + body: HttpRequestNodeBody | None = None + timeout: HttpRequestNodeTimeout | None = None + ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY class Response: @@ -183,7 +183,7 @@ class Response: return f"{(self.size / 1024 / 1024):.2f} MB" @property - def parsed_content_disposition(self) -> Optional[Message]: + def parsed_content_disposition(self) -> Message | None: content_disposition = self.headers.get("content-disposition", "") if content_disposition: msg = Message() diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index bb3c453d99..837cf883c8 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,7 +1,7 @@ import logging import mimetypes from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from configs import dify_config from core.file import File, FileTransferMethod @@ -41,7 +41,7 @@ class HttpRequestNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = HttpRequestNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -50,7 +50,7 @@ class HttpRequestNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -60,7 +60,7 @@ class HttpRequestNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict[str, Any]] = None): + def get_default_config(cls, filters: dict[str, Any] | None = None): return { "type": "http-request", "config": { diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 67d6d6a886..b22bd6f508 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field @@ -20,7 +20,7 @@ class IfElseNodeData(BaseNodeData): logical_operator: Literal["and", "or"] conditions: list[Condition] - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) - cases: Optional[list[Case]] = None + cases: list[Case] | None = None diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 82dba59cbe..857b1c6f44 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal from typing_extensions import deprecated @@ -22,7 +22,7 @@ class IfElseNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IfElseNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -31,7 +31,7 @@ class IfElseNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 7a489dd725..9608edb06e 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import Field @@ -17,7 +17,7 @@ class IterationNodeData(BaseIterationNodeData): Iteration Node Data. """ - parent_loop_id: Optional[str] = None # redundant field, not used currently + parent_loop_id: str | None = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector is_parallel: bool = False # open the parallel mode or not @@ -39,7 +39,7 @@ class IterationState(BaseIterationState): """ outputs: list[Any] = Field(default_factory=list) - current_output: Optional[Any] = None + current_output: Any | None = None class MetaData(BaseIterationState.MetaData): """ @@ -48,7 +48,7 @@ class IterationState(BaseIterationState): iterator_length: int - def get_last_output(self) -> Optional[Any]: + def get_last_output(self) -> Any | None: """ Get last output. """ @@ -56,7 +56,7 @@ class IterationState(BaseIterationState): return self.outputs[-1] return None - def get_current_output(self) -> Optional[Any]: + def get_current_output(self) -> Any | None: """ Get current output. """ diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 52eb7fdd75..2cf59bc2fb 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -6,7 +6,7 @@ from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, wait from datetime import datetime from queue import Empty, Queue -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from flask import Flask, current_app @@ -70,7 +70,7 @@ class IterationNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IterationNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -79,7 +79,7 @@ class IterationNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -89,7 +89,7 @@ class IterationNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): return { "type": "iteration", "config": { @@ -424,7 +424,7 @@ class IterationNode(BaseNode): graph_engine: "GraphEngine", iteration_graph: Graph, iter_run_map: dict[str, float], - parallel_mode_run_id: Optional[str] = None, + parallel_mode_run_id: str | None = None, ) -> Generator[NodeEvent | InNodeEvent, None, None]: """ run single iteration diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 8c4794cf37..1a6c9fa908 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -21,7 +21,7 @@ class IterationStartNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IterationStartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -30,7 +30,7 @@ class IterationStartNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index b71271abeb..460290f0ea 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field @@ -49,11 +49,11 @@ class MultipleRetrievalConfig(BaseModel): """ top_k: int - score_threshold: Optional[float] = None + score_threshold: float | None = None reranking_mode: str = "reranking_model" reranking_enable: bool = True - reranking_model: Optional[RerankingModelConfig] = None - weights: Optional[WeightedScoreConfig] = None + reranking_model: RerankingModelConfig | None = None + weights: WeightedScoreConfig | None = None class SingleRetrievalConfig(BaseModel): @@ -104,8 +104,8 @@ class MetadataFilteringCondition(BaseModel): Metadata Filtering Condition. """ - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) class KnowledgeRetrievalNodeData(BaseNodeData): @@ -117,11 +117,11 @@ class KnowledgeRetrievalNodeData(BaseNodeData): query_variable_selector: list[str] dataset_ids: list[str] retrieval_mode: Literal["single", "multiple"] - multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None - single_retrieval_config: Optional[SingleRetrievalConfig] = None - metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_model_config: Optional[ModelConfig] = None - metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None + multiple_retrieval_config: MultipleRetrievalConfig | None = None + single_retrieval_config: SingleRetrievalConfig | None = None + metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled" + metadata_model_config: ModelConfig | None = None + metadata_filtering_conditions: MetadataFilteringCondition | None = None vision: VisionConfig = Field(default_factory=VisionConfig) @property diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 3a77541807..99e1ba6d28 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,7 +4,7 @@ import re import time from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import Float, and_, func, or_, select, text from sqlalchemy import cast as sqlalchemy_cast @@ -101,8 +101,8 @@ class KnowledgeRetrievalNode(BaseNode): graph_init_params: "GraphInitParams", graph: "Graph", graph_runtime_state: "GraphRuntimeState", - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None, + previous_node_id: str | None = None, + thread_pool_id: str | None = None, *, llm_file_saver: LLMFileSaver | None = None, ): @@ -128,7 +128,7 @@ class KnowledgeRetrievalNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = KnowledgeRetrievalNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -137,7 +137,7 @@ class KnowledgeRetrievalNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -419,7 +419,7 @@ class KnowledgeRetrievalNode(BaseNode): def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: + ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: document_query = db.session.query(Document).where( Document.dataset_id.in_(dataset_ids), Document.indexing_status == "completed", @@ -576,7 +576,7 @@ class KnowledgeRetrievalNode(BaseNode): return automatic_metadata_filters def _process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list + self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list ): if value is None and condition not in ("empty", "not empty"): return diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index cf46870254..5d3b07a8f3 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Mapping, Sequence -from typing import Any, Optional, TypeAlias, TypeVar +from typing import Any, TypeAlias, TypeVar from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -44,7 +44,7 @@ class ListOperatorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = ListOperatorNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -53,7 +53,7 @@ class ListOperatorNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 222914351e..3dfb1ce28e 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator @@ -18,7 +18,7 @@ class ModelConfig(BaseModel): class ContextConfig(BaseModel): enabled: bool - variable_selector: Optional[list[str]] = None + variable_selector: list[str] | None = None class VisionConfigOptions(BaseModel): @@ -51,18 +51,18 @@ class PromptConfig(BaseModel): class LLMNodeChatModelMessage(ChatModelMessage): text: str = "" - jinja2_text: Optional[str] = None + jinja2_text: str | None = None class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - jinja2_text: Optional[str] = None + jinja2_text: str | None = None class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_config: PromptConfig = Field(default_factory=PromptConfig) - memory: Optional[MemoryConfig] = None + memory: MemoryConfig | None = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) structured_output: Mapping[str, Any] | None = None diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index fae127ab76..ce6bb441ab 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Optional, cast +from typing import cast from sqlalchemy import select, update from sqlalchemy.orm import Session @@ -86,8 +86,8 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc def fetch_memory( - variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance -) -> Optional[TokenBufferMemory]: + variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance +) -> TokenBufferMemory | None: if not node_data_memory: return None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index fdcdac1ec2..9ae4f275fb 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -4,7 +4,7 @@ import json import logging import re from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager @@ -116,8 +116,8 @@ class LLMNode(BaseNode): graph_init_params: "GraphInitParams", graph: "Graph", graph_runtime_state: "GraphRuntimeState", - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None, + previous_node_id: str | None = None, + thread_pool_id: str | None = None, *, llm_file_saver: LLMFileSaver | None = None, ): @@ -143,7 +143,7 @@ class LLMNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LLMNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -152,7 +152,7 @@ class LLMNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -166,7 +166,7 @@ class LLMNode(BaseNode): return "1" def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: - node_inputs: Optional[dict[str, Any]] = None + node_inputs: dict[str, Any] | None = None process_data = None result_text = "" usage = LLMUsage.empty_usage() @@ -353,10 +353,10 @@ class LLMNode(BaseNode): node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], - stop: Optional[Sequence[str]] = None, + stop: Sequence[str] | None = None, user_id: str, structured_output_enabled: bool, - structured_output: Optional[Mapping[str, Any]] = None, + structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, file_outputs: list["File"], node_id: str, @@ -708,7 +708,7 @@ class LLMNode(BaseNode): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, - ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: + ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] if isinstance(prompt_template, list): @@ -951,7 +951,7 @@ class LLMNode(BaseNode): return variable_mapping @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): return { "type": "llm", "config": { @@ -979,7 +979,7 @@ class LLMNode(BaseNode): def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], - context: Optional[str], + context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, @@ -1174,7 +1174,7 @@ class LLMNode(BaseNode): def _combine_message_content_with_role( - *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole + *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole ): match role: case PromptMessageRole.USER: @@ -1280,7 +1280,7 @@ def _handle_memory_completion_mode( def _handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, - context: Optional[str], + context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, ) -> Sequence[PromptMessage]: diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 3ed4d21ba5..c875b4202e 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Field @@ -35,7 +35,7 @@ class LoopVariableData(BaseModel): label: str var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] value_type: Literal["variable", "constant"] - value: Optional[Any | list[str]] = None + value: Any | list[str] | None = None class LoopNodeData(BaseLoopNodeData): @@ -46,8 +46,8 @@ class LoopNodeData(BaseLoopNodeData): loop_count: int # Maximum number of loops break_conditions: list[Condition] # Conditions to break the loop logical_operator: Literal["and", "or"] - loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData]) - outputs: Optional[Mapping[str, Any]] = None + loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData]) + outputs: Mapping[str, Any] | None = None class LoopStartNodeData(BaseNodeData): @@ -72,7 +72,7 @@ class LoopState(BaseLoopState): """ outputs: list[Any] = Field(default_factory=list) - current_output: Optional[Any] = None + current_output: Any | None = None class MetaData(BaseLoopState.MetaData): """ @@ -81,7 +81,7 @@ class LoopState(BaseLoopState): loop_length: int - def get_last_output(self) -> Optional[Any]: + def get_last_output(self) -> Any | None: """ Get last output. """ @@ -89,7 +89,7 @@ class LoopState(BaseLoopState): return self.outputs[-1] return None - def get_current_output(self) -> Optional[Any]: + def get_current_output(self) -> Any | None: """ Get current output. """ diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 892ae88b04..e2940ae004 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -21,7 +21,7 @@ class LoopEndNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopEndNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -30,7 +30,7 @@ class LoopEndNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 2fe3fb5567..753963dc90 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -3,7 +3,7 @@ import logging import time from collections.abc import Generator, Mapping, Sequence from datetime import datetime -from typing import TYPE_CHECKING, Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, cast from configs import dify_config from core.variables import ( @@ -57,7 +57,7 @@ class LoopNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -66,7 +66,7 @@ class LoopNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index f5a20fc009..07e98a494f 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -21,7 +21,7 @@ class LoopStartNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopStartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -30,7 +30,7 @@ class LoopStartNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 2739140224..2dc0aabe3c 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any, Literal from pydantic import ( BaseModel, @@ -50,7 +50,7 @@ class ParameterConfig(BaseModel): name: str type: Annotated[SegmentType, BeforeValidator(_validate_type)] - options: Optional[list[str]] = None + options: list[str] | None = None description: str required: bool @@ -88,8 +88,8 @@ class ParameterExtractorNodeData(BaseNodeData): model: ModelConfig query: list[str] parameters: list[ParameterConfig] - instruction: Optional[str] = None - memory: Optional[MemoryConfig] = None + instruction: str | None = None + memory: MemoryConfig | None = None reasoning_mode: Literal["function_call", "prompt"] vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 1e1c10a11a..51d9a2d2e9 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -3,7 +3,7 @@ import json import logging import uuid from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import File @@ -98,7 +98,7 @@ class ParameterExtractorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = ParameterExtractorNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -107,7 +107,7 @@ class ParameterExtractorNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -116,11 +116,11 @@ class ParameterExtractorNode(BaseNode): def get_base_node_data(self) -> BaseNodeData: return self._node_data - _model_instance: Optional[ModelInstance] = None - _model_config: Optional[ModelConfigWithCredentialsEntity] = None + _model_instance: ModelInstance | None = None + _model_config: ModelConfigWithCredentialsEntity | None = None @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): return { "model": { "prompt_templates": { @@ -295,7 +295,7 @@ class ParameterExtractorNode(BaseNode): prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], stop: list[str], - ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: + ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=node_data_model.completion_params, @@ -330,9 +330,9 @@ class ParameterExtractorNode(BaseNode): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. @@ -412,9 +412,9 @@ class ParameterExtractorNode(BaseNode): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate prompt engineering prompt. @@ -450,9 +450,9 @@ class ParameterExtractorNode(BaseNode): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate completion prompt. @@ -484,9 +484,9 @@ class ParameterExtractorNode(BaseNode): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate chat prompt. @@ -657,7 +657,7 @@ class ParameterExtractorNode(BaseNode): return transformed_result - def _extract_complete_json_response(self, result: str) -> Optional[dict]: + def _extract_complete_json_response(self, result: str) -> dict | None: """ Extract complete json response. """ @@ -672,7 +672,7 @@ class ParameterExtractorNode(BaseNode): logger.info("extra error: %s", result) return None - def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: + def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None: """ Extract json from tool call. """ @@ -711,7 +711,7 @@ class ParameterExtractorNode(BaseNode): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ) -> list[ChatModelMessage]: model_mode = ModelMode(node_data.model.mode) @@ -738,7 +738,7 @@ class ParameterExtractorNode(BaseNode): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ): model_mode = ModelMode(node_data.model.mode) @@ -774,7 +774,7 @@ class ParameterExtractorNode(BaseNode): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - context: Optional[str], + context: str | None, ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index 6248df0edf..edde30708a 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig @@ -16,8 +14,8 @@ class QuestionClassifierNodeData(BaseNodeData): query_variable_selector: list[str] model: ModelConfig classes: list[ClassConfig] - instruction: Optional[str] = None - memory: Optional[MemoryConfig] = None + instruction: str | None = None + memory: MemoryConfig | None = None vision: VisionConfig = Field(default_factory=VisionConfig) @property diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 07fb658f24..b15193ecde 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -59,8 +59,8 @@ class QuestionClassifierNode(BaseNode): graph_init_params: "GraphInitParams", graph: "Graph", graph_runtime_state: "GraphRuntimeState", - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None, + previous_node_id: str | None = None, + thread_pool_id: str | None = None, *, llm_file_saver: LLMFileSaver | None = None, ): @@ -86,7 +86,7 @@ class QuestionClassifierNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = QuestionClassifierNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -95,7 +95,7 @@ class QuestionClassifierNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -275,7 +275,7 @@ class QuestionClassifierNode(BaseNode): return variable_mapping @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): """ Get default config of node. :param filters: filter by node config parameters. @@ -288,7 +288,7 @@ class QuestionClassifierNode(BaseNode): node_data: QuestionClassifierNodeData, query: str, model_config: ModelConfigWithCredentialsEntity, - context: Optional[str], + context: str | None, ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_template = self._get_prompt_template(node_data, query, None, 2000) @@ -331,7 +331,7 @@ class QuestionClassifierNode(BaseNode): self, node_data: QuestionClassifierNodeData, query: str, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ): model_mode = ModelMode(node_data.model.mode) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 6052774e6c..5015d59ccc 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunResult @@ -18,7 +18,7 @@ class StartNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = StartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -27,7 +27,7 @@ class StartNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 5588463a36..761854045c 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,6 +1,6 @@ import os from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult @@ -21,7 +21,7 @@ class TemplateTransformNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = TemplateTransformNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -30,7 +30,7 @@ class TemplateTransformNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -40,7 +40,7 @@ class TemplateTransformNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): """ Get default config of node. :param filters: filter by node config parameters. diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index c4caf5d83b..53632f43c6 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional +from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session @@ -439,7 +439,7 @@ class ToolNode(BaseNode): return result - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -448,7 +448,7 @@ class ToolNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index f4577d7573..13dbc5dbe6 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel from core.variables.types import SegmentType @@ -33,4 +31,4 @@ class VariableAssignerNodeData(BaseNodeData): type: str = "variable-assigner" output_type: str variables: list[list[str]] - advanced_settings: Optional[AdvancedSettings] = None + advanced_settings: AdvancedSettings | None = None diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index cc5092d0a9..1c1817496f 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.variables.segments import Segment from core.workflow.entities.node_entities import NodeRunResult @@ -18,7 +18,7 @@ class VariableAggregatorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -27,7 +27,7 @@ class VariableAggregatorNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 263c5a3893..8cf9e82d3b 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias from core.variables import SegmentType, Variable from core.variables.segments import BooleanSegment @@ -33,7 +33,7 @@ class VariableAssignerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -42,7 +42,7 @@ class VariableAssignerNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -58,8 +58,8 @@ class VariableAssignerNode(BaseNode): graph_init_params: "GraphInitParams", graph: "Graph", graph_runtime_state: "GraphRuntimeState", - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None, + previous_node_id: str | None = None, + thread_pool_id: str | None = None, conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, ): super().__init__( diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index fdd155adf9..9915b842f7 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable @@ -61,7 +61,7 @@ class VariableAssignerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -70,7 +70,7 @@ class VariableAssignerNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py index f4668c05c5..8148934b0e 100644 --- a/api/core/workflow/repositories/workflow_node_execution_repository.py +++ b/api/core/workflow/repositories/workflow_node_execution_repository.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Literal, Optional, Protocol +from typing import Literal, Protocol from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution @@ -10,7 +10,7 @@ class OrderConfig: """Configuration for ordering NodeExecution instances.""" order_by: list[str] - order_direction: Optional[Literal["asc", "desc"]] = None + order_direction: Literal["asc", "desc"] | None = None class WorkflowNodeExecutionRepository(Protocol): @@ -42,7 +42,7 @@ class WorkflowNodeExecutionRepository(Protocol): def get_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 4f259b64a2..0410b843b9 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional, Union +from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( @@ -83,9 +83,9 @@ class WorkflowCycleManager: total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, - external_trace_id: Optional[str] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) @@ -110,9 +110,9 @@ class WorkflowCycleManager: total_steps: int, outputs: Mapping[str, Any] | None = None, exceptions_count: int = 0, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, - external_trace_id: Optional[str] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) @@ -138,10 +138,10 @@ class WorkflowCycleManager: total_steps: int, status: WorkflowExecutionStatus, error_message: str, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, exceptions_count: int = 0, - external_trace_id: Optional[str] = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) now = naive_utc_now() @@ -296,9 +296,9 @@ class WorkflowCycleManager: total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, - error_message: Optional[str] = None, + error_message: str | None = None, exceptions_count: int = 0, - finished_at: Optional[datetime] = None, + finished_at: datetime | None = None, ): """Update workflow execution with completion data.""" execution.status = status @@ -312,10 +312,10 @@ class WorkflowCycleManager: def _add_trace_task_if_needed( self, - trace_manager: Optional[TraceQueueManager], + trace_manager: TraceQueueManager | None, workflow_execution: WorkflowExecution, - conversation_id: Optional[str], - external_trace_id: Optional[str], + conversation_id: str | None, + external_trace_id: str | None, ): """Add trace task if trace manager is provided.""" if trace_manager: @@ -357,8 +357,8 @@ class WorkflowCycleManager: workflow_execution: WorkflowExecution, event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent], status: WorkflowNodeExecutionStatus, - error: Optional[str] = None, - created_at: Optional[datetime] = None, + error: str | None = None, + created_at: datetime | None = None, ) -> WorkflowNodeExecution: """Create a node execution from an event.""" now = naive_utc_now() @@ -404,7 +404,7 @@ class WorkflowCycleManager: QueueNodeExceptionEvent, ], status: WorkflowNodeExecutionStatus, - error: Optional[str] = None, + error: str | None = None, handle_special_values: bool = False, ): """Update node execution with completion data.""" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 7aab4037c6..ecad75b1ca 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional +from typing import Any from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError @@ -47,7 +47,7 @@ class WorkflowEntry: invoke_from: InvokeFrom, call_depth: int, variable_pool: VariablePool, - thread_pool_id: Optional[str] = None, + thread_pool_id: str | None = None, ): """ Init workflow entry @@ -311,7 +311,7 @@ class WorkflowEntry: raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) @staticmethod - def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: + def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: # NOTE(QuantumGhost): Avoid using this function in new code. # Keep values structured as long as possible and only convert to dict # immediately before serialization (e.g., JSON serialization) to maintain diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 21fd0b9c5b..c318684b2f 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -1,7 +1,7 @@ import logging import time as time_module from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel from sqlalchemy import update @@ -33,7 +33,7 @@ def _get_provider_cache_key(tenant_id: str, provider_name: str) -> str: @redis_fallback(default_return=None) -def _get_last_update_timestamp(cache_key: str) -> Optional[datetime]: +def _get_last_update_timestamp(cache_key: str) -> datetime | None: """Get last update timestamp from Redis cache.""" timestamp_str = redis_client.get(cache_key) if timestamp_str: @@ -52,8 +52,8 @@ class _ProviderUpdateFilters(BaseModel): tenant_id: str provider_name: str - provider_type: Optional[str] = None - quota_type: Optional[str] = None + provider_type: str | None = None + quota_type: str | None = None class _ProviderUpdateAdditionalFilters(BaseModel): @@ -65,8 +65,8 @@ class _ProviderUpdateAdditionalFilters(BaseModel): class _ProviderUpdateValues(BaseModel): """Values to update in Provider records.""" - last_used: Optional[datetime] = None - quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression + last_used: datetime | None = None + quota_used: Any | None = None # Can be Provider.quota_used + int expression class _ProviderUpdateOperation(BaseModel): @@ -182,7 +182,7 @@ def handle(sender: Message, **kwargs): def _calculate_quota_usage( *, message: Message, system_configuration: SystemConfiguration, model_name: str -) -> Optional[int]: +) -> int | None: """Calculate quota usage based on message tokens and quota type.""" quota_unit = None for quota_configuration in system_configuration.quota_configurations: diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index fb5352ca8f..1884f7f9a9 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,6 +1,6 @@ import ssl from datetime import timedelta -from typing import Any, Optional +from typing import Any import pytz from celery import Celery, Task @@ -10,7 +10,7 @@ from configs import dify_config from dify_app import DifyApp -def _get_celery_ssl_options() -> Optional[dict[str, Any]]: +def _get_celery_ssl_options() -> dict[str, Any] | None: """Get SSL configuration for Celery broker/backend connections.""" # Use REDIS_USE_SSL for consistency with the main Redis client # Only apply SSL if we're using Redis as broker/backend diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index 58ab023559..042bf8cc47 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from flask import Flask @@ -68,7 +67,7 @@ class Mail: case _: raise ValueError(f"Unsupported mail type {mail_type}") - def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): + def send(self, to: str, subject: str, html: str, from_: str | None = None): if not self._client: raise ValueError("Mail client is not initialized") diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 61b26b5b95..487917b2a7 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -3,7 +3,7 @@ import logging import ssl from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import redis from redis import RedisError @@ -246,7 +246,7 @@ def init_app(app: DifyApp): app.extensions["redis"] = redis_client -def redis_fallback(default_return: Optional[Any] = None): +def redis_fallback(default_return: Any | None = None): """ decorator to handle Redis operation exceptions and return a default value when Redis is unavailable. diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 7ec0889776..9053aece89 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -1,6 +1,5 @@ from collections.abc import Generator from datetime import timedelta -from typing import Optional from azure.identity import ChainedTokenCredential, DefaultAzureCredential from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas @@ -21,7 +20,7 @@ class AzureBlobStorage(BaseStorage): self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY - self.credential: Optional[ChainedTokenCredential] = None + self.credential: ChainedTokenCredential | None = None if self.account_key == "managedidentity": self.credential = DefaultAzureCredential() else: diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 33fa7d0a8d..2ffac9a92d 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -10,7 +10,6 @@ import tempfile from collections.abc import Generator from io import BytesIO from pathlib import Path -from typing import Optional import clickzetta # type: ignore[import] from pydantic import BaseModel, model_validator @@ -33,7 +32,7 @@ class ClickZettaVolumeConfig(BaseModel): vcluster: str = "default_ap" schema_name: str = "dify" volume_type: str = "table" # table|user|external - volume_name: Optional[str] = None # For external volumes + volume_name: str | None = None # For external volumes table_prefix: str = "dataset_" # Prefix for table volume names dify_prefix: str = "dify_km" # Directory prefix for User Volume permission_check: bool = True # Enable/disable permission checking @@ -154,7 +153,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.exception("Failed to initialize permission manager") raise - def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str: + def _get_volume_path(self, filename: str, dataset_id: str | None = None) -> str: """Get the appropriate volume path based on volume type.""" if self._config.volume_type == "user": # Add dify prefix for User Volume to organize files @@ -179,7 +178,7 @@ class ClickZettaVolumeStorage(BaseStorage): else: raise ValueError(f"Unsupported volume type: {self._config.volume_type}") - def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str: + def _get_volume_sql_prefix(self, dataset_id: str | None = None) -> str: """Get SQL prefix for volume operations.""" if self._config.volume_type == "user": return "USER VOLUME" diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 43fe771bcd..5d5d3ee295 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -10,7 +10,7 @@ import logging from dataclasses import asdict, dataclass from datetime import datetime from enum import StrEnum, auto -from typing import Any, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -34,9 +34,9 @@ class FileMetadata: modified_at: datetime version: int | None status: FileStatus - checksum: Optional[str] = None - tags: Optional[dict[str, str]] = None - parent_version: Optional[int] = None + checksum: str | None = None + tags: dict[str, str] | None = None + parent_version: int | None = None def to_dict(self): """Convert to dictionary format""" @@ -59,7 +59,7 @@ class FileMetadata: class FileLifecycleManager: """File lifecycle manager""" - def __init__(self, storage, dataset_id: Optional[str] = None): + def __init__(self, storage, dataset_id: str | None = None): """Initialize lifecycle manager Args: @@ -74,9 +74,9 @@ class FileLifecycleManager: self._deleted_prefix = ".deleted/" # Get permission manager (if exists) - self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None) + self._permission_manager: Any | None = getattr(storage, "_permission_manager", None) - def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata: + def save_with_lifecycle(self, filename: str, data: bytes, tags: dict[str, str] | None = None) -> FileMetadata: """Save file and manage lifecycle Args: @@ -150,7 +150,7 @@ class FileLifecycleManager: logger.exception("Failed to save file with lifecycle") raise - def get_file_metadata(self, filename: str) -> Optional[FileMetadata]: + def get_file_metadata(self, filename: str) -> FileMetadata | None: """Get file metadata Args: diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index 2431b08d81..eb1116638f 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -6,7 +6,6 @@ According to ClickZetta's permission model, different Volume types have differen import logging from enum import StrEnum -from typing import Optional logger = logging.getLogger(__name__) @@ -24,7 +23,7 @@ class VolumePermission(StrEnum): class VolumePermissionManager: """Volume permission manager""" - def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None): + def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None): """Initialize permission manager Args: @@ -63,7 +62,7 @@ class VolumePermissionManager: self._permission_cache: dict[str, set[str]] = {} self._current_username = None # Will get current username from connection - def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool: + def check_permission(self, operation: VolumePermission, dataset_id: str | None = None) -> bool: """Check if user has permission to perform specific operation Args: @@ -126,7 +125,7 @@ class VolumePermissionManager: logger.info("User Volume permission check failed, but permission checking is disabled in this version") return False - def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool: + def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: str | None) -> bool: """Check Table Volume permission Table Volume permission rules: @@ -440,7 +439,7 @@ class VolumePermissionManager: self._permission_cache.clear() logger.debug("Permission cache cleared") - def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]: + def get_permission_summary(self, dataset_id: str | None = None) -> dict[str, bool]: """Get permission summary Args: @@ -582,7 +581,7 @@ class VolumePermissionManager: return any(pattern in file_path_lower for pattern in sensitive_patterns) - def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool: + def validate_operation(self, operation: str, dataset_id: str | None = None) -> bool: """Validate operation permission Args: @@ -614,16 +613,14 @@ class VolumePermissionManager: class VolumePermissionError(Exception): """Volume permission error exception""" - def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None): + def __init__(self, message: str, operation: str, volume_type: str, dataset_id: str | None = None): self.operation = operation self.volume_type = volume_type self.dataset_id = dataset_id super().__init__(message) -def check_volume_permission( - permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None -): +def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None): """Permission check decorator function Args: diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index 5258823a07..37ff1a438e 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -8,7 +8,7 @@ eliminates the need for repetitive language switching logic. from dataclasses import dataclass from enum import StrEnum, auto -from typing import Any, Optional, Protocol +from typing import Any, Protocol from flask import render_template from pydantic import BaseModel, Field @@ -171,7 +171,7 @@ class EmailI18nService: email_type: EmailType, language_code: str, to: str, - template_context: Optional[dict[str, Any]] = None, + template_context: dict[str, Any] | None = None, ): """ Send internationalized email with branding support. @@ -515,7 +515,7 @@ def get_default_email_i18n_service() -> EmailI18nService: # Global instance -_email_i18n_service: Optional[EmailI18nService] = None +_email_i18n_service: EmailI18nService | None = None def get_email_i18n_service() -> EmailI18nService: diff --git a/api/libs/exception.py b/api/libs/exception.py index 5970269ecd..73379dfded 100644 --- a/api/libs/exception.py +++ b/api/libs/exception.py @@ -1,11 +1,9 @@ -from typing import Optional - from werkzeug.exceptions import HTTPException class BaseHTTPException(HTTPException): error_code: str = "unknown" - data: Optional[dict] = None + data: dict | None = None def __init__(self, description=None, response=None): super().__init__(description, response) diff --git a/api/libs/helper.py b/api/libs/helper.py index 09dbfac6cb..0551470f65 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -269,8 +269,8 @@ class TokenManager: cls, token_type: str, account: Optional["Account"] = None, - email: Optional[str] = None, - additional_data: Optional[dict] = None, + email: str | None = None, + additional_data: dict | None = None, ) -> str: if account is None and email is None: raise ValueError("Account or email must be provided") @@ -312,19 +312,19 @@ class TokenManager: redis_client.delete(token_key) @classmethod - def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]: + def get_token_data(cls, token: str, token_type: str) -> dict[str, Any] | None: key = cls._get_token_key(token, token_type) token_data_json = redis_client.get(key) if token_data_json is None: logger.warning("%s token %s not found with key %s", token_type, token, key) return None - token_data: Optional[dict[str, Any]] = json.loads(token_data_json) + token_data: dict[str, Any] | None = json.loads(token_data_json) return token_data @classmethod - def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: + def _get_current_token_for_account(cls, account_id: str, token_type: str) -> str | None: key = cls._get_account_token_key(account_id, token_type) - current_token: Optional[str] = redis_client.get(key) + current_token: str | None = redis_client.get(key) return current_token @classmethod diff --git a/api/libs/oauth.py b/api/libs/oauth.py index df75b55019..35bd6c2c7c 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,6 +1,5 @@ import urllib.parse from dataclasses import dataclass -from typing import Optional import requests @@ -41,7 +40,7 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" - def get_authorization_url(self, invite_token: Optional[str] = None): + def get_authorization_url(self, invite_token: str | None = None): params = { "client_id": self.client_id, "redirect_uri": self.redirect_uri, @@ -93,7 +92,7 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" - def get_authorization_url(self, invite_token: Optional[str] = None): + def get_authorization_url(self, invite_token: str | None = None): params = { "client_id": self.client_id, "response_type": "code", diff --git a/api/libs/orjson.py b/api/libs/orjson.py index 2fc5ce8dd3..6e7c6b738d 100644 --- a/api/libs/orjson.py +++ b/api/libs/orjson.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import orjson @@ -6,6 +6,6 @@ import orjson def orjson_dumps( obj: Any, encoding: str = "utf-8", - option: Optional[int] = None, + option: int | None = None, ) -> str: return orjson.dumps(obj, option=option).decode(encoding) diff --git a/api/models/account.py b/api/models/account.py index aefef673df..8c1f990aa2 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -90,24 +90,24 @@ class Account(UserMixin, Base): id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column(String(255)) email: Mapped[str] = mapped_column(String(255)) - password: Mapped[Optional[str]] = mapped_column(String(255)) - password_salt: Mapped[Optional[str]] = mapped_column(String(255)) - avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - interface_language: Mapped[Optional[str]] = mapped_column(String(255)) - interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - timezone: Mapped[Optional[str]] = mapped_column(String(255)) - last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + password: Mapped[str | None] = mapped_column(String(255)) + password_salt: Mapped[str | None] = mapped_column(String(255)) + avatar: Mapped[str | None] = mapped_column(String(255), nullable=True) + interface_language: Mapped[str | None] = mapped_column(String(255)) + interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True) + timezone: Mapped[str | None] = mapped_column(String(255)) + last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True) last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying")) - initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) @reconstructor def init_on_load(self): - self.role: Optional[TenantAccountRole] = None - self._current_tenant: Optional[Tenant] = None + self.role: TenantAccountRole | None = None + self._current_tenant: Tenant | None = None @property def is_password_set(self): @@ -232,10 +232,10 @@ class Tenant(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column(String(255)) - encrypt_public_key: Mapped[Optional[str]] = mapped_column(sa.Text) + encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text) plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying")) status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) - custom_config: Mapped[Optional[str]] = mapped_column(sa.Text) + custom_config: Mapped[str | None] = mapped_column(sa.Text) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) @@ -271,7 +271,7 @@ class TenantAccountJoin(Base): account_id: Mapped[str] = mapped_column(StringUUID) current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) role: Mapped[str] = mapped_column(String(16), server_default="normal") - invited_by: Mapped[Optional[str]] = mapped_column(StringUUID) + invited_by: Mapped[str | None] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) @@ -305,10 +305,10 @@ class InvitationCode(Base): batch: Mapped[str] = mapped_column(String(255)) code: Mapped[str] = mapped_column(String(32)) status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying")) - used_at: Mapped[Optional[datetime]] = mapped_column(DateTime) - used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID) - used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) - deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + used_at: Mapped[datetime | None] = mapped_column(DateTime) + used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID) + used_by_account_id: Mapped[str | None] = mapped_column(StringUUID) + deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/dataset.py b/api/models/dataset.py index 7314945053..662cfeb0d2 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -10,7 +10,7 @@ import re import time from datetime import datetime from json import JSONDecodeError -from typing import Any, Optional, cast +from typing import Any, cast import sqlalchemy as sa from sqlalchemy import DateTime, String, func, select @@ -56,7 +56,7 @@ class Dataset(Base): provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'::character varying")) permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'::character varying")) data_source_type = mapped_column(String(255)) - indexing_technique: Mapped[Optional[str]] = mapped_column(String(255)) + indexing_technique: Mapped[str | None] = mapped_column(String(255)) index_struct = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -330,42 +330,42 @@ class Document(Base): created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) # start processing - processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + processing_started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # parsing file_id = mapped_column(sa.Text, nullable=True) - word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable - parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable + parsing_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # cleaning - cleaning_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + cleaning_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # split - splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + splitting_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # indexing - tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) - indexing_latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + indexing_latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # pause - is_paused: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) + is_paused: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) paused_by = mapped_column(StringUUID, nullable=True) - paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + paused_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # error error = mapped_column(sa.Text, nullable=True) - stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # basic fields indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'::character varying")) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) - disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) archived: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) archived_reason = mapped_column(String(255), nullable=True) archived_by = mapped_column(StringUUID, nullable=True) - archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) doc_type = mapped_column(String(40), nullable=True) doc_metadata = mapped_column(JSONB, nullable=True) @@ -677,17 +677,17 @@ class DocumentSegment(Base): # basic fields hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) - disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'::character varying")) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) error = mapped_column(sa.Text, nullable=True) - stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) @property def dataset(self): @@ -829,8 +829,8 @@ class ChildChunk(Base): updated_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) - indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) error = mapped_column(sa.Text, nullable=True) @property diff --git a/api/models/model.py b/api/models/model.py index 50c07268dd..928508cc48 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -76,9 +76,9 @@ class App(Base): name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying")) mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji + icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji icon = mapped_column(String(255)) - icon_background: Mapped[Optional[str]] = mapped_column(String(255)) + icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) @@ -90,7 +90,7 @@ class App(Base): is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) tracing = mapped_column(sa.Text, nullable=True) - max_active_requests: Mapped[Optional[int]] + max_active_requests: Mapped[int | None] created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -134,7 +134,7 @@ class App(Base): return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" @property - def tenant(self) -> Optional[Tenant]: + def tenant(self) -> Tenant | None: tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @@ -291,7 +291,7 @@ class App(Base): return tags or [] @property - def author_name(self) -> Optional[str]: + def author_name(self) -> str | None: if self.created_by: account = db.session.query(Account).where(Account.id == self.created_by).first() if account: @@ -334,7 +334,7 @@ class AppModelConfig(Base): file_upload = mapped_column(sa.Text) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @@ -546,7 +546,7 @@ class RecommendedApp(Base): updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @@ -570,12 +570,12 @@ class InstalledApp(Base): created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @property - def tenant(self) -> Optional[Tenant]: + def tenant(self) -> Tenant | None: tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @@ -711,7 +711,7 @@ class Conversation(Base): @property def model_config(self): model_config = {} - app_model_config: Optional[AppModelConfig] = None + app_model_config: AppModelConfig | None = None if self.mode == AppMode.ADVANCED_CHAT: if self.override_model_configs: @@ -845,7 +845,7 @@ class Conversation(Base): ) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: return db.session.query(App).where(App.id == self.app_id).first() @property @@ -858,7 +858,7 @@ class Conversation(Base): return None @property - def from_account_name(self) -> Optional[str]: + def from_account_name(self) -> str | None: if self.from_account_id: account = db.session.query(Account).where(Account.id == self.from_account_id).first() if account: @@ -933,14 +933,14 @@ class Message(Base): status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) error = mapped_column(sa.Text) message_metadata = mapped_column(sa.Text) - invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True) from_source: Mapped[str] = mapped_column(String(255), nullable=False) - from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) - from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) + from_end_user_id: Mapped[str | None] = mapped_column(StringUUID) + from_account_id: Mapped[str | None] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) @property def inputs(self) -> dict[str, Any]: @@ -1337,9 +1337,9 @@ class MessageFile(Base): message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) transfer_method: Mapped[str] = mapped_column(String(255), nullable=False) - url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) - belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + url: Mapped[str | None] = mapped_column(sa.Text, nullable=True) + belongs_to: Mapped[str | None] = mapped_column(String(255), nullable=True) + upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1356,8 +1356,8 @@ class MessageAnnotation(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id: Mapped[str] = mapped_column(StringUUID) - conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) - message_id: Mapped[Optional[str]] = mapped_column(StringUUID) + conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) + message_id: Mapped[str | None] = mapped_column(StringUUID) question = mapped_column(sa.Text, nullable=True) content = mapped_column(sa.Text, nullable=False) hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) @@ -1729,18 +1729,18 @@ class MessageAgentThought(Base): # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design tool_process_data = mapped_column(sa.Text, nullable=True) message = mapped_column(sa.Text, nullable=True) - message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) message_unit_price = mapped_column(sa.Numeric, nullable=True) message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) message_files = mapped_column(sa.Text, nullable=True) answer = mapped_column(sa.Text, nullable=True) - answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) answer_unit_price = mapped_column(sa.Numeric, nullable=True) answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) - tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) total_price = mapped_column(sa.Numeric, nullable=True) currency = mapped_column(String, nullable=True) - latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1838,11 +1838,11 @@ class DatasetRetrieverResource(Base): document_name = mapped_column(sa.Text, nullable=False) data_source_type = mapped_column(sa.Text, nullable=True) segment_id = mapped_column(StringUUID, nullable=True) - score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + score: Mapped[float | None] = mapped_column(sa.Float, nullable=True) content = mapped_column(sa.Text, nullable=False) - hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) - word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) - segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + hit_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + segment_position: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) index_node_hash = mapped_column(sa.Text, nullable=True) retriever_from = mapped_column(sa.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) diff --git a/api/models/provider.py b/api/models/provider.py index 17094f6d6e..aacc6e505a 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,7 +1,6 @@ from datetime import datetime from enum import StrEnum, auto from functools import cached_property -from typing import Optional import sqlalchemy as sa from sqlalchemy import DateTime, String, func, text @@ -63,14 +62,14 @@ class Provider(Base): String(40), nullable=False, server_default=text("'custom'::character varying") ) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) - last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - quota_type: Mapped[Optional[str]] = mapped_column( + quota_type: Mapped[str | None] = mapped_column( String(40), nullable=True, server_default=text("''::character varying") ) - quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True) - quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0) + quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True) + quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, default=0) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -133,7 +132,7 @@ class ProviderModel(Base): provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) - credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -201,17 +200,17 @@ class ProviderOrder(Base): provider_name: Mapped[str] = mapped_column(String(255), nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False) - payment_id: Mapped[Optional[str]] = mapped_column(String(191)) - transaction_id: Mapped[Optional[str]] = mapped_column(String(191)) + payment_id: Mapped[str | None] = mapped_column(String(191)) + transaction_id: Mapped[str | None] = mapped_column(String(191)) quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) - currency: Mapped[Optional[str]] = mapped_column(String(40)) - total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer) + currency: Mapped[str | None] = mapped_column(String(40)) + total_amount: Mapped[int | None] = mapped_column(sa.Integer) payment_status: Mapped[str] = mapped_column( String(40), nullable=False, server_default=text("'wait_pay'::character varying") ) - paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime) - pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) - refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + paid_at: Mapped[datetime | None] = mapped_column(DateTime) + pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime) + refunded_at: Mapped[datetime | None] = mapped_column(DateTime) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -255,9 +254,9 @@ class LoadBalancingModelConfig(Base): model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) - credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) - credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True) + encrypted_config: Mapped[str | None] = mapped_column(sa.Text, nullable=True) + credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 8456d65a87..5b4c486bc4 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,6 +1,5 @@ import json from datetime import datetime -from typing import Optional import sqlalchemy as sa from sqlalchemy import DateTime, String, func @@ -27,7 +26,7 @@ class DataSourceOauthBinding(Base): source_info = mapped_column(JSONB, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) + disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) class DataSourceApiKeyAuthBinding(Base): @@ -45,7 +44,7 @@ class DataSourceApiKeyAuthBinding(Base): credentials = mapped_column(sa.Text, nullable=True) # JSON created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) + disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) def to_dict(self): return { diff --git a/api/models/task.py b/api/models/task.py index 9a52fcfb41..3da1674536 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional import sqlalchemy as sa from celery import states @@ -32,7 +31,7 @@ class CeleryTask(Base): args = mapped_column(sa.LargeBinary, nullable=True) kwargs = mapped_column(sa.LargeBinary, nullable=True) worker = mapped_column(String(155), nullable=True) - retries: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + retries: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) queue = mapped_column(String(155), nullable=True) @@ -46,4 +45,4 @@ class CeleryTaskSet(Base): ) taskset_id = mapped_column(String(155), unique=True) result = mapped_column(db.PickleType, nullable=True) - date_done: Mapped[Optional[datetime]] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) + date_done: Mapped[datetime | None] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 96ad76eae5..4f50e9e619 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Any, Optional, cast +from typing import Any, cast from urllib.parse import urlparse import sqlalchemy as sa @@ -487,13 +487,13 @@ class ToolFile(TypeBase): # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID) # conversation id - conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) # file key file_key: Mapped[str] = mapped_column(String(255), nullable=False) # mime type mimetype: Mapped[str] = mapped_column(String(255), nullable=False) # original url - original_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True, default=None) + original_url: Mapped[str | None] = mapped_column(String(2048), nullable=True, default=None) # name name: Mapped[str] = mapped_column(default="") # size diff --git a/api/models/workflow.py b/api/models/workflow.py index 78582dd3fb..9d129a09e2 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -3,7 +3,7 @@ import logging from collections.abc import Mapping, Sequence from datetime import datetime from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -130,7 +130,7 @@ class Workflow(Base): _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) + updated_by: Mapped[str | None] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, @@ -499,18 +499,18 @@ class WorkflowRun(Base): type: Mapped[str] = mapped_column(String(255)) triggered_from: Mapped[str] = mapped_column(String(255)) version: Mapped[str] = mapped_column(String(255)) - graph: Mapped[Optional[str]] = mapped_column(sa.Text) - inputs: Mapped[Optional[str]] = mapped_column(sa.Text) + graph: Mapped[str | None] = mapped_column(sa.Text) + inputs: Mapped[str | None] = mapped_column(sa.Text) status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded - outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error: Mapped[Optional[str]] = mapped_column(sa.Text) + outputs: Mapped[str | None] = mapped_column(sa.Text, default="{}") + error: Mapped[str | None] = mapped_column(sa.Text) elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + finished_at: Mapped[datetime | None] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) @property @@ -706,24 +706,24 @@ class WorkflowNodeExecutionModel(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) triggered_from: Mapped[str] = mapped_column(String(255)) - workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) index: Mapped[int] = mapped_column(sa.Integer) - predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255)) - node_execution_id: Mapped[Optional[str]] = mapped_column(String(255)) + predecessor_node_id: Mapped[str | None] = mapped_column(String(255)) + node_execution_id: Mapped[str | None] = mapped_column(String(255)) node_id: Mapped[str] = mapped_column(String(255)) node_type: Mapped[str] = mapped_column(String(255)) title: Mapped[str] = mapped_column(String(255)) - inputs: Mapped[Optional[str]] = mapped_column(sa.Text) - process_data: Mapped[Optional[str]] = mapped_column(sa.Text) - outputs: Mapped[Optional[str]] = mapped_column(sa.Text) + inputs: Mapped[str | None] = mapped_column(sa.Text) + process_data: Mapped[str | None] = mapped_column(sa.Text) + outputs: Mapped[str | None] = mapped_column(sa.Text) status: Mapped[str] = mapped_column(String(255)) - error: Mapped[Optional[str]] = mapped_column(sa.Text) + error: Mapped[str | None] = mapped_column(sa.Text) elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) - execution_metadata: Mapped[Optional[str]] = mapped_column(sa.Text) + execution_metadata: Mapped[str | None] = mapped_column(sa.Text) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) created_by_role: Mapped[str] = mapped_column(String(255)) created_by: Mapped[str] = mapped_column(StringUUID) - finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + finished_at: Mapped[datetime | None] = mapped_column(DateTime) @property def created_by_account(self): diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 00a2d1f87d..fa2c94b623 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -11,7 +11,7 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m from collections.abc import Sequence from datetime import datetime -from typing import Optional, Protocol +from typing import Protocol from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel @@ -44,7 +44,7 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr app_id: str, workflow_id: str, node_id: str, - ) -> Optional[WorkflowNodeExecutionModel]: + ) -> WorkflowNodeExecutionModel | None: """ Get the most recent execution for a specific node. @@ -87,8 +87,8 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr def get_execution_by_id( self, execution_id: str, - tenant_id: Optional[str] = None, - ) -> Optional[WorkflowNodeExecutionModel]: + tenant_id: str | None = None, + ) -> WorkflowNodeExecutionModel | None: """ Get a workflow node execution by its ID. diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 59e7baeb79..3ac28fad75 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -36,7 +36,7 @@ Example: from collections.abc import Sequence from datetime import datetime -from typing import Optional, Protocol +from typing import Protocol from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -58,7 +58,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): app_id: str, triggered_from: str, limit: int = 20, - last_id: Optional[str] = None, + last_id: str | None = None, ) -> InfiniteScrollPagination: """ Get paginated workflow runs with filtering. @@ -90,7 +90,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): tenant_id: str, app_id: str, run_id: str, - ) -> Optional[WorkflowRun]: + ) -> WorkflowRun | None: """ Get a specific workflow run by ID. diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index e6a23ddf9f..cbb09af542 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -7,7 +7,6 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. from collections.abc import Sequence from datetime import datetime -from typing import Optional from sqlalchemy import delete, desc, select from sqlalchemy.orm import Session, sessionmaker @@ -49,7 +48,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut app_id: str, workflow_id: str, node_id: str, - ) -> Optional[WorkflowNodeExecutionModel]: + ) -> WorkflowNodeExecutionModel | None: """ Get the most recent execution for a specific node. @@ -116,8 +115,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut def get_execution_by_id( self, execution_id: str, - tenant_id: Optional[str] = None, - ) -> Optional[WorkflowNodeExecutionModel]: + tenant_id: str | None = None, + ) -> WorkflowNodeExecutionModel | None: """ Get a workflow node execution by its ID. diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 6294846f5e..205f8c87ee 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -22,7 +22,6 @@ Implementation Notes: import logging from collections.abc import Sequence from datetime import datetime -from typing import Optional from sqlalchemy import delete, select from sqlalchemy.orm import Session, sessionmaker @@ -61,7 +60,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): app_id: str, triggered_from: str, limit: int = 20, - last_id: Optional[str] = None, + last_id: str | None = None, ) -> InfiniteScrollPagination: """ Get paginated workflow runs with filtering. @@ -107,7 +106,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): tenant_id: str, app_id: str, run_id: str, - ) -> Optional[WorkflowRun]: + ) -> WorkflowRun | None: """ Get a specific workflow run by ID with tenant and app isolation. """ diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 2b1e6b47cc..9efd46ba5d 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -1,6 +1,6 @@ import datetime import time -from typing import Optional, TypedDict +from typing import TypedDict import click from sqlalchemy import func, select @@ -17,7 +17,7 @@ from services.feature_service import FeatureService class CleanupConfig(TypedDict): clean_day: datetime.datetime - plan_filter: Optional[str] + plan_filter: str | None add_logs: bool diff --git a/api/services/account_service.py b/api/services/account_service.py index b10887e88e..8362e415c1 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -5,7 +5,7 @@ import secrets import uuid from datetime import UTC, datetime, timedelta from hashlib import sha256 -from typing import Any, Optional, cast +from typing import Any, cast from pydantic import BaseModel from sqlalchemy import func @@ -171,7 +171,7 @@ class AccountService: return token @staticmethod - def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account: + def authenticate(email: str, password: str, invite_token: str | None = None) -> Account: """authenticate account with email and password""" account = db.session.query(Account).filter_by(email=email).first() @@ -228,9 +228,9 @@ class AccountService: email: str, name: str, interface_language: str, - password: Optional[str] = None, + password: str | None = None, interface_theme: str = "light", - is_setup: Optional[bool] = False, + is_setup: bool | None = False, ) -> Account: """create account""" if not FeatureService.get_system_features().is_allow_register and not is_setup: @@ -276,7 +276,7 @@ class AccountService: @staticmethod def create_account_and_tenant( - email: str, name: str, interface_language: str, password: Optional[str] = None + email: str, name: str, interface_language: str, password: str | None = None ) -> Account: """create account""" account = AccountService.create_account( @@ -330,7 +330,7 @@ class AccountService: """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: Optional[AccountIntegrate] = ( + account_integrate: AccountIntegrate | None = ( db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() ) @@ -391,7 +391,7 @@ class AccountService: db.session.commit() @staticmethod - def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair: + def login(account: Account, *, ip_address: str | None = None) -> TokenPair: if ip_address: AccountService.update_login_info(account=account, ip_address=ip_address) @@ -439,8 +439,8 @@ class AccountService: @classmethod def send_reset_password_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", is_allow_register: bool = False, ): @@ -473,8 +473,8 @@ class AccountService: @classmethod def send_email_register_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", ): account_email = account.email if account else email @@ -507,11 +507,11 @@ class AccountService: @classmethod def send_change_email_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, - old_email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, + old_email: str | None = None, language: str = "en-US", - phase: Optional[str] = None, + phase: str | None = None, ): account_email = account.email if account else email if account_email is None: @@ -538,8 +538,8 @@ class AccountService: @classmethod def send_change_email_completed_notify_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", ): account_email = account.email if account else email @@ -554,10 +554,10 @@ class AccountService: @classmethod def send_owner_transfer_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", - workspace_name: Optional[str] = "", + workspace_name: str | None = "", ): account_email = account.email if account else email if account_email is None: @@ -583,10 +583,10 @@ class AccountService: @classmethod def send_old_owner_transfer_notify_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", - workspace_name: Optional[str] = "", + workspace_name: str | None = "", new_owner_email: str = "", ): account_email = account.email if account else email @@ -604,10 +604,10 @@ class AccountService: @classmethod def send_new_owner_transfer_notify_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", - workspace_name: Optional[str] = "", + workspace_name: str | None = "", ): account_email = account.email if account else email if account_email is None: @@ -624,8 +624,8 @@ class AccountService: def generate_reset_password_token( cls, email: str, - account: Optional[Account] = None, - code: Optional[str] = None, + account: Account | None = None, + code: str | None = None, additional_data: dict[str, Any] = {}, ): if not code: @@ -640,7 +640,7 @@ class AccountService: def generate_email_register_token( cls, email: str, - code: Optional[str] = None, + code: str | None = None, additional_data: dict[str, Any] = {}, ): if not code: @@ -653,9 +653,9 @@ class AccountService: def generate_change_email_token( cls, email: str, - account: Optional[Account] = None, - code: Optional[str] = None, - old_email: Optional[str] = None, + account: Account | None = None, + code: str | None = None, + old_email: str | None = None, additional_data: dict[str, Any] = {}, ): if not code: @@ -671,8 +671,8 @@ class AccountService: def generate_owner_transfer_token( cls, email: str, - account: Optional[Account] = None, - code: Optional[str] = None, + account: Account | None = None, + code: str | None = None, additional_data: dict[str, Any] = {}, ): if not code: @@ -700,26 +700,26 @@ class AccountService: TokenManager.revoke_token(token, "owner_transfer") @classmethod - def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_reset_password_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "reset_password") @classmethod - def get_email_register_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_email_register_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "email_register") @classmethod - def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_change_email_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "change_email") @classmethod - def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_owner_transfer_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "owner_transfer") @classmethod def send_email_code_login_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", ): email = account.email if account else email @@ -743,7 +743,7 @@ class AccountService: return token @classmethod - def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "email_code_login") @classmethod @@ -965,7 +965,7 @@ class AccountService: class TenantService: @staticmethod - def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant: + def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant: """Create tenant""" if ( not FeatureService.get_system_features().is_allow_create_workspace @@ -996,9 +996,7 @@ class TenantService: return tenant @staticmethod - def create_owner_tenant_if_not_exist( - account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False - ): + def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False): """Check if user have a workspace or not""" available_ta = ( db.session.query(TenantAccountJoin) @@ -1070,7 +1068,7 @@ class TenantService: return tenant @staticmethod - def switch_tenant(account: Account, tenant_id: Optional[str] = None): + def switch_tenant(account: Account, tenant_id: str | None = None): """Switch the current workspace for the account""" # Ensure tenant_id is provided @@ -1152,7 +1150,7 @@ class TenantService: ) @staticmethod - def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountRole]: + def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None: """Get the role of the current account for a given tenant""" join = ( db.session.query(TenantAccountJoin) @@ -1292,13 +1290,13 @@ class RegisterService: cls, email, name, - password: Optional[str] = None, - open_id: Optional[str] = None, - provider: Optional[str] = None, - language: Optional[str] = None, - status: Optional[AccountStatus] = None, - is_setup: Optional[bool] = False, - create_workspace_required: Optional[bool] = True, + password: str | None = None, + open_id: str | None = None, + provider: str | None = None, + language: str | None = None, + status: AccountStatus | None = None, + is_setup: bool | None = False, + create_workspace_required: bool | None = True, ) -> Account: db.session.begin_nested() """Register account""" @@ -1415,9 +1413,7 @@ class RegisterService: redis_client.delete(cls._get_invitation_token_key(token)) @classmethod - def get_invitation_if_token_valid( - cls, workspace_id: Optional[str], email: str, token: str - ) -> Optional[dict[str, Any]]: + def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None: invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -1456,8 +1452,8 @@ class RegisterService: @classmethod def get_invitation_by_token( - cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None - ) -> Optional[dict[str, str]]: + cls, token: str, workspace_id: str | None = None, email: str | None = None + ) -> dict[str, str] | None: if workspace_id is not None and email is not None: email_hash = sha256(email.encode()).hexdigest() cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 8578f38a0d..d631ce812f 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -1,5 +1,5 @@ import threading -from typing import Any, Optional +from typing import Any import pytz @@ -35,7 +35,7 @@ class AgentService: if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Optional[Message] = ( + message: Message | None = ( db.session.query(Message) .where( Message.id == message_id, diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 34681ba111..9feca7337f 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,5 +1,4 @@ import uuid -from typing import Optional import pandas as pd from sqlalchemy import or_, select @@ -42,7 +41,7 @@ class AppAnnotationService: if not message: raise NotFound("Message Not Exists.") - annotation: Optional[MessageAnnotation] = message.annotation + annotation: MessageAnnotation | None = message.annotation # save the message annotation if annotation: annotation.content = args["answer"] diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 49ff28d191..1c4a9b96ec 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -4,7 +4,6 @@ import logging import uuid from collections.abc import Mapping from enum import StrEnum -from typing import Optional from urllib.parse import urlparse from uuid import uuid4 @@ -61,8 +60,8 @@ class ImportStatus(StrEnum): class Import(BaseModel): id: str status: ImportStatus - app_id: Optional[str] = None - app_mode: Optional[str] = None + app_id: str | None = None + app_mode: str | None = None current_dsl_version: str = CURRENT_DSL_VERSION imported_dsl_version: str = "" error: str = "" @@ -121,14 +120,14 @@ class AppDslService: *, account: Account, import_mode: str, - yaml_content: Optional[str] = None, - yaml_url: Optional[str] = None, - name: Optional[str] = None, - description: Optional[str] = None, - icon_type: Optional[str] = None, - icon: Optional[str] = None, - icon_background: Optional[str] = None, - app_id: Optional[str] = None, + yaml_content: str | None = None, + yaml_url: str | None = None, + name: str | None = None, + description: str | None = None, + icon_type: str | None = None, + icon: str | None = None, + icon_background: str | None = None, + app_id: str | None = None, ) -> Import: """Import an app from YAML content or URL.""" import_id = str(uuid.uuid4()) @@ -407,15 +406,15 @@ class AppDslService: def _create_or_update_app( self, *, - app: Optional[App], + app: App | None, data: dict, account: Account, - name: Optional[str] = None, - description: Optional[str] = None, - icon_type: Optional[str] = None, - icon: Optional[str] = None, - icon_background: Optional[str] = None, - dependencies: Optional[list[PluginDependency]] = None, + name: str | None = None, + description: str | None = None, + icon_type: str | None = None, + icon: str | None = None, + icon_background: str | None = None, + dependencies: list[PluginDependency] | None = None, ) -> App: """Create a new app or update an existing one.""" app_data = data.get("app", {}) @@ -533,7 +532,7 @@ class AppDslService: return app @classmethod - def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: Optional[str] = None) -> str: + def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: str | None = None) -> str: """ Export app :param app_model: App instance @@ -566,7 +565,7 @@ class AppDslService: @classmethod def _append_workflow_export_data( - cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: Optional[str] = None + cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None ): """ Append workflow export data diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index c1ef206c99..1fae452d38 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Any, Optional, Union +from typing import Any, Union from openai._exceptions import RateLimitError @@ -214,7 +214,7 @@ class AppGenerateService: ) @classmethod - def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow: + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: str | None = None) -> Workflow: """ Get workflow :param app_model: app model diff --git a/api/services/app_service.py b/api/services/app_service.py index c553ec8c19..ab2b38ec01 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, TypedDict, cast +from typing import TypedDict, cast from flask_sqlalchemy.pagination import Pagination @@ -370,7 +370,7 @@ class AppService: } ) else: - app_model_config: Optional[AppModelConfig] = app_model.app_model_config + app_model_config: AppModelConfig | None = app_model.app_model_config if not app_model_config: return meta @@ -393,7 +393,7 @@ class AppService: meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" elif provider_type == "api": try: - provider: Optional[ApiToolProvider] = ( + provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first() ) if provider is None: diff --git a/api/services/audio_service.py b/api/services/audio_service.py index a7cd5e9487..1158fc5197 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -2,7 +2,6 @@ import io import logging import uuid from collections.abc import Generator -from typing import Optional from flask import Response, stream_with_context from werkzeug.datastructures import FileStorage @@ -30,7 +29,7 @@ logger = logging.getLogger(__name__) class AudioService: @classmethod - def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): + def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None): if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: @@ -77,15 +76,15 @@ class AudioService: def transcript_tts( cls, app_model: App, - text: Optional[str] = None, - voice: Optional[str] = None, - end_user: Optional[str] = None, - message_id: Optional[str] = None, + text: str | None = None, + voice: str | None = None, + end_user: str | None = None, + message_id: str | None = None, is_draft: bool = False, ): from app import app - def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False): + def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False): with app.app_context(): if voice is None: if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 066bed3234..a364862a88 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,5 +1,5 @@ import os -from typing import Literal, Optional +from typing import Literal import httpx from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed @@ -73,7 +73,7 @@ class BillingService: def is_tenant_owner_or_admin(current_user: Account): tenant_id = current_user.current_tenant_id - join: Optional[TenantAccountJoin] = ( + join: TenantAccountJoin | None = ( db.session.query(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .first() diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index a29204aabf..a8e51a426d 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,7 +1,7 @@ import contextlib import logging from collections.abc import Callable, Sequence -from typing import Any, Optional, Union +from typing import Any, Union from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -36,12 +36,12 @@ class ConversationService: *, session: Session, app_model: App, - user: Optional[Union[Account, EndUser]], - last_id: Optional[str], + user: Union[Account, EndUser] | None, + last_id: str | None, limit: int, invoke_from: InvokeFrom, - include_ids: Optional[Sequence[str]] = None, - exclude_ids: Optional[Sequence[str]] = None, + include_ids: Sequence[str] | None = None, + exclude_ids: Sequence[str] | None = None, sort_by: str = "-updated_at", ) -> InfiniteScrollPagination: if not user: @@ -118,7 +118,7 @@ class ConversationService: cls, app_model: App, conversation_id: str, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, name: str, auto_generate: bool, ): @@ -158,7 +158,7 @@ class ConversationService: return conversation @classmethod - def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): conversation = ( db.session.query(Conversation) .where( @@ -178,7 +178,7 @@ class ConversationService: return conversation @classmethod - def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): try: logger.info( "Initiating conversation deletion for app_name %s, conversation_id: %s", @@ -200,9 +200,9 @@ class ConversationService: cls, app_model: App, conversation_id: str, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, limit: int, - last_id: Optional[str], + last_id: str | None, ) -> InfiniteScrollPagination: conversation = cls.get_conversation(app_model, conversation_id, user) @@ -248,7 +248,7 @@ class ConversationService: app_model: App, conversation_id: str, variable_id: str, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, new_value: Any, ): """ diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8c8b368d42..102629629d 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,7 +7,7 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal import sqlalchemy as sa from sqlalchemy import exists, func, select @@ -185,16 +185,16 @@ class DatasetService: def create_empty_dataset( tenant_id: str, name: str, - description: Optional[str], - indexing_technique: Optional[str], + description: str | None, + indexing_technique: str | None, account: Account, - permission: Optional[str] = None, + permission: str | None = None, provider: str = "vendor", - external_knowledge_api_id: Optional[str] = None, - external_knowledge_id: Optional[str] = None, - embedding_model_provider: Optional[str] = None, - embedding_model_name: Optional[str] = None, - retrieval_model: Optional[RetrievalModel] = None, + external_knowledge_api_id: str | None = None, + external_knowledge_id: str | None = None, + embedding_model_provider: str | None = None, + embedding_model_name: str | None = None, + retrieval_model: RetrievalModel | None = None, ): # check if dataset name already exists if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): @@ -257,8 +257,8 @@ class DatasetService: return dataset @staticmethod - def get_dataset(dataset_id) -> Optional[Dataset]: - dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() + def get_dataset(dataset_id) -> Dataset | None: + dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first() return dataset @staticmethod @@ -694,7 +694,7 @@ class DatasetService: raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod - def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): + def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None): if not dataset: raise ValueError("Dataset not found") @@ -868,7 +868,7 @@ class DocumentService: } @staticmethod - def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: + def get_document(dataset_id: str, document_id: str | None = None) -> Document | None: if document_id: document = ( db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() @@ -878,7 +878,7 @@ class DocumentService: return None @staticmethod - def get_document_by_id(document_id: str) -> Optional[Document]: + def get_document_by_id(document_id: str) -> Document | None: document = db.session.query(Document).where(Document.id == document_id).first() return document @@ -1099,7 +1099,7 @@ class DocumentService: dataset: Dataset, knowledge_config: KnowledgeConfig, account: Account | Any, - dataset_process_rule: Optional[DatasetProcessRule] = None, + dataset_process_rule: DatasetProcessRule | None = None, created_from: str = "web", ) -> tuple[list[Document], str]: # check doc_form @@ -1463,7 +1463,7 @@ class DocumentService: dataset: Dataset, document_data: KnowledgeConfig, account: Account, - dataset_process_rule: Optional[DatasetProcessRule] = None, + dataset_process_rule: DatasetProcessRule | None = None, created_from: str = "web", ): assert isinstance(current_user, Account) @@ -2655,7 +2655,7 @@ class SegmentService: @classmethod def get_child_chunks( - cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None + cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: str | None = None ): assert isinstance(current_user, Account) @@ -2674,7 +2674,7 @@ class SegmentService: return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod - def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]: + def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None: """Get a child chunk by its ID.""" result = ( db.session.query(ChildChunk) @@ -2711,7 +2711,7 @@ class SegmentService: return paginated_segments.items, paginated_segments.total @classmethod - def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: + def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None: """Get a segment by its ID.""" result = ( db.session.query(DocumentSegment) diff --git a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py index 4545f385eb..c9fb1c9e21 100644 --- a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py +++ b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Union +from typing import Literal, Union from pydantic import BaseModel @@ -11,7 +11,7 @@ class AuthorizationConfig(BaseModel): class Authorization(BaseModel): type: Literal["no-auth", "api-key"] - config: Optional[AuthorizationConfig] = None + config: AuthorizationConfig | None = None class ProcessStatusSetting(BaseModel): @@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel): class ExternalKnowledgeApiSetting(BaseModel): url: str request_method: str - headers: Optional[dict] = None - params: Optional[dict] = None + headers: dict | None = None + params: dict | None = None diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 344c67885e..94ce9d5415 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -11,14 +11,14 @@ class ParentMode(StrEnum): class NotionIcon(BaseModel): type: str - url: Optional[str] = None - emoji: Optional[str] = None + url: str | None = None + emoji: str | None = None class NotionPage(BaseModel): page_id: str page_name: str - page_icon: Optional[NotionIcon] = None + page_icon: NotionIcon | None = None type: str @@ -40,9 +40,9 @@ class FileInfo(BaseModel): class InfoList(BaseModel): data_source_type: Literal["upload_file", "notion_import", "website_crawl"] - notion_info_list: Optional[list[NotionInfo]] = None - file_info_list: Optional[FileInfo] = None - website_info_list: Optional[WebsiteInfo] = None + notion_info_list: list[NotionInfo] | None = None + file_info_list: FileInfo | None = None + website_info_list: WebsiteInfo | None = None class DataSource(BaseModel): @@ -61,20 +61,20 @@ class Segmentation(BaseModel): class Rule(BaseModel): - pre_processing_rules: Optional[list[PreProcessingRule]] = None - segmentation: Optional[Segmentation] = None - parent_mode: Optional[Literal["full-doc", "paragraph"]] = None - subchunk_segmentation: Optional[Segmentation] = None + pre_processing_rules: list[PreProcessingRule] | None = None + segmentation: Segmentation | None = None + parent_mode: Literal["full-doc", "paragraph"] | None = None + subchunk_segmentation: Segmentation | None = None class ProcessRule(BaseModel): mode: Literal["automatic", "custom", "hierarchical"] - rules: Optional[Rule] = None + rules: Rule | None = None class RerankingModel(BaseModel): - reranking_provider_name: Optional[str] = None - reranking_model_name: Optional[str] = None + reranking_provider_name: str | None = None + reranking_model_name: str | None = None class WeightVectorSetting(BaseModel): @@ -88,20 +88,20 @@ class WeightKeywordSetting(BaseModel): class WeightModel(BaseModel): - weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None - vector_setting: Optional[WeightVectorSetting] = None - keyword_setting: Optional[WeightKeywordSetting] = None + weight_type: Literal["semantic_first", "keyword_first", "customized"] | None = None + vector_setting: WeightVectorSetting | None = None + keyword_setting: WeightKeywordSetting | None = None class RetrievalModel(BaseModel): search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"] reranking_enable: bool - reranking_model: Optional[RerankingModel] = None - reranking_mode: Optional[str] = None + reranking_model: RerankingModel | None = None + reranking_mode: str | None = None top_k: int score_threshold_enabled: bool - score_threshold: Optional[float] = None - weights: Optional[WeightModel] = None + score_threshold: float | None = None + weights: WeightModel | None = None class MetaDataConfig(BaseModel): @@ -110,29 +110,29 @@ class MetaDataConfig(BaseModel): class KnowledgeConfig(BaseModel): - original_document_id: Optional[str] = None + original_document_id: str | None = None duplicate: bool = True indexing_technique: Literal["high_quality", "economy"] - data_source: Optional[DataSource] = None - process_rule: Optional[ProcessRule] = None - retrieval_model: Optional[RetrievalModel] = None + data_source: DataSource | None = None + process_rule: ProcessRule | None = None + retrieval_model: RetrievalModel | None = None doc_form: str = "text_model" doc_language: str = "English" - embedding_model: Optional[str] = None - embedding_model_provider: Optional[str] = None - name: Optional[str] = None + embedding_model: str | None = None + embedding_model_provider: str | None = None + name: str | None = None class SegmentUpdateArgs(BaseModel): - content: Optional[str] = None - answer: Optional[str] = None - keywords: Optional[list[str]] = None + content: str | None = None + answer: str | None = None + keywords: list[str] | None = None regenerate_child_chunks: bool = False - enabled: Optional[bool] = None + enabled: bool | None = None class ChildChunkUpdateArgs(BaseModel): - id: Optional[str] = None + id: str | None = None content: str @@ -143,13 +143,13 @@ class MetadataArgs(BaseModel): class MetadataUpdateArgs(BaseModel): name: str - value: Optional[str | int | float] = None + value: str | int | float | None = None class MetadataDetail(BaseModel): id: str name: str - value: Optional[str | int | float] = None + value: str | int | float | None = None class DocumentMetadataOperation(BaseModel): diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 647052d739..49d48f044c 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -42,11 +41,11 @@ class CustomConfigurationResponse(BaseModel): """ status: CustomConfigurationStatus - current_credential_id: Optional[str] = None - current_credential_name: Optional[str] = None - available_credentials: Optional[list[CredentialConfiguration]] = None - custom_models: Optional[list[CustomModelConfiguration]] = None - can_added_models: Optional[list[UnaddedModelConfiguration]] = None + current_credential_id: str | None = None + current_credential_name: str | None = None + available_credentials: list[CredentialConfiguration] | None = None + custom_models: list[CustomModelConfiguration] | None = None + can_added_models: list[UnaddedModelConfiguration] | None = None class SystemConfigurationResponse(BaseModel): @@ -55,7 +54,7 @@ class SystemConfigurationResponse(BaseModel): """ enabled: bool - current_quota_type: Optional[ProviderQuotaType] = None + current_quota_type: ProviderQuotaType | None = None quota_configurations: list[QuotaConfiguration] = [] @@ -67,15 +66,15 @@ class ProviderResponse(BaseModel): tenant_id: str provider: str label: I18nObject - description: Optional[I18nObject] = None - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None - background: Optional[str] = None - help: Optional[ProviderHelpEntity] = None + description: I18nObject | None = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None + background: str | None = None + help: ProviderHelpEntity | None = None supported_model_types: list[ModelType] configurate_methods: list[ConfigurateMethod] - provider_credential_schema: Optional[ProviderCredentialSchema] = None - model_credential_schema: Optional[ModelCredentialSchema] = None + provider_credential_schema: ProviderCredentialSchema | None = None + model_credential_schema: ModelCredentialSchema | None = None preferred_provider_type: ProviderType custom_configuration: CustomConfigurationResponse system_configuration: SystemConfigurationResponse @@ -108,8 +107,8 @@ class ProviderWithModelsResponse(BaseModel): tenant_id: str provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None status: CustomConfigurationStatus models: list[ProviderModelWithStatusEntity] diff --git a/api/services/errors/base.py b/api/services/errors/base.py index 35ea28468e..0f9631190f 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,6 +1,3 @@ -from typing import Optional - - class BaseServiceError(ValueError): - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description diff --git a/api/services/errors/llm.py b/api/services/errors/llm.py index ca4c9a611d..5bf34f3aa6 100644 --- a/api/services/errors/llm.py +++ b/api/services/errors/llm.py @@ -1,12 +1,9 @@ -from typing import Optional - - class InvokeError(Exception): """Base class for all LLM exceptions.""" - description: Optional[str] = None + description: str | None = None - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description def __str__(self): diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 3911b763b6..b6ba3bafea 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from urllib.parse import urlparse import httpx @@ -100,7 +100,7 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: - external_knowledge_api: Optional[ExternalKnowledgeApis] = ( + external_knowledge_api: ExternalKnowledgeApis | None = ( db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first() ) if external_knowledge_api is None: @@ -109,7 +109,7 @@ class ExternalDatasetService: @staticmethod def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: - external_knowledge_api: Optional[ExternalKnowledgeApis] = ( + external_knowledge_api: ExternalKnowledgeApis | None = ( db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() ) if external_knowledge_api is None: @@ -151,7 +151,7 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: - external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ( + external_knowledge_binding: ExternalKnowledgeBindings | None = ( db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() ) if not external_knowledge_binding: @@ -203,7 +203,7 @@ class ExternalDatasetService: return response @staticmethod - def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]: + def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]: authorization = deepcopy(authorization) if headers: headers = deepcopy(headers) @@ -277,7 +277,7 @@ class ExternalDatasetService: dataset_id: str, query: str, external_retrieval_parameters: dict, - metadata_condition: Optional[MetadataCondition] = None, + metadata_condition: MetadataCondition | None = None, ): external_knowledge_binding = ( db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() diff --git a/api/services/message_service.py b/api/services/message_service.py index 7c1f74e488..e2e27443ba 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Union +from typing import Union from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom @@ -29,9 +29,9 @@ class MessageService: def pagination_by_first_id( cls, app_model: App, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, conversation_id: str, - first_id: Optional[str], + first_id: str | None, limit: int, order: str = "asc", ) -> InfiniteScrollPagination: @@ -91,11 +91,11 @@ class MessageService: def pagination_by_last_id( cls, app_model: App, - user: Optional[Union[Account, EndUser]], - last_id: Optional[str], + user: Union[Account, EndUser] | None, + last_id: str | None, limit: int, - conversation_id: Optional[str] = None, - include_ids: Optional[list] = None, + conversation_id: str | None = None, + include_ids: list | None = None, ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -145,9 +145,9 @@ class MessageService: *, app_model: App, message_id: str, - user: Optional[Union[Account, EndUser]], - rating: Optional[str], - content: Optional[str], + user: Union[Account, EndUser] | None, + rating: str | None, + content: str | None, ): if not user: raise ValueError("user cannot be None") @@ -196,7 +196,7 @@ class MessageService: return [record.to_dict() for record in feedbacks] @classmethod - def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): message = ( db.session.query(Message) .where( @@ -216,7 +216,7 @@ class MessageService: @classmethod def get_suggested_questions_after_answer( - cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom + cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom ) -> list[Message]: if not user: raise ValueError("user cannot be None") diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 208ecdb79e..6add830813 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -1,6 +1,5 @@ import copy import logging -from typing import Optional from flask_login import current_user @@ -237,7 +236,7 @@ class MetadataService: redis_client.delete(lock_key) @staticmethod - def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]): + def knowledge_base_metadata_lock_check(dataset_id: str | None, document_id: str | None): if dataset_id: lock_key = f"dataset_metadata_lock_{dataset_id}" if redis_client.get(lock_key): diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 33d7dacba0..69da3bfb79 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,7 +1,7 @@ import json import logging from json import JSONDecodeError -from typing import Optional, Union +from typing import Union from sqlalchemy import or_, select @@ -211,7 +211,7 @@ class ModelLoadBalancingService: def get_load_balancing_config( self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str - ) -> Optional[dict]: + ) -> dict | None: """ Get load balancing configuration. :param tenant_id: workspace id @@ -478,7 +478,7 @@ class ModelLoadBalancingService: model: str, model_type: str, credentials: dict, - config_id: Optional[str] = None, + config_id: str | None = None, ): """ Validate load balancing credentials. @@ -536,7 +536,7 @@ class ModelLoadBalancingService: model_type: ModelType, model: str, credentials: dict, - load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, + load_balancing_model_config: LoadBalancingModelConfig | None = None, validate: bool = True, ): """ diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 510b1f1fe6..2901a0d273 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.model_runtime.entities.model_entities import ModelType, ParameterRule @@ -52,7 +51,7 @@ class ModelProviderService: return provider_configuration - def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]: + def get_provider_list(self, tenant_id: str, model_type: str | None = None) -> list[ProviderResponse]: """ get provider list. @@ -128,9 +127,7 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] - def get_provider_credential( - self, tenant_id: str, provider: str, credential_id: Optional[str] = None - ) -> Optional[dict]: + def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: """ get provider credentials. @@ -216,7 +213,7 @@ class ModelProviderService: def get_model_credential( self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None - ) -> Optional[dict]: + ) -> dict | None: """ Retrieve model-specific credentials. @@ -449,7 +446,7 @@ class ModelProviderService: return model_schema.parameter_rules if model_schema else [] - def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: + def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> DefaultModelResponse | None: """ get default model of model type. @@ -498,7 +495,7 @@ class ModelProviderService: def get_model_provider_icon( self, tenant_id: str, provider: str, icon_type: str, lang: str - ) -> tuple[Optional[bytes], Optional[str]]: + ) -> tuple[bytes | None, str | None]: """ get model provider icon. diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 2596e9f711..c214640653 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from core.ops.entities.config_entity import BaseTracingConfig from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map @@ -15,7 +15,7 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config_data: Optional[TraceAppConfig] = ( + trace_config_data: TraceAppConfig | None = ( db.session.query(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -153,7 +153,7 @@ class OpsService: project_url = None # check if trace config already exists - trace_config_data: Optional[TraceAppConfig] = ( + trace_config_data: TraceAppConfig | None = ( db.session.query(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 34f10ae407..fcfa52371d 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -5,7 +5,7 @@ import time from collections.abc import Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, Optional +from typing import Any from uuid import uuid4 import click @@ -281,7 +281,7 @@ class PluginMigration: return result @classmethod - def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]: + def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> str | None: """ Fetch plugin unique identifier using plugin id. """ diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 9005f0669b..3b7ce20f83 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -1,7 +1,6 @@ import logging from collections.abc import Mapping, Sequence from mimetypes import guess_type -from typing import Optional from pydantic import BaseModel @@ -46,11 +45,11 @@ class PluginService: REDIS_TTL = 60 * 5 # 5 minutes @staticmethod - def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]: + def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]: """ Fetch the latest plugin version """ - result: dict[str, Optional[PluginService.LatestPluginCache]] = {} + result: dict[str, PluginService.LatestPluginCache | None] = {} try: cache_not_exists = [] @@ -109,7 +108,7 @@ class PluginService: raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only") @staticmethod - def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]): + def _check_plugin_installation_scope(plugin_verification: PluginVerification | None): """ Check the plugin installation scope """ @@ -144,7 +143,7 @@ class PluginService: return manager.get_debugging_key(tenant_id) @staticmethod - def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]: + def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]: """ List the latest versions of the plugins """ diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index df9e01e273..64751d186c 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -1,7 +1,6 @@ import json from os import path from pathlib import Path -from typing import Optional from flask import current_app @@ -14,7 +13,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): Retrieval recommended app from buildin, the location is constants/recommended_apps.json """ - builtin_data: Optional[dict] = None + builtin_data: dict | None = None def get_type(self) -> str: return RecommendAppType.BUILDIN @@ -54,7 +53,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): return builtin_data.get("recommended_apps", {}).get(language, {}) @classmethod - def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: + def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None: """ Fetch recommended app detail from builtin. :param app_id: App ID diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index a9733e0826..d0c49325dc 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -1,5 +1,3 @@ -from typing import Optional - from sqlalchemy import select from constants.languages import languages @@ -72,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} @classmethod - def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None: """ Fetch recommended app detail from db. :param app_id: App ID diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index 1e59287429..2d57769f63 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import requests @@ -36,7 +35,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): return RecommendAppType.REMOTE @classmethod - def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]: + def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None: """ Fetch recommended app detail from dify official. :param app_id: App ID diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index d9c1b51fa1..544383a106 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,5 +1,3 @@ -from typing import Optional - from configs import dify_config from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory @@ -25,7 +23,7 @@ class RecommendedAppService: return result @classmethod - def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: + def get_recommend_app_detail(cls, app_id: str) -> dict | None: """ Get recommend app detail. :param app_id: app id diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 641e03c3cf..67a0106bbd 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -11,7 +11,7 @@ from services.message_service import MessageService class SavedMessageService: @classmethod def pagination_by_last_id( - cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int + cls, app_model: App, user: Union[Account, EndUser] | None, last_id: str | None, limit: int ) -> InfiniteScrollPagination: if not user: raise ValueError("User is required") @@ -32,7 +32,7 @@ class SavedMessageService: ) @classmethod - def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): if not user: return saved_message = ( @@ -62,7 +62,7 @@ class SavedMessageService: db.session.commit() @classmethod - def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): if not user: return saved_message = ( diff --git a/api/services/tag_service.py b/api/services/tag_service.py index dd67b19966..4674335ba8 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,5 +1,4 @@ import uuid -from typing import Optional from flask_login import current_user from sqlalchemy import func, select @@ -12,7 +11,7 @@ from models.model import App, Tag, TagBinding class TagService: @staticmethod - def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None): + def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None): query = ( db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index cb31111485..9db71dcd09 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -3,7 +3,7 @@ import logging import re from collections.abc import Mapping from pathlib import Path -from typing import Any, Optional +from typing import Any from sqlalchemy import exists, select from sqlalchemy.orm import Session @@ -604,7 +604,7 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) @staticmethod - def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: + def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: """ This method is used to fetch the builtin provider from the database 1.if the default provider exists, return the default provider @@ -665,8 +665,8 @@ class BuiltinToolManageService: def save_custom_oauth_client_params( tenant_id: str, provider: str, - client_params: Optional[dict] = None, - enable_oauth_custom_client: Optional[bool] = None, + client_params: dict | None = None, + enable_oauth_custom_client: bool | None = None, ): """ setup oauth custom client diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index f245dd7527..51e9120b8d 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.tool_manager import ToolManager @@ -10,7 +9,7 @@ logger = logging.getLogger(__name__) class ToolCommonService: @staticmethod - def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None): + def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None): """ list tool providers diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index f5fc7f951f..93c632f92c 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from yarl import URL @@ -94,7 +94,7 @@ class ToolTransformService: def builtin_provider_to_user_provider( cls, provider_controller: BuiltinToolProviderController | PluginToolProviderController, - db_provider: Optional[BuiltinToolProvider], + db_provider: BuiltinToolProvider | None, decrypt_credentials: bool = True, ) -> ToolProviderApiEntity: """ diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 428abdde17..1c559f2c2b 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -19,7 +18,7 @@ logger = logging.getLogger(__name__) class VectorService: @classmethod def create_segments_vector( - cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str + cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents: list[Document] = [] @@ -79,7 +78,7 @@ class VectorService: index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) @classmethod - def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): + def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset): # update segment index task # format new index diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index c48e24f244..0f54e838f3 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -19,11 +19,11 @@ class WebConversationService: *, session: Session, app_model: App, - user: Optional[Union[Account, EndUser]], - last_id: Optional[str], + user: Union[Account, EndUser] | None, + last_id: str | None, limit: int, invoke_from: InvokeFrom, - pinned: Optional[bool] = None, + pinned: bool | None = None, sort_by="-updated_at", ) -> InfiniteScrollPagination: if not user: @@ -60,7 +60,7 @@ class WebConversationService: ) @classmethod - def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return pinned_conversation = ( @@ -92,7 +92,7 @@ class WebConversationService: db.session.commit() @classmethod - def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return pinned_conversation = ( diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index bb46bf3090..066dc9d741 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -1,7 +1,7 @@ import enum import secrets from datetime import UTC, datetime, timedelta -from typing import Any, Optional +from typing import Any from werkzeug.exceptions import NotFound, Unauthorized @@ -63,7 +63,7 @@ class WebAppAuthService: @classmethod def send_email_code_login_email( - cls, account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US" + cls, account: Account | None = None, email: str | None = None, language: str = "en-US" ): email = account.email if account else email if email is None: @@ -82,7 +82,7 @@ class WebAppAuthService: return token @classmethod - def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "email_code_login") @classmethod @@ -130,7 +130,7 @@ class WebAppAuthService: @classmethod def is_app_require_permission_check( - cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None + cls, app_code: str | None = None, app_id: str | None = None, access_mode: str | None = None ) -> bool: """ Check if the app requires permission check based on its access mode. diff --git a/api/services/website_service.py b/api/services/website_service.py index 131b96db13..2dc049fc72 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,7 +1,7 @@ import datetime import json from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import requests from flask_login import current_user @@ -21,9 +21,9 @@ class CrawlOptions: limit: int = 1 crawl_sub_pages: bool = False only_main_content: bool = False - includes: Optional[str] = None - excludes: Optional[str] = None - max_depth: Optional[int] = None + includes: str | None = None + excludes: str | None = None + max_depth: int | None = None use_sitemap: bool = True def get_include_paths(self) -> list[str]: diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 376f96c03a..9ce5b6dbe0 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional +from typing import Any from core.app.app_config.entities import ( DatasetEntity, @@ -327,7 +327,7 @@ class WorkflowConverter: def _convert_to_knowledge_retrieval_node( self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity - ) -> Optional[dict]: + ) -> dict | None: """ Convert datasets to Knowledge Retrieval Node :param new_app_mode: new app mode @@ -383,7 +383,7 @@ class WorkflowConverter: graph: dict, model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, - file_upload: Optional[FileUploadConfig] = None, + file_upload: FileUploadConfig | None = None, external_data_variable_node_mapping: dict[str, str] | None = None, ): """ @@ -403,7 +403,7 @@ class WorkflowConverter: ) role_prefix = None - prompts: Optional[Any] = None + prompts: Any | None = None # Chat Model if model_config.mode == LLMMode.CHAT.value: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index e43999a8c9..79d91cab4c 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,6 +1,5 @@ import threading from collections.abc import Sequence -from typing import Optional from sqlalchemy.orm import sessionmaker @@ -80,7 +79,7 @@ class WorkflowRunService: last_id=last_id, ) - def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: + def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None: """ Get workflow run detail diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index e680de3502..abda024555 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,7 +2,7 @@ import json import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from uuid import uuid4 from sqlalchemy import exists, select @@ -88,7 +88,7 @@ class WorkflowService: ) return db.session.execute(stmt).scalar_one() - def get_draft_workflow(self, app_model: App, workflow_id: Optional[str] = None) -> Optional[Workflow]: + def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None: """ Get draft workflow """ @@ -108,7 +108,7 @@ class WorkflowService: # return draft workflow return workflow - def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None: """ fetch published workflow by workflow_id """ @@ -130,7 +130,7 @@ class WorkflowService: ) return workflow - def get_published_workflow(self, app_model: App) -> Optional[Workflow]: + def get_published_workflow(self, app_model: App) -> Workflow | None: """ Get published workflow """ @@ -195,7 +195,7 @@ class WorkflowService: app_model: App, graph: dict, features: dict, - unique_hash: Optional[str], + unique_hash: str | None, account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], @@ -561,7 +561,7 @@ class WorkflowService: return default_block_configs - def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: + def get_default_block_config(self, node_type: str, filters: dict | None = None) -> dict | None: """ Get default config of node. :param node_type: node type @@ -857,7 +857,7 @@ class WorkflowService: def update_workflow( self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict - ) -> Optional[Workflow]: + ) -> Workflow | None: """ Update workflow attributes diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 761ac6fc3d..62200715cc 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -1,6 +1,5 @@ import logging import time -from typing import Optional import click from celery import shared_task @@ -17,7 +16,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: Optional[str]): +def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: str | None): """ Clean document when document deleted. :param document_id: document id diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 986e9dbc3c..6b2907cffd 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -1,6 +1,5 @@ import logging import time -from typing import Optional import click from celery import shared_task @@ -16,7 +15,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] = None): +def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = None): """ Async create segment to index :param segment_id: diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index 0fb7076c85..bc64fda9c2 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -1,6 +1,5 @@ import unittest from datetime import UTC, datetime -from typing import Optional from unittest.mock import patch from uuid import uuid4 @@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase): self.session.rollback() def _create_upload_file( - self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None ) -> UploadFile: """Helper method to create an UploadFile record for testing.""" if file_id is None: @@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase): return upload_file def _create_tool_file( - self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None ) -> ToolFile: """Helper method to create a ToolFile record for testing.""" if file_id is None: @@ -101,9 +100,7 @@ class TestStorageKeyLoader(unittest.TestCase): return tool_file - def _create_file( - self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None - ) -> File: + def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File: """Helper method to create a File object for testing.""" if tenant_id is None: tenant_id = self.tenant_id diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index d699866fb4..d59d5dc0fe 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -5,8 +5,6 @@ from decimal import Decimal from json import dumps # import monkeypatch -from typing import Optional - from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool @@ -113,8 +111,8 @@ class MockModelClass(PluginModelClient): @staticmethod def generate_function_call( - tools: Optional[list[PromptMessageTool]], - ) -> Optional[AssistantPromptMessage.ToolCall]: + tools: list[PromptMessageTool] | None, + ) -> AssistantPromptMessage.ToolCall | None: if not tools or len(tools) == 0: return None function: PromptMessageTool = tools[0] @@ -157,7 +155,7 @@ class MockModelClass(PluginModelClient): def mocked_chat_create_sync( model: str, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> LLMResult: tool_call = MockModelClass.generate_function_call(tools=tools) @@ -186,7 +184,7 @@ class MockModelClass(PluginModelClient): def mocked_chat_create_stream( model: str, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> Generator[LLMResultChunk, None, None]: tool_call = MockModelClass.generate_function_call(tools=tools) @@ -241,9 +239,9 @@ class MockModelClass(PluginModelClient): model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, ): return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools) diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index be5b4de5a2..f9f9f4f369 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -1,6 +1,5 @@ import os from collections import UserDict -from typing import Optional from unittest.mock import MagicMock import pytest @@ -22,7 +21,7 @@ class MockBaiduVectorDBClass: def mock_vector_db_client( self, config=None, - adapter: Optional[HTTPAdapter] = None, + adapter: HTTPAdapter | None = None, ): self.conn = MagicMock() self._config = MagicMock() diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index fd7ab0a22b..e0b908cece 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Union +from typing import Union import pytest from _pytest.monkeypatch import MonkeyPatch @@ -23,16 +23,16 @@ class MockTcvectordbClass: key="", read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, timeout=10, - adapter: Optional[HTTPAdapter] = None, + adapter: HTTPAdapter | None = None, pool_size: int = 2, - proxies: Optional[dict] = None, - password: Optional[str] = None, + proxies: dict | None = None, + password: str | None = None, **kwargs, ): self._conn = None self._read_consistency = read_consistency - def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase: + def create_database_if_not_exists(self, database_name: str, timeout: float | None = None) -> RPCDatabase: return RPCDatabase( name="dify", read_consistency=self._read_consistency, @@ -42,7 +42,7 @@ class MockTcvectordbClass: return True def describe_collection( - self, database_name: str, collection_name: str, timeout: Optional[float] = None + self, database_name: str, collection_name: str, timeout: float | None = None ) -> RPCCollection: index = Index( FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY), @@ -71,13 +71,13 @@ class MockTcvectordbClass: collection_name: str, shard: int, replicas: int, - description: Optional[str] = None, - index: Optional[Index] = None, - embedding: Optional[Embedding] = None, - timeout: Optional[float] = None, - ttl_config: Optional[dict] = None, - filter_index_config: Optional[FilterIndexConfig] = None, - indexes: Optional[list[IndexField]] = None, + description: str | None = None, + index: Index | None = None, + embedding: Embedding | None = None, + timeout: float | None = None, + ttl_config: dict | None = None, + filter_index_config: FilterIndexConfig | None = None, + indexes: list[IndexField] | None = None, ) -> RPCCollection: return RPCCollection( RPCDatabase( @@ -102,7 +102,7 @@ class MockTcvectordbClass: database_name: str, collection_name: str, documents: list[Union[Document, dict]], - timeout: Optional[float] = None, + timeout: float | None = None, build_index: bool = True, **kwargs, ): @@ -113,12 +113,12 @@ class MockTcvectordbClass: database_name: str, collection_name: str, vectors: list[list[float]], - filter: Optional[Filter] = None, + filter: Filter | None = None, params=None, retrieve_vector: bool = False, limit: int = 10, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + output_fields: list[str] | None = None, + timeout: float | None = None, ) -> list[list[dict]]: return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]] @@ -126,14 +126,14 @@ class MockTcvectordbClass: self, database_name: str, collection_name: str, - ann: Optional[Union[list[AnnSearch], AnnSearch]] = None, - match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None, - filter: Optional[Union[Filter, str]] = None, - rerank: Optional[Rerank] = None, - retrieve_vector: Optional[bool] = None, - output_fields: Optional[list[str]] = None, - limit: Optional[int] = None, - timeout: Optional[float] = None, + ann: Union[list[AnnSearch], AnnSearch] | None = None, + match: Union[list[KeywordSearch], KeywordSearch] | None = None, + filter: Union[Filter, str] | None = None, + rerank: Rerank | None = None, + retrieve_vector: bool | None = None, + output_fields: list[str] | None = None, + limit: int | None = None, + timeout: float | None = None, return_pd_object=False, **kwargs, ) -> list[list[dict]]: @@ -143,13 +143,13 @@ class MockTcvectordbClass: self, database_name: str, collection_name: str, - document_ids: Optional[list] = None, + document_ids: list | None = None, retrieve_vector: bool = False, - limit: Optional[int] = None, - offset: Optional[int] = None, - filter: Optional[Filter] = None, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + limit: int | None = None, + offset: int | None = None, + filter: Filter | None = None, + output_fields: list[str] | None = None, + timeout: float | None = None, ): return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}] @@ -157,13 +157,13 @@ class MockTcvectordbClass: self, database_name: str, collection_name: str, - document_ids: Optional[list[str]] = None, - filter: Optional[Filter] = None, - timeout: Optional[float] = None, + document_ids: list[str] | None = None, + filter: Filter | None = None, + timeout: float | None = None, ): return {"code": 0, "msg": "operation success"} - def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None): + def drop_collection(self, database_name: str, collection_name: str, timeout: float | None = None): return {"code": 0, "msg": "operation success"} diff --git a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py b/api/tests/integration_tests/vdb/__mock/upstashvectordb.py index 4b251ba836..70c85d4c98 100644 --- a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/upstashvectordb.py @@ -1,6 +1,5 @@ import os from collections import UserDict -from typing import Optional import pytest from _pytest.monkeypatch import MonkeyPatch @@ -34,7 +33,7 @@ class MockIndex: include_vectors: bool = False, include_metadata: bool = False, filter: str = "", - data: Optional[str] = None, + data: str | None = None, namespace: str = "", include_data: bool = False, ): diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index ef373d968d..11129c4b0c 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -1,7 +1,6 @@ import os import time import uuid -from typing import Optional from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom @@ -29,7 +28,7 @@ def get_mocked_fetch_memory(memory_text: str): human_prefix: str = "Human", ai_prefix: str = "Assistant", max_token_limit: int = 2000, - message_limit: Optional[int] = None, + message_limit: int | None = None, ): return memory_text diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index f28437f6c1..77ed8f261a 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -11,7 +11,6 @@ import logging import os from collections.abc import Generator from pathlib import Path -from typing import Optional import pytest from flask import Flask @@ -42,10 +41,10 @@ class DifyTestContainers: def __init__(self): """Initialize container management with default configurations.""" - self.postgres: Optional[PostgresContainer] = None - self.redis: Optional[RedisContainer] = None - self.dify_sandbox: Optional[DockerContainer] = None - self.dify_plugin_daemon: Optional[DockerContainer] = None + self.postgres: PostgresContainer | None = None + self.redis: RedisContainer | None = None + self.dify_sandbox: DockerContainer | None = None + self.dify_plugin_daemon: DockerContainer | None = None self._containers_started = False logger.info("DifyTestContainers initialized - ready to manage test containers") diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index b6fe8b73a2..21a792de06 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -1,6 +1,5 @@ import unittest from datetime import UTC, datetime -from typing import Optional from unittest.mock import patch from uuid import uuid4 @@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase): self.session.rollback() def _create_upload_file( - self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None ) -> UploadFile: """Helper method to create an UploadFile record for testing.""" if file_id is None: @@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase): return upload_file def _create_tool_file( - self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None ) -> ToolFile: """Helper method to create a ToolFile record for testing.""" if file_id is None: @@ -102,9 +101,7 @@ class TestStorageKeyLoader(unittest.TestCase): return tool_file - def _create_file( - self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None - ) -> File: + def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File: """Helper method to create a File object for testing.""" if tenant_id is None: tenant_id = self.tenant_id diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index ea8a88692f..2765048734 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -1,7 +1,6 @@ import base64 import uuid from collections.abc import Sequence -from typing import Optional from unittest import mock import pytest @@ -47,7 +46,7 @@ class MockTokenBufferMemory: self.history_messages = history_messages or [] def get_history_prompt_messages( - self, max_token_limit: int = 2000, message_limit: Optional[int] = None + self, max_token_limit: int = 2000, message_limit: int | None = None ) -> Sequence[PromptMessage]: if message_limit is not None: return self.history_messages[-message_limit * 2 :] diff --git a/api/tests/unit_tests/libs/test_flask_utils.py b/api/tests/unit_tests/libs/test_flask_utils.py index fb46ba50f3..e30433bfce 100644 --- a/api/tests/unit_tests/libs/test_flask_utils.py +++ b/api/tests/unit_tests/libs/test_flask_utils.py @@ -1,6 +1,5 @@ import contextvars import threading -from typing import Optional import pytest from flask import Flask @@ -29,7 +28,7 @@ def login_app(app: Flask) -> Flask: login_manager.init_app(app) @login_manager.user_loader - def load_user(user_id: str) -> Optional[User]: + def load_user(user_id: str) -> User | None: if user_id == "test_user": return User("test_user") return None diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py index 1881ceac26..69766188f3 100644 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py @@ -1,5 +1,4 @@ import datetime -from typing import Optional # Mock redis_client before importing dataset_service from unittest.mock import Mock, call, patch @@ -37,7 +36,7 @@ class DocumentBatchUpdateTestDataFactory: enabled: bool = True, archived: bool = False, indexing_status: str = "completed", - completed_at: Optional[datetime.datetime] = None, + completed_at: datetime.datetime | None = None, **kwargs, ) -> Mock: """Create a mock document with specified attributes.""" diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py index fb23863043..df5596f5c8 100644 --- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, Optional +from typing import Any # Mock redis_client before importing dataset_service from unittest.mock import Mock, create_autospec, patch @@ -24,9 +24,9 @@ class DatasetUpdateTestDataFactory: description: str = "old_description", indexing_technique: str = "high_quality", retrieval_model: str = "old_model", - embedding_model_provider: Optional[str] = None, - embedding_model: Optional[str] = None, - collection_binding_id: Optional[str] = None, + embedding_model_provider: str | None = None, + embedding_model: str | None = None, + collection_binding_id: str | None = None, **kwargs, ) -> Mock: """Create a mock dataset with specified attributes."""