diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..6756a2fce6 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +version: 2 +updates: + - package-ecosystem: "npm" + directory: "/web" + schedule: + interval: "weekly" + open-pull-requests-limit: 2 + - package-ecosystem: "uv" + directory: "/api" + schedule: + interval: "weekly" + open-pull-requests-limit: 2 diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 82ba95444f..068ba686fa 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -20,14 +20,60 @@ jobs: cd api uv sync --dev # Fix lint errors - uv run ruff check --fix-only . + uv run ruff check --fix . # Format code - uv run ruff format . + uv run ruff format .. + - name: ast-grep run: | uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all + # Convert Optional[T] to T | None (ignoring quoted types) + cat > /tmp/optional-rule.yml << 'EOF' + id: convert-optional-to-union + language: python + rule: + kind: generic_type + all: + - has: + kind: identifier + pattern: Optional + - has: + kind: type_parameter + has: + kind: type + pattern: $T + fix: $T | None + EOF + uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all + # Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax) + find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; + find . -name "*.py.bak" -type f -delete + - name: mdformat run: | uvx mdformat . + + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + package_json_file: web/package.json + run_install: false + + - name: Setup NodeJS + uses: actions/setup-node@v4 + with: + node-version: 22 + cache: pnpm + cache-dependency-path: ./web/package.json + + - name: Web dependencies + working-directory: ./web + run: pnpm install --frozen-lockfile + + - name: oxlint + working-directory: ./web + run: | + pnpx oxlint --fix + - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 diff --git a/.github/workflows/deploy-enterprise.yml b/.github/workflows/deploy-enterprise.yml index 98fa7c3b49..9cff3a3482 100644 --- a/.github/workflows/deploy-enterprise.yml +++ b/.github/workflows/deploy-enterprise.yml @@ -19,11 +19,23 @@ jobs: github.event.workflow_run.head_branch == 'deploy/enterprise' steps: - - name: Deploy to server - uses: appleboy/ssh-action@v0.1.8 - with: - host: ${{ secrets.ENTERPRISE_SSH_HOST }} - username: ${{ secrets.ENTERPRISE_SSH_USER }} - password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }} - script: | - ${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }} + - name: trigger deployments + env: + DEV_ENV_ADDRS: ${{ vars.DEV_ENV_ADDRS }} + DEPLOY_SECRET: ${{ secrets.DEPLOY_SECRET }} + run: | + IFS=',' read -ra ENDPOINTS <<< "${DEV_ENV_ADDRS:-}" + BODY='{"project":"dify-api","tag":"deploy-enterprise"}' + + for ENDPOINT in "${ENDPOINTS[@]}"; do + ENDPOINT="$(echo "$ENDPOINT" | xargs)" + [ -z "$ENDPOINT" ] && continue + + API_SIGNATURE=$(printf '%s' "$BODY" | openssl dgst -sha256 -hmac "$DEPLOY_SECRET" | awk '{print "sha256="$2}') + + curl -sSf -X POST \ + -H "Content-Type: application/json" \ + -H "X-Hub-Signature-256: $API_SIGNATURE" \ + -d "$BODY" \ + "$ENDPOINT" + done diff --git a/.gitignore b/.gitignore index bc354e639e..cbb7b4dac0 100644 --- a/.gitignore +++ b/.gitignore @@ -227,3 +227,7 @@ web/public/fallback-*.js .roo/ api/.env.backup /clickzetta + +# Benchmark +scripts/stress-test/setup/config/ +scripts/stress-test/reports/ \ No newline at end of file diff --git a/Makefile b/Makefile index d82f6f24ad..ec7df3e03d 100644 --- a/Makefile +++ b/Makefile @@ -4,10 +4,13 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web API_IMAGE=$(DOCKER_REGISTRY)/dify-api VERSION=latest +# Default target - show help +.DEFAULT_GOAL := help + # Backend Development Environment Setup .PHONY: dev-setup prepare-docker prepare-web prepare-api -# Default dev setup target +# Dev setup target dev-setup: prepare-docker prepare-web prepare-api @echo "✅ Backend development environment setup complete!" @@ -46,6 +49,27 @@ dev-clean: @rm -rf api/storage @echo "✅ Cleanup complete" +# Backend Code Quality Commands +format: + @echo "🎨 Running ruff format..." + @uv run --project api --dev ruff format ./api + @echo "✅ Code formatting complete" + +check: + @echo "🔍 Running ruff check..." + @uv run --project api --dev ruff check ./api + @echo "✅ Code check complete" + +lint: + @echo "🔧 Running ruff format and check with fixes..." + @uv run --directory api --dev sh -c 'ruff format ./api && ruff check --fix ./api' + @echo "✅ Linting complete" + +type-check: + @echo "📝 Running type check with basedpyright..." + @uv run --directory api --dev basedpyright + @echo "✅ Type check complete" + # Build Docker images build-web: @echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..." @@ -90,6 +114,12 @@ help: @echo " make prepare-api - Set up API environment" @echo " make dev-clean - Stop Docker middleware containers" @echo "" + @echo "Backend Code Quality:" + @echo " make format - Format code with ruff" + @echo " make check - Check code with ruff" + @echo " make lint - Format and fix code with ruff" + @echo " make type-check - Run type checking with basedpyright" + @echo "" @echo "Docker Build Targets:" @echo " make build-web - Build web Docker image" @echo " make build-api - Build API Docker image" @@ -98,4 +128,4 @@ help: @echo " make build-push-all - Build and push all Docker images" # Phony targets -.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help +.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check diff --git a/api/.env.example b/api/.env.example index 7081f4879d..967e5fa57e 100644 --- a/api/.env.example +++ b/api/.env.example @@ -540,6 +540,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5 CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 diff --git a/api/.ruff.toml b/api/.ruff.toml index 9668dc9f76..67ad3b1449 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -5,7 +5,7 @@ line-length = 120 quote-style = "double" [lint] -preview = false +preview = true select = [ "B", # flake8-bugbear rules "C4", # flake8-comprehensions @@ -45,6 +45,7 @@ select = [ "G001", # don't use str format to logging messages "G003", # don't use + in logging messages "G004", # don't use f-strings to format logging messages + "UP042", # use StrEnum ] ignore = [ @@ -64,6 +65,7 @@ ignore = [ "B006", # mutable-argument-default "B007", # unused-loop-control-variable "B026", # star-arg-unpacking-after-keyword-arg + "B901", # allow return in yield "B903", # class-as-data-structure "B904", # raise-without-from-inside-except "B905", # zip-without-explicit-strict diff --git a/api/commands.py b/api/commands.py index c0dc913811..39c40fdf73 100644 --- a/api/commands.py +++ b/api/commands.py @@ -2,7 +2,7 @@ import base64 import json import logging import secrets -from typing import Any, Optional +from typing import Any import click import sqlalchemy as sa @@ -218,7 +218,9 @@ def migrate_annotation_vector_database(): if not dataset_collection_binding: click.echo(f"App annotation collection binding not found: {app.id}") continue - annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() + annotations = db.session.scalars( + select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) + ).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, @@ -373,29 +375,25 @@ def migrate_knowledge_vector_database(): ) raise e - dataset_documents = ( - db.session.query(DatasetDocument) - .where( + dataset_documents = db.session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset.id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() documents = [] segments_count = 0 for dataset_document in dataset_documents: - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.status == "completed", DocumentSegment.enabled == True, ) - .all() - ) + ).all() for segment in segments: document = Document( @@ -485,12 +483,12 @@ def convert_to_agent_apps(): click.echo(f"Converting app: {app.id}") try: - app.mode = AppMode.AGENT_CHAT.value + app.mode = AppMode.AGENT_CHAT db.session.commit() # update conversation mode to agent db.session.query(Conversation).where(Conversation.app_id == app.id).update( - {Conversation.mode: AppMode.AGENT_CHAT.value} + {Conversation.mode: AppMode.AGENT_CHAT} ) db.session.commit() @@ -517,7 +515,7 @@ def add_qdrant_index(field: str): from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import PayloadSchemaType - from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig + from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig for binding in bindings: if dify_config.QDRANT_URL is None: @@ -531,7 +529,21 @@ def add_qdrant_index(field: str): prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, ) try: - client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) + params = qdrant_config.to_qdrant_params() + # Check the type before using + if isinstance(params, PathQdrantParams): + # PathQdrantParams case + client = qdrant_client.QdrantClient(path=params.path) + else: + # UrlQdrantParams case - params is UrlQdrantParams + client = qdrant_client.QdrantClient( + url=params.url, + api_key=params.api_key, + timeout=int(params.timeout), + verify=params.verify, + grpc_port=params.grpc_port, + prefer_grpc=params.prefer_grpc, + ) # create payload index client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) create_count += 1 @@ -633,7 +645,7 @@ def old_metadata_migration(): @click.option("--email", prompt=True, help="Tenant account email.") @click.option("--name", prompt=True, help="Workspace name.") @click.option("--language", prompt=True, help="Account language, default: en-US.") -def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None): +def create_tenant(email: str, language: str | None = None, name: str | None = None): """ Create tenant account """ @@ -947,7 +959,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style("- Deleting orphaned message_files records", fg="white")) query = "DELETE FROM message_files WHERE id IN :ids" with db.engine.begin() as conn: - conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])}) + conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)}) click.echo( click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green") ) diff --git a/api/configs/extra/notion_config.py b/api/configs/extra/notion_config.py index f9c4d73463..9694f3db6b 100644 --- a/api/configs/extra/notion_config.py +++ b/api/configs/extra/notion_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,28 +7,28 @@ class NotionConfig(BaseSettings): Configuration settings for Notion integration """ - NOTION_CLIENT_ID: Optional[str] = Field( + NOTION_CLIENT_ID: str | None = Field( description="Client ID for Notion API authentication. Required for OAuth 2.0 flow.", default=None, ) - NOTION_CLIENT_SECRET: Optional[str] = Field( + NOTION_CLIENT_SECRET: str | None = Field( description="Client secret for Notion API authentication. Required for OAuth 2.0 flow.", default=None, ) - NOTION_INTEGRATION_TYPE: Optional[str] = Field( + NOTION_INTEGRATION_TYPE: str | None = Field( description="Type of Notion integration." " Set to 'internal' for internal integrations, or None for public integrations.", default=None, ) - NOTION_INTERNAL_SECRET: Optional[str] = Field( + NOTION_INTERNAL_SECRET: str | None = Field( description="Secret key for internal Notion integrations. Required when NOTION_INTEGRATION_TYPE is 'internal'.", default=None, ) - NOTION_INTEGRATION_TOKEN: Optional[str] = Field( + NOTION_INTEGRATION_TOKEN: str | None = Field( description="Integration token for Notion API access. Used for direct API calls without OAuth flow.", default=None, ) diff --git a/api/configs/extra/sentry_config.py b/api/configs/extra/sentry_config.py index f76a6bdb95..d72d01b49f 100644 --- a/api/configs/extra/sentry_config.py +++ b/api/configs/extra/sentry_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeFloat from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class SentryConfig(BaseSettings): Configuration settings for Sentry error tracking and performance monitoring """ - SENTRY_DSN: Optional[str] = Field( + SENTRY_DSN: str | None = Field( description="Sentry Data Source Name (DSN)." " This is the unique identifier of your Sentry project, used to send events to the correct project.", default=None, diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index ca63546f7c..6d3934a557 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import ( AliasChoices, @@ -31,6 +31,12 @@ class SecurityConfig(BaseSettings): description="Duration in minutes for which a password reset token remains valid", default=5, ) + + EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a email register token remains valid", + default=5, + ) + CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( description="Duration in minutes for which a change email token remains valid", default=5, @@ -51,7 +57,7 @@ class SecurityConfig(BaseSettings): default=False, ) - ADMIN_API_KEY: Optional[str] = Field( + ADMIN_API_KEY: str | None = Field( description="admin api key for authentication", default=None, ) @@ -91,17 +97,17 @@ class CodeExecutionSandboxConfig(BaseSettings): default="dify-sandbox", ) - CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field( + CODE_EXECUTION_CONNECT_TIMEOUT: float | None = Field( description="Connection timeout in seconds for code execution requests", default=10.0, ) - CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field( + CODE_EXECUTION_READ_TIMEOUT: float | None = Field( description="Read timeout in seconds for code execution requests", default=60.0, ) - CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field( + CODE_EXECUTION_WRITE_TIMEOUT: float | None = Field( description="Write timeout in seconds for code execution request", default=10.0, ) @@ -362,17 +368,17 @@ class HttpConfig(BaseSettings): default=3, ) - SSRF_PROXY_ALL_URL: Optional[str] = Field( + SSRF_PROXY_ALL_URL: str | None = Field( description="Proxy URL for HTTP or HTTPS requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) - SSRF_PROXY_HTTP_URL: Optional[str] = Field( + SSRF_PROXY_HTTP_URL: str | None = Field( description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) - SSRF_PROXY_HTTPS_URL: Optional[str] = Field( + SSRF_PROXY_HTTPS_URL: str | None = Field( description="Proxy URL for HTTPS requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) @@ -414,7 +420,7 @@ class InnerAPIConfig(BaseSettings): default=False, ) - INNER_API_KEY: Optional[str] = Field( + INNER_API_KEY: str | None = Field( description="API key for accessing the internal API", default=None, ) @@ -430,7 +436,7 @@ class LoggingConfig(BaseSettings): default="INFO", ) - LOG_FILE: Optional[str] = Field( + LOG_FILE: str | None = Field( description="File path for log output.", default=None, ) @@ -450,12 +456,12 @@ class LoggingConfig(BaseSettings): default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", ) - LOG_DATEFORMAT: Optional[str] = Field( + LOG_DATEFORMAT: str | None = Field( description="Date format string for log timestamps", default=None, ) - LOG_TZ: Optional[str] = Field( + LOG_TZ: str | None = Field( description="Timezone for log timestamps (e.g., 'America/New_York')", default="UTC", ) @@ -627,22 +633,22 @@ class AuthConfig(BaseSettings): default="/console/api/oauth/authorize", ) - GITHUB_CLIENT_ID: Optional[str] = Field( + GITHUB_CLIENT_ID: str | None = Field( description="GitHub OAuth client ID", default=None, ) - GITHUB_CLIENT_SECRET: Optional[str] = Field( + GITHUB_CLIENT_SECRET: str | None = Field( description="GitHub OAuth client secret", default=None, ) - GOOGLE_CLIENT_ID: Optional[str] = Field( + GOOGLE_CLIENT_ID: str | None = Field( description="Google OAuth client ID", default=None, ) - GOOGLE_CLIENT_SECRET: Optional[str] = Field( + GOOGLE_CLIENT_SECRET: str | None = Field( description="Google OAuth client secret", default=None, ) @@ -677,6 +683,11 @@ class AuthConfig(BaseSettings): default=86400, ) + EMAIL_REGISTER_LOCKOUT_DURATION: PositiveInt = Field( + description="Time (in seconds) a user must wait before retrying email register after exceeding the rate limit.", + default=86400, + ) + class ModerationConfig(BaseSettings): """ @@ -705,42 +716,42 @@ class MailConfig(BaseSettings): Configuration for email services """ - MAIL_TYPE: Optional[str] = Field( + MAIL_TYPE: str | None = Field( description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.", default=None, ) - MAIL_DEFAULT_SEND_FROM: Optional[str] = Field( + MAIL_DEFAULT_SEND_FROM: str | None = Field( description="Default email address to use as the sender", default=None, ) - RESEND_API_KEY: Optional[str] = Field( + RESEND_API_KEY: str | None = Field( description="API key for Resend email service", default=None, ) - RESEND_API_URL: Optional[str] = Field( + RESEND_API_URL: str | None = Field( description="API URL for Resend email service", default=None, ) - SMTP_SERVER: Optional[str] = Field( + SMTP_SERVER: str | None = Field( description="SMTP server hostname", default=None, ) - SMTP_PORT: Optional[int] = Field( + SMTP_PORT: int | None = Field( description="SMTP server port number", default=465, ) - SMTP_USERNAME: Optional[str] = Field( + SMTP_USERNAME: str | None = Field( description="Username for SMTP authentication", default=None, ) - SMTP_PASSWORD: Optional[str] = Field( + SMTP_PASSWORD: str | None = Field( description="Password for SMTP authentication", default=None, ) @@ -760,7 +771,7 @@ class MailConfig(BaseSettings): default=50, ) - SENDGRID_API_KEY: Optional[str] = Field( + SENDGRID_API_KEY: str | None = Field( description="API key for SendGrid service", default=None, ) @@ -783,17 +794,17 @@ class RagEtlConfig(BaseSettings): default="database", ) - UNSTRUCTURED_API_URL: Optional[str] = Field( + UNSTRUCTURED_API_URL: str | None = Field( description="API URL for Unstructured.io service", default=None, ) - UNSTRUCTURED_API_KEY: Optional[str] = Field( + UNSTRUCTURED_API_KEY: str | None = Field( description="API key for Unstructured.io service", default="", ) - SCARF_NO_ANALYTICS: Optional[str] = Field( + SCARF_NO_ANALYTICS: str | None = Field( description="This is about whether to disable Scarf analytics in Unstructured library.", default="false", ) diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 7633ffcf8a..4ad30014c7 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt from pydantic_settings import BaseSettings @@ -40,17 +38,17 @@ class HostedOpenAiConfig(BaseSettings): Configuration for hosted OpenAI service """ - HOSTED_OPENAI_API_KEY: Optional[str] = Field( + HOSTED_OPENAI_API_KEY: str | None = Field( description="API key for hosted OpenAI service", default=None, ) - HOSTED_OPENAI_API_BASE: Optional[str] = Field( + HOSTED_OPENAI_API_BASE: str | None = Field( description="Base URL for hosted OpenAI API", default=None, ) - HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field( + HOSTED_OPENAI_API_ORGANIZATION: str | None = Field( description="Organization ID for hosted OpenAI service", default=None, ) @@ -110,12 +108,12 @@ class HostedAzureOpenAiConfig(BaseSettings): default=False, ) - HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field( + HOSTED_AZURE_OPENAI_API_KEY: str | None = Field( description="API key for hosted Azure OpenAI service", default=None, ) - HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field( + HOSTED_AZURE_OPENAI_API_BASE: str | None = Field( description="Base URL for hosted Azure OpenAI API", default=None, ) @@ -131,12 +129,12 @@ class HostedAnthropicConfig(BaseSettings): Configuration for hosted Anthropic service """ - HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field( + HOSTED_ANTHROPIC_API_BASE: str | None = Field( description="Base URL for hosted Anthropic API", default=None, ) - HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field( + HOSTED_ANTHROPIC_API_KEY: str | None = Field( description="API key for hosted Anthropic service", default=None, ) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 591c24cbe0..dbad90270e 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,5 +1,5 @@ import os -from typing import Any, Literal, Optional +from typing import Any, Literal from urllib.parse import parse_qsl, quote_plus from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field @@ -78,18 +78,18 @@ class StorageConfig(BaseSettings): class VectorStoreConfig(BaseSettings): - VECTOR_STORE: Optional[str] = Field( + VECTOR_STORE: str | None = Field( description="Type of vector store to use for efficient similarity search." " Set to None if not using a vector store.", default=None, ) - VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field( + VECTOR_STORE_WHITELIST_ENABLE: bool | None = Field( description="Enable whitelist for vector store.", default=False, ) - VECTOR_INDEX_NAME_PREFIX: Optional[str] = Field( + VECTOR_INDEX_NAME_PREFIX: str | None = Field( description="Prefix used to create collection name in vector database", default="Vector_index", ) @@ -225,26 +225,26 @@ class CeleryConfig(DatabaseConfig): default="redis", ) - CELERY_BROKER_URL: Optional[str] = Field( + CELERY_BROKER_URL: str | None = Field( description="URL of the message broker for Celery tasks.", default=None, ) - CELERY_USE_SENTINEL: Optional[bool] = Field( + CELERY_USE_SENTINEL: bool | None = Field( description="Whether to use Redis Sentinel for high availability.", default=False, ) - CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field( + CELERY_SENTINEL_MASTER_NAME: str | None = Field( description="Name of the Redis Sentinel master.", default=None, ) - CELERY_SENTINEL_PASSWORD: Optional[str] = Field( + CELERY_SENTINEL_PASSWORD: str | None = Field( description="Password of the Redis Sentinel master.", default=None, ) - CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( + CELERY_SENTINEL_SOCKET_TIMEOUT: PositiveFloat | None = Field( description="Timeout for Redis Sentinel socket operations in seconds.", default=0.1, ) @@ -268,12 +268,12 @@ class InternalTestConfig(BaseSettings): Configuration settings for Internal Test """ - AWS_SECRET_ACCESS_KEY: Optional[str] = Field( + AWS_SECRET_ACCESS_KEY: str | None = Field( description="Internal test AWS secret access key", default=None, ) - AWS_ACCESS_KEY_ID: Optional[str] = Field( + AWS_ACCESS_KEY_ID: str | None = Field( description="Internal test AWS access key ID", default=None, ) @@ -284,15 +284,15 @@ class DatasetQueueMonitorConfig(BaseSettings): Configuration settings for Dataset Queue Monitor """ - QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field( + QUEUE_MONITOR_THRESHOLD: NonNegativeInt | None = Field( description="Threshold for dataset queue monitor", default=200, ) - QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field( + QUEUE_MONITOR_ALERT_EMAILS: str | None = Field( description="Emails for dataset queue monitor alert, separated by commas", default=None, ) - QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field( + QUEUE_MONITOR_INTERVAL: NonNegativeFloat | None = Field( description="Interval for dataset queue monitor in minutes", default=30, ) diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 16dca98cfa..4705b28c69 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt from pydantic_settings import BaseSettings @@ -19,12 +17,12 @@ class RedisConfig(BaseSettings): default=6379, ) - REDIS_USERNAME: Optional[str] = Field( + REDIS_USERNAME: str | None = Field( description="Username for Redis authentication (if required)", default=None, ) - REDIS_PASSWORD: Optional[str] = Field( + REDIS_PASSWORD: str | None = Field( description="Password for Redis authentication (if required)", default=None, ) @@ -44,47 +42,47 @@ class RedisConfig(BaseSettings): default="CERT_NONE", ) - REDIS_SSL_CA_CERTS: Optional[str] = Field( + REDIS_SSL_CA_CERTS: str | None = Field( description="Path to the CA certificate file for SSL verification", default=None, ) - REDIS_SSL_CERTFILE: Optional[str] = Field( + REDIS_SSL_CERTFILE: str | None = Field( description="Path to the client certificate file for SSL authentication", default=None, ) - REDIS_SSL_KEYFILE: Optional[str] = Field( + REDIS_SSL_KEYFILE: str | None = Field( description="Path to the client private key file for SSL authentication", default=None, ) - REDIS_USE_SENTINEL: Optional[bool] = Field( + REDIS_USE_SENTINEL: bool | None = Field( description="Enable Redis Sentinel mode for high availability", default=False, ) - REDIS_SENTINELS: Optional[str] = Field( + REDIS_SENTINELS: str | None = Field( description="Comma-separated list of Redis Sentinel nodes (host:port)", default=None, ) - REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field( + REDIS_SENTINEL_SERVICE_NAME: str | None = Field( description="Name of the Redis Sentinel service to monitor", default=None, ) - REDIS_SENTINEL_USERNAME: Optional[str] = Field( + REDIS_SENTINEL_USERNAME: str | None = Field( description="Username for Redis Sentinel authentication (if required)", default=None, ) - REDIS_SENTINEL_PASSWORD: Optional[str] = Field( + REDIS_SENTINEL_PASSWORD: str | None = Field( description="Password for Redis Sentinel authentication (if required)", default=None, ) - REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( + REDIS_SENTINEL_SOCKET_TIMEOUT: PositiveFloat | None = Field( description="Socket timeout in seconds for Redis Sentinel connections", default=0.1, ) @@ -94,12 +92,12 @@ class RedisConfig(BaseSettings): default=False, ) - REDIS_CLUSTERS: Optional[str] = Field( + REDIS_CLUSTERS: str | None = Field( description="Comma-separated list of Redis Clusters nodes (host:port)", default=None, ) - REDIS_CLUSTERS_PASSWORD: Optional[str] = Field( + REDIS_CLUSTERS_PASSWORD: str | None = Field( description="Password for Redis Clusters authentication (if required)", default=None, ) diff --git a/api/configs/middleware/storage/aliyun_oss_storage_config.py b/api/configs/middleware/storage/aliyun_oss_storage_config.py index 07eb527170..331c486d54 100644 --- a/api/configs/middleware/storage/aliyun_oss_storage_config.py +++ b/api/configs/middleware/storage/aliyun_oss_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,37 +7,37 @@ class AliyunOSSStorageConfig(BaseSettings): Configuration settings for Aliyun Object Storage Service (OSS) """ - ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field( + ALIYUN_OSS_BUCKET_NAME: str | None = Field( description="Name of the Aliyun OSS bucket to store and retrieve objects", default=None, ) - ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field( + ALIYUN_OSS_ACCESS_KEY: str | None = Field( description="Access key ID for authenticating with Aliyun OSS", default=None, ) - ALIYUN_OSS_SECRET_KEY: Optional[str] = Field( + ALIYUN_OSS_SECRET_KEY: str | None = Field( description="Secret access key for authenticating with Aliyun OSS", default=None, ) - ALIYUN_OSS_ENDPOINT: Optional[str] = Field( + ALIYUN_OSS_ENDPOINT: str | None = Field( description="URL of the Aliyun OSS endpoint for your chosen region", default=None, ) - ALIYUN_OSS_REGION: Optional[str] = Field( + ALIYUN_OSS_REGION: str | None = Field( description="Aliyun OSS region where your bucket is located (e.g., 'oss-cn-hangzhou')", default=None, ) - ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field( + ALIYUN_OSS_AUTH_VERSION: str | None = Field( description="Version of the authentication protocol to use with Aliyun OSS (e.g., 'v4')", default=None, ) - ALIYUN_OSS_PATH: Optional[str] = Field( + ALIYUN_OSS_PATH: str | None = Field( description="Base path within the bucket to store objects (e.g., 'my-app-data/')", default=None, ) diff --git a/api/configs/middleware/storage/amazon_s3_storage_config.py b/api/configs/middleware/storage/amazon_s3_storage_config.py index e14c210718..9277a335f7 100644 --- a/api/configs/middleware/storage/amazon_s3_storage_config.py +++ b/api/configs/middleware/storage/amazon_s3_storage_config.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +9,27 @@ class S3StorageConfig(BaseSettings): Configuration settings for S3-compatible object storage """ - S3_ENDPOINT: Optional[str] = Field( + S3_ENDPOINT: str | None = Field( description="URL of the S3-compatible storage endpoint (e.g., 'https://s3.amazonaws.com')", default=None, ) - S3_REGION: Optional[str] = Field( + S3_REGION: str | None = Field( description="Region where the S3 bucket is located (e.g., 'us-east-1')", default=None, ) - S3_BUCKET_NAME: Optional[str] = Field( + S3_BUCKET_NAME: str | None = Field( description="Name of the S3 bucket to store and retrieve objects", default=None, ) - S3_ACCESS_KEY: Optional[str] = Field( + S3_ACCESS_KEY: str | None = Field( description="Access key ID for authenticating with the S3 service", default=None, ) - S3_SECRET_KEY: Optional[str] = Field( + S3_SECRET_KEY: str | None = Field( description="Secret access key for authenticating with the S3 service", default=None, ) diff --git a/api/configs/middleware/storage/azure_blob_storage_config.py b/api/configs/middleware/storage/azure_blob_storage_config.py index b7ab5247a9..7195d446b1 100644 --- a/api/configs/middleware/storage/azure_blob_storage_config.py +++ b/api/configs/middleware/storage/azure_blob_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class AzureBlobStorageConfig(BaseSettings): Configuration settings for Azure Blob Storage """ - AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field( + AZURE_BLOB_ACCOUNT_NAME: str | None = Field( description="Name of the Azure Storage account (e.g., 'mystorageaccount')", default=None, ) - AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field( + AZURE_BLOB_ACCOUNT_KEY: str | None = Field( description="Access key for authenticating with the Azure Storage account", default=None, ) - AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field( + AZURE_BLOB_CONTAINER_NAME: str | None = Field( description="Name of the Azure Blob container to store and retrieve objects", default=None, ) - AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field( + AZURE_BLOB_ACCOUNT_URL: str | None = Field( description="URL of the Azure Blob storage endpoint (e.g., 'https://mystorageaccount.blob.core.windows.net')", default=None, ) diff --git a/api/configs/middleware/storage/baidu_obs_storage_config.py b/api/configs/middleware/storage/baidu_obs_storage_config.py index e7913b0acc..138a0db650 100644 --- a/api/configs/middleware/storage/baidu_obs_storage_config.py +++ b/api/configs/middleware/storage/baidu_obs_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class BaiduOBSStorageConfig(BaseSettings): Configuration settings for Baidu Object Storage Service (OBS) """ - BAIDU_OBS_BUCKET_NAME: Optional[str] = Field( + BAIDU_OBS_BUCKET_NAME: str | None = Field( description="Name of the Baidu OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')", default=None, ) - BAIDU_OBS_ACCESS_KEY: Optional[str] = Field( + BAIDU_OBS_ACCESS_KEY: str | None = Field( description="Access Key ID for authenticating with Baidu OBS", default=None, ) - BAIDU_OBS_SECRET_KEY: Optional[str] = Field( + BAIDU_OBS_SECRET_KEY: str | None = Field( description="Secret Access Key for authenticating with Baidu OBS", default=None, ) - BAIDU_OBS_ENDPOINT: Optional[str] = Field( + BAIDU_OBS_ENDPOINT: str | None = Field( description="URL of the Baidu OSS endpoint for your chosen region (e.g., 'https://.bj.bcebos.com')", default=None, ) diff --git a/api/configs/middleware/storage/clickzetta_volume_storage_config.py b/api/configs/middleware/storage/clickzetta_volume_storage_config.py index 56e1b6a957..035650d98a 100644 --- a/api/configs/middleware/storage/clickzetta_volume_storage_config.py +++ b/api/configs/middleware/storage/clickzetta_volume_storage_config.py @@ -1,7 +1,5 @@ """ClickZetta Volume Storage Configuration""" -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,17 +7,17 @@ from pydantic_settings import BaseSettings class ClickZettaVolumeStorageConfig(BaseSettings): """Configuration for ClickZetta Volume storage.""" - CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field( + CLICKZETTA_VOLUME_USERNAME: str | None = Field( description="Username for ClickZetta Volume authentication", default=None, ) - CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field( + CLICKZETTA_VOLUME_PASSWORD: str | None = Field( description="Password for ClickZetta Volume authentication", default=None, ) - CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field( + CLICKZETTA_VOLUME_INSTANCE: str | None = Field( description="ClickZetta instance identifier", default=None, ) @@ -49,7 +47,7 @@ class ClickZettaVolumeStorageConfig(BaseSettings): default="user", ) - CLICKZETTA_VOLUME_NAME: Optional[str] = Field( + CLICKZETTA_VOLUME_NAME: str | None = Field( description="ClickZetta volume name for external volumes", default=None, ) diff --git a/api/configs/middleware/storage/google_cloud_storage_config.py b/api/configs/middleware/storage/google_cloud_storage_config.py index e5d763d7f5..a63eb798a8 100644 --- a/api/configs/middleware/storage/google_cloud_storage_config.py +++ b/api/configs/middleware/storage/google_cloud_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class GoogleCloudStorageConfig(BaseSettings): Configuration settings for Google Cloud Storage """ - GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field( + GOOGLE_STORAGE_BUCKET_NAME: str | None = Field( description="Name of the Google Cloud Storage bucket to store and retrieve objects (e.g., 'my-gcs-bucket')", default=None, ) - GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field( + GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: str | None = Field( description="Base64-encoded JSON key file for Google Cloud service account authentication", default=None, ) diff --git a/api/configs/middleware/storage/huawei_obs_storage_config.py b/api/configs/middleware/storage/huawei_obs_storage_config.py index be983b5187..5b5cd2f750 100644 --- a/api/configs/middleware/storage/huawei_obs_storage_config.py +++ b/api/configs/middleware/storage/huawei_obs_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class HuaweiCloudOBSStorageConfig(BaseSettings): Configuration settings for Huawei Cloud Object Storage Service (OBS) """ - HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field( + HUAWEI_OBS_BUCKET_NAME: str | None = Field( description="Name of the Huawei Cloud OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')", default=None, ) - HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field( + HUAWEI_OBS_ACCESS_KEY: str | None = Field( description="Access Key ID for authenticating with Huawei Cloud OBS", default=None, ) - HUAWEI_OBS_SECRET_KEY: Optional[str] = Field( + HUAWEI_OBS_SECRET_KEY: str | None = Field( description="Secret Access Key for authenticating with Huawei Cloud OBS", default=None, ) - HUAWEI_OBS_SERVER: Optional[str] = Field( + HUAWEI_OBS_SERVER: str | None = Field( description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')", default=None, ) diff --git a/api/configs/middleware/storage/oci_storage_config.py b/api/configs/middleware/storage/oci_storage_config.py index edc245bcac..70815a0055 100644 --- a/api/configs/middleware/storage/oci_storage_config.py +++ b/api/configs/middleware/storage/oci_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class OCIStorageConfig(BaseSettings): Configuration settings for Oracle Cloud Infrastructure (OCI) Object Storage """ - OCI_ENDPOINT: Optional[str] = Field( + OCI_ENDPOINT: str | None = Field( description="URL of the OCI Object Storage endpoint (e.g., 'https://objectstorage.us-phoenix-1.oraclecloud.com')", default=None, ) - OCI_REGION: Optional[str] = Field( + OCI_REGION: str | None = Field( description="OCI region where the bucket is located (e.g., 'us-phoenix-1')", default=None, ) - OCI_BUCKET_NAME: Optional[str] = Field( + OCI_BUCKET_NAME: str | None = Field( description="Name of the OCI Object Storage bucket to store and retrieve objects (e.g., 'my-oci-bucket')", default=None, ) - OCI_ACCESS_KEY: Optional[str] = Field( + OCI_ACCESS_KEY: str | None = Field( description="Access key (also known as API key) for authenticating with OCI Object Storage", default=None, ) - OCI_SECRET_KEY: Optional[str] = Field( + OCI_SECRET_KEY: str | None = Field( description="Secret key associated with the access key for authenticating with OCI Object Storage", default=None, ) diff --git a/api/configs/middleware/storage/supabase_storage_config.py b/api/configs/middleware/storage/supabase_storage_config.py index dcf7c20cf9..7f140fc5b9 100644 --- a/api/configs/middleware/storage/supabase_storage_config.py +++ b/api/configs/middleware/storage/supabase_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,17 +7,17 @@ class SupabaseStorageConfig(BaseSettings): Configuration settings for Supabase Object Storage Service """ - SUPABASE_BUCKET_NAME: Optional[str] = Field( + SUPABASE_BUCKET_NAME: str | None = Field( description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')", default=None, ) - SUPABASE_API_KEY: Optional[str] = Field( + SUPABASE_API_KEY: str | None = Field( description="API KEY for authenticating with Supabase", default=None, ) - SUPABASE_URL: Optional[str] = Field( + SUPABASE_URL: str | None = Field( description="URL of the Supabase", default=None, ) diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py index 255c4e8938..e297e748e9 100644 --- a/api/configs/middleware/storage/tencent_cos_storage_config.py +++ b/api/configs/middleware/storage/tencent_cos_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class TencentCloudCOSStorageConfig(BaseSettings): Configuration settings for Tencent Cloud Object Storage (COS) """ - TENCENT_COS_BUCKET_NAME: Optional[str] = Field( + TENCENT_COS_BUCKET_NAME: str | None = Field( description="Name of the Tencent Cloud COS bucket to store and retrieve objects", default=None, ) - TENCENT_COS_REGION: Optional[str] = Field( + TENCENT_COS_REGION: str | None = Field( description="Tencent Cloud region where the COS bucket is located (e.g., 'ap-guangzhou')", default=None, ) - TENCENT_COS_SECRET_ID: Optional[str] = Field( + TENCENT_COS_SECRET_ID: str | None = Field( description="SecretId for authenticating with Tencent Cloud COS (part of API credentials)", default=None, ) - TENCENT_COS_SECRET_KEY: Optional[str] = Field( + TENCENT_COS_SECRET_KEY: str | None = Field( description="SecretKey for authenticating with Tencent Cloud COS (part of API credentials)", default=None, ) - TENCENT_COS_SCHEME: Optional[str] = Field( + TENCENT_COS_SCHEME: str | None = Field( description="Protocol scheme for COS requests: 'https' (recommended) or 'http'", default=None, ) diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py index 06c3ae4d3e..be01f2dc36 100644 --- a/api/configs/middleware/storage/volcengine_tos_storage_config.py +++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class VolcengineTOSStorageConfig(BaseSettings): Configuration settings for Volcengine Tinder Object Storage (TOS) """ - VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field( + VOLCENGINE_TOS_BUCKET_NAME: str | None = Field( description="Name of the Volcengine TOS bucket to store and retrieve objects (e.g., 'my-tos-bucket')", default=None, ) - VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field( + VOLCENGINE_TOS_ACCESS_KEY: str | None = Field( description="Access Key ID for authenticating with Volcengine TOS", default=None, ) - VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field( + VOLCENGINE_TOS_SECRET_KEY: str | None = Field( description="Secret Access Key for authenticating with Volcengine TOS", default=None, ) - VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field( + VOLCENGINE_TOS_ENDPOINT: str | None = Field( description="URL of the Volcengine TOS endpoint (e.g., 'https://tos-cn-beijing.volces.com')", default=None, ) - VOLCENGINE_TOS_REGION: Optional[str] = Field( + VOLCENGINE_TOS_REGION: str | None = Field( description="Volcengine region where the TOS bucket is located (e.g., 'cn-beijing')", default=None, ) diff --git a/api/configs/middleware/vdb/analyticdb_config.py b/api/configs/middleware/vdb/analyticdb_config.py index cb8dc7d724..539b9c0963 100644 --- a/api/configs/middleware/vdb/analyticdb_config.py +++ b/api/configs/middleware/vdb/analyticdb_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -11,37 +9,37 @@ class AnalyticdbConfig(BaseSettings): https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled """ - ANALYTICDB_KEY_ID: Optional[str] = Field( + ANALYTICDB_KEY_ID: str | None = Field( default=None, description="The Access Key ID provided by Alibaba Cloud for API authentication." ) - ANALYTICDB_KEY_SECRET: Optional[str] = Field( + ANALYTICDB_KEY_SECRET: str | None = Field( default=None, description="The Secret Access Key corresponding to the Access Key ID for secure API access." ) - ANALYTICDB_REGION_ID: Optional[str] = Field( + ANALYTICDB_REGION_ID: str | None = Field( default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou', 'ap-southeast-1').", ) - ANALYTICDB_INSTANCE_ID: Optional[str] = Field( + ANALYTICDB_INSTANCE_ID: str | None = Field( default=None, description="The unique identifier of the AnalyticDB instance you want to connect to.", ) - ANALYTICDB_ACCOUNT: Optional[str] = Field( + ANALYTICDB_ACCOUNT: str | None = Field( default=None, description="The account name used to log in to the AnalyticDB instance" " (usually the initial account created with the instance).", ) - ANALYTICDB_PASSWORD: Optional[str] = Field( + ANALYTICDB_PASSWORD: str | None = Field( default=None, description="The password associated with the AnalyticDB account for database authentication." ) - ANALYTICDB_NAMESPACE: Optional[str] = Field( + ANALYTICDB_NAMESPACE: str | None = Field( default=None, description="The namespace within AnalyticDB for schema isolation (if using namespace feature)." ) - ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field( + ANALYTICDB_NAMESPACE_PASSWORD: str | None = Field( default=None, description="The password for accessing the specified namespace within the AnalyticDB instance" " (if namespace feature is enabled).", ) - ANALYTICDB_HOST: Optional[str] = Field( + ANALYTICDB_HOST: str | None = Field( default=None, description="The host of the AnalyticDB instance you want to connect to." ) ANALYTICDB_PORT: PositiveInt = Field( diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py index 44742c2e2f..4b6ddb3bde 100644 --- a/api/configs/middleware/vdb/baidu_vector_config.py +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class BaiduVectorDBConfig(BaseSettings): Configuration settings for Baidu Vector Database """ - BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field( + BAIDU_VECTOR_DB_ENDPOINT: str | None = Field( description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')", default=None, ) @@ -19,17 +17,17 @@ class BaiduVectorDBConfig(BaseSettings): default=30000, ) - BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field( + BAIDU_VECTOR_DB_ACCOUNT: str | None = Field( description="Account for authenticating with the Baidu Vector Database", default=None, ) - BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field( + BAIDU_VECTOR_DB_API_KEY: str | None = Field( description="API key for authenticating with the Baidu Vector Database service", default=None, ) - BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field( + BAIDU_VECTOR_DB_DATABASE: str | None = Field( description="Name of the specific Baidu Vector Database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/chroma_config.py b/api/configs/middleware/vdb/chroma_config.py index e83a9902de..3a78980b91 100644 --- a/api/configs/middleware/vdb/chroma_config.py +++ b/api/configs/middleware/vdb/chroma_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class ChromaConfig(BaseSettings): Configuration settings for Chroma vector database """ - CHROMA_HOST: Optional[str] = Field( + CHROMA_HOST: str | None = Field( description="Hostname or IP address of the Chroma server (e.g., 'localhost' or '192.168.1.100')", default=None, ) @@ -19,22 +17,22 @@ class ChromaConfig(BaseSettings): default=8000, ) - CHROMA_TENANT: Optional[str] = Field( + CHROMA_TENANT: str | None = Field( description="Tenant identifier for multi-tenancy support in Chroma", default=None, ) - CHROMA_DATABASE: Optional[str] = Field( + CHROMA_DATABASE: str | None = Field( description="Name of the Chroma database to connect to", default=None, ) - CHROMA_AUTH_PROVIDER: Optional[str] = Field( + CHROMA_AUTH_PROVIDER: str | None = Field( description="Authentication provider for Chroma (e.g., 'basic', 'token', or a custom provider)", default=None, ) - CHROMA_AUTH_CREDENTIALS: Optional[str] = Field( + CHROMA_AUTH_CREDENTIALS: str | None = Field( description="Authentication credentials for Chroma (format depends on the auth provider)", default=None, ) diff --git a/api/configs/middleware/vdb/clickzetta_config.py b/api/configs/middleware/vdb/clickzetta_config.py index 61bc01202b..e8172b5299 100644 --- a/api/configs/middleware/vdb/clickzetta_config.py +++ b/api/configs/middleware/vdb/clickzetta_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,62 +7,62 @@ class ClickzettaConfig(BaseSettings): Clickzetta Lakehouse vector database configuration """ - CLICKZETTA_USERNAME: Optional[str] = Field( + CLICKZETTA_USERNAME: str | None = Field( description="Username for authenticating with Clickzetta Lakehouse", default=None, ) - CLICKZETTA_PASSWORD: Optional[str] = Field( + CLICKZETTA_PASSWORD: str | None = Field( description="Password for authenticating with Clickzetta Lakehouse", default=None, ) - CLICKZETTA_INSTANCE: Optional[str] = Field( + CLICKZETTA_INSTANCE: str | None = Field( description="Clickzetta Lakehouse instance ID", default=None, ) - CLICKZETTA_SERVICE: Optional[str] = Field( + CLICKZETTA_SERVICE: str | None = Field( description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')", default="api.clickzetta.com", ) - CLICKZETTA_WORKSPACE: Optional[str] = Field( + CLICKZETTA_WORKSPACE: str | None = Field( description="Clickzetta workspace name", default="default", ) - CLICKZETTA_VCLUSTER: Optional[str] = Field( + CLICKZETTA_VCLUSTER: str | None = Field( description="Clickzetta virtual cluster name", default="default_ap", ) - CLICKZETTA_SCHEMA: Optional[str] = Field( + CLICKZETTA_SCHEMA: str | None = Field( description="Database schema name in Clickzetta", default="public", ) - CLICKZETTA_BATCH_SIZE: Optional[int] = Field( + CLICKZETTA_BATCH_SIZE: int | None = Field( description="Batch size for bulk insert operations", default=100, ) - CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field( + CLICKZETTA_ENABLE_INVERTED_INDEX: bool | None = Field( description="Enable inverted index for full-text search capabilities", default=True, ) - CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field( + CLICKZETTA_ANALYZER_TYPE: str | None = Field( description="Analyzer type for full-text search: keyword, english, chinese, unicode", default="chinese", ) - CLICKZETTA_ANALYZER_MODE: Optional[str] = Field( + CLICKZETTA_ANALYZER_MODE: str | None = Field( description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)", default="smart", ) - CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field( + CLICKZETTA_VECTOR_DISTANCE_FUNCTION: str | None = Field( description="Distance function for vector similarity: l2_distance or cosine_distance", default="cosine_distance", ) diff --git a/api/configs/middleware/vdb/couchbase_config.py b/api/configs/middleware/vdb/couchbase_config.py index b81cbf8959..a365e30263 100644 --- a/api/configs/middleware/vdb/couchbase_config.py +++ b/api/configs/middleware/vdb/couchbase_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class CouchbaseConfig(BaseSettings): Couchbase configs """ - COUCHBASE_CONNECTION_STRING: Optional[str] = Field( + COUCHBASE_CONNECTION_STRING: str | None = Field( description="COUCHBASE connection string", default=None, ) - COUCHBASE_USER: Optional[str] = Field( + COUCHBASE_USER: str | None = Field( description="COUCHBASE user", default=None, ) - COUCHBASE_PASSWORD: Optional[str] = Field( + COUCHBASE_PASSWORD: str | None = Field( description="COUCHBASE password", default=None, ) - COUCHBASE_BUCKET_NAME: Optional[str] = Field( + COUCHBASE_BUCKET_NAME: str | None = Field( description="COUCHBASE bucket name", default=None, ) - COUCHBASE_SCOPE_NAME: Optional[str] = Field( + COUCHBASE_SCOPE_NAME: str | None = Field( description="COUCHBASE scope name", default=None, ) diff --git a/api/configs/middleware/vdb/elasticsearch_config.py b/api/configs/middleware/vdb/elasticsearch_config.py index 8c4b333d45..a0efd41417 100644 --- a/api/configs/middleware/vdb/elasticsearch_config.py +++ b/api/configs/middleware/vdb/elasticsearch_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt, model_validator from pydantic_settings import BaseSettings @@ -10,7 +8,7 @@ class ElasticsearchConfig(BaseSettings): Can load from environment variables or .env files. """ - ELASTICSEARCH_HOST: Optional[str] = Field( + ELASTICSEARCH_HOST: str | None = Field( description="Hostname or IP address of the Elasticsearch server (e.g., 'localhost' or '192.168.1.100')", default="127.0.0.1", ) @@ -20,30 +18,28 @@ class ElasticsearchConfig(BaseSettings): default=9200, ) - ELASTICSEARCH_USERNAME: Optional[str] = Field( + ELASTICSEARCH_USERNAME: str | None = Field( description="Username for authenticating with Elasticsearch (default is 'elastic')", default="elastic", ) - ELASTICSEARCH_PASSWORD: Optional[str] = Field( + ELASTICSEARCH_PASSWORD: str | None = Field( description="Password for authenticating with Elasticsearch (default is 'elastic')", default="elastic", ) # Elastic Cloud (optional) - ELASTICSEARCH_USE_CLOUD: Optional[bool] = Field( + ELASTICSEARCH_USE_CLOUD: bool | None = Field( description="Set to True to use Elastic Cloud instead of self-hosted Elasticsearch", default=False ) - ELASTICSEARCH_CLOUD_URL: Optional[str] = Field( + ELASTICSEARCH_CLOUD_URL: str | None = Field( description="Full URL for Elastic Cloud deployment (e.g., 'https://example.es.region.aws.found.io:443')", default=None, ) - ELASTICSEARCH_API_KEY: Optional[str] = Field( - description="API key for authenticating with Elastic Cloud", default=None - ) + ELASTICSEARCH_API_KEY: str | None = Field(description="API key for authenticating with Elastic Cloud", default=None) # Common options - ELASTICSEARCH_CA_CERTS: Optional[str] = Field( + ELASTICSEARCH_CA_CERTS: str | None = Field( description="Path to CA certificate file for SSL verification", default=None ) ELASTICSEARCH_VERIFY_CERTS: bool = Field( diff --git a/api/configs/middleware/vdb/huawei_cloud_config.py b/api/configs/middleware/vdb/huawei_cloud_config.py index 2290c60499..d64cb870fa 100644 --- a/api/configs/middleware/vdb/huawei_cloud_config.py +++ b/api/configs/middleware/vdb/huawei_cloud_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,17 +7,17 @@ class HuaweiCloudConfig(BaseSettings): Configuration settings for Huawei cloud search service """ - HUAWEI_CLOUD_HOSTS: Optional[str] = Field( + HUAWEI_CLOUD_HOSTS: str | None = Field( description="Hostname or IP address of the Huawei cloud search service instance", default=None, ) - HUAWEI_CLOUD_USER: Optional[str] = Field( + HUAWEI_CLOUD_USER: str | None = Field( description="Username for authenticating with Huawei cloud search service", default=None, ) - HUAWEI_CLOUD_PASSWORD: Optional[str] = Field( + HUAWEI_CLOUD_PASSWORD: str | None = Field( description="Password for authenticating with Huawei cloud search service", default=None, ) diff --git a/api/configs/middleware/vdb/lindorm_config.py b/api/configs/middleware/vdb/lindorm_config.py index e80e3f4a35..262d5a1f26 100644 --- a/api/configs/middleware/vdb/lindorm_config.py +++ b/api/configs/middleware/vdb/lindorm_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class LindormConfig(BaseSettings): Lindorm configs """ - LINDORM_URL: Optional[str] = Field( + LINDORM_URL: str | None = Field( description="Lindorm url", default=None, ) - LINDORM_USERNAME: Optional[str] = Field( + LINDORM_USERNAME: str | None = Field( description="Lindorm user", default=None, ) - LINDORM_PASSWORD: Optional[str] = Field( + LINDORM_PASSWORD: str | None = Field( description="Lindorm password", default=None, ) - DEFAULT_INDEX_TYPE: Optional[str] = Field( + DEFAULT_INDEX_TYPE: str | None = Field( description="Lindorm Vector Index Type, hnsw or flat is available in dify", default="hnsw", ) - DEFAULT_DISTANCE_TYPE: Optional[str] = Field( + DEFAULT_DISTANCE_TYPE: str | None = Field( description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2" ) - USING_UGC_INDEX: Optional[bool] = Field( + USING_UGC_INDEX: bool | None = Field( description="Using UGC index will store the same type of Index in a single index but can retrieve separately.", default=False, ) - LINDORM_QUERY_TIMEOUT: Optional[float] = Field(description="The lindorm search request timeout (s)", default=2.0) + LINDORM_QUERY_TIMEOUT: float | None = Field(description="The lindorm search request timeout (s)", default=2.0) diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index d398ef5bd8..05cee51cc9 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class MilvusConfig(BaseSettings): Configuration settings for Milvus vector database """ - MILVUS_URI: Optional[str] = Field( + MILVUS_URI: str | None = Field( description="URI for connecting to the Milvus server (e.g., 'http://localhost:19530' or 'https://milvus-instance.example.com:19530')", default="http://127.0.0.1:19530", ) - MILVUS_TOKEN: Optional[str] = Field( + MILVUS_TOKEN: str | None = Field( description="Authentication token for Milvus, if token-based authentication is enabled", default=None, ) - MILVUS_USER: Optional[str] = Field( + MILVUS_USER: str | None = Field( description="Username for authenticating with Milvus, if username/password authentication is enabled", default=None, ) - MILVUS_PASSWORD: Optional[str] = Field( + MILVUS_PASSWORD: str | None = Field( description="Password for authenticating with Milvus, if username/password authentication is enabled", default=None, ) @@ -40,7 +38,7 @@ class MilvusConfig(BaseSettings): default=True, ) - MILVUS_ANALYZER_PARAMS: Optional[str] = Field( + MILVUS_ANALYZER_PARAMS: str | None = Field( description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.', default=None, ) diff --git a/api/configs/middleware/vdb/oceanbase_config.py b/api/configs/middleware/vdb/oceanbase_config.py index 9b11a22732..8437328e76 100644 --- a/api/configs/middleware/vdb/oceanbase_config.py +++ b/api/configs/middleware/vdb/oceanbase_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class OceanBaseVectorConfig(BaseSettings): Configuration settings for OceanBase Vector database """ - OCEANBASE_VECTOR_HOST: Optional[str] = Field( + OCEANBASE_VECTOR_HOST: str | None = Field( description="Hostname or IP address of the OceanBase Vector server (e.g. 'localhost')", default=None, ) - OCEANBASE_VECTOR_PORT: Optional[PositiveInt] = Field( + OCEANBASE_VECTOR_PORT: PositiveInt | None = Field( description="Port number on which the OceanBase Vector server is listening (default is 2881)", default=2881, ) - OCEANBASE_VECTOR_USER: Optional[str] = Field( + OCEANBASE_VECTOR_USER: str | None = Field( description="Username for authenticating with the OceanBase Vector database", default=None, ) - OCEANBASE_VECTOR_PASSWORD: Optional[str] = Field( + OCEANBASE_VECTOR_PASSWORD: str | None = Field( description="Password for authenticating with the OceanBase Vector database", default=None, ) - OCEANBASE_VECTOR_DATABASE: Optional[str] = Field( + OCEANBASE_VECTOR_DATABASE: str | None = Field( description="Name of the OceanBase Vector database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/opengauss_config.py b/api/configs/middleware/vdb/opengauss_config.py index 87ea292ab4..b57c1e59a9 100644 --- a/api/configs/middleware/vdb/opengauss_config.py +++ b/api/configs/middleware/vdb/opengauss_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class OpenGaussConfig(BaseSettings): Configuration settings for OpenGauss """ - OPENGAUSS_HOST: Optional[str] = Field( + OPENGAUSS_HOST: str | None = Field( description="Hostname or IP address of the OpenGauss server(e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class OpenGaussConfig(BaseSettings): default=6600, ) - OPENGAUSS_USER: Optional[str] = Field( + OPENGAUSS_USER: str | None = Field( description="Username for authenticating with the OpenGauss database", default=None, ) - OPENGAUSS_PASSWORD: Optional[str] = Field( + OPENGAUSS_PASSWORD: str | None = Field( description="Password for authenticating with the OpenGauss database", default=None, ) - OPENGAUSS_DATABASE: Optional[str] = Field( + OPENGAUSS_DATABASE: str | None = Field( description="Name of the OpenGauss database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 9fd9b60194..ba015a6eb9 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -1,5 +1,5 @@ -import enum -from typing import Literal, Optional +from enum import Enum +from typing import Literal from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -10,7 +10,7 @@ class OpenSearchConfig(BaseSettings): Configuration settings for OpenSearch """ - class AuthMethod(enum.StrEnum): + class AuthMethod(Enum): """ Authentication method for OpenSearch """ @@ -18,7 +18,7 @@ class OpenSearchConfig(BaseSettings): BASIC = "basic" AWS_MANAGED_IAM = "aws_managed_iam" - OPENSEARCH_HOST: Optional[str] = Field( + OPENSEARCH_HOST: str | None = Field( description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", default=None, ) @@ -43,21 +43,21 @@ class OpenSearchConfig(BaseSettings): default=AuthMethod.BASIC, ) - OPENSEARCH_USER: Optional[str] = Field( + OPENSEARCH_USER: str | None = Field( description="Username for authenticating with OpenSearch", default=None, ) - OPENSEARCH_PASSWORD: Optional[str] = Field( + OPENSEARCH_PASSWORD: str | None = Field( description="Password for authenticating with OpenSearch", default=None, ) - OPENSEARCH_AWS_REGION: Optional[str] = Field( + OPENSEARCH_AWS_REGION: str | None = Field( description="AWS region for OpenSearch (e.g. 'us-west-2')", default=None, ) - OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field( + OPENSEARCH_AWS_SERVICE: Literal["es", "aoss"] | None = Field( description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None ) diff --git a/api/configs/middleware/vdb/oracle_config.py b/api/configs/middleware/vdb/oracle_config.py index ea39909ef4..dc179e8e4f 100644 --- a/api/configs/middleware/vdb/oracle_config.py +++ b/api/configs/middleware/vdb/oracle_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,33 +7,33 @@ class OracleConfig(BaseSettings): Configuration settings for Oracle database """ - ORACLE_USER: Optional[str] = Field( + ORACLE_USER: str | None = Field( description="Username for authenticating with the Oracle database", default=None, ) - ORACLE_PASSWORD: Optional[str] = Field( + ORACLE_PASSWORD: str | None = Field( description="Password for authenticating with the Oracle database", default=None, ) - ORACLE_DSN: Optional[str] = Field( + ORACLE_DSN: str | None = Field( description="Oracle database connection string. For traditional database, use format 'host:port/service_name'. " "For autonomous database, use the service name from tnsnames.ora in the wallet", default=None, ) - ORACLE_CONFIG_DIR: Optional[str] = Field( + ORACLE_CONFIG_DIR: str | None = Field( description="Directory containing the tnsnames.ora configuration file. Only used in thin mode connection", default=None, ) - ORACLE_WALLET_LOCATION: Optional[str] = Field( + ORACLE_WALLET_LOCATION: str | None = Field( description="Oracle wallet directory path containing the wallet files for secure connection", default=None, ) - ORACLE_WALLET_PASSWORD: Optional[str] = Field( + ORACLE_WALLET_PASSWORD: str | None = Field( description="Password to decrypt the Oracle wallet, if it is encrypted", default=None, ) diff --git a/api/configs/middleware/vdb/pgvector_config.py b/api/configs/middleware/vdb/pgvector_config.py index 9f5f7284d7..62334636a5 100644 --- a/api/configs/middleware/vdb/pgvector_config.py +++ b/api/configs/middleware/vdb/pgvector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class PGVectorConfig(BaseSettings): Configuration settings for PGVector (PostgreSQL with vector extension) """ - PGVECTOR_HOST: Optional[str] = Field( + PGVECTOR_HOST: str | None = Field( description="Hostname or IP address of the PostgreSQL server with PGVector extension (e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class PGVectorConfig(BaseSettings): default=5433, ) - PGVECTOR_USER: Optional[str] = Field( + PGVECTOR_USER: str | None = Field( description="Username for authenticating with the PostgreSQL database", default=None, ) - PGVECTOR_PASSWORD: Optional[str] = Field( + PGVECTOR_PASSWORD: str | None = Field( description="Password for authenticating with the PostgreSQL database", default=None, ) - PGVECTOR_DATABASE: Optional[str] = Field( + PGVECTOR_DATABASE: str | None = Field( description="Name of the PostgreSQL database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/pgvectors_config.py b/api/configs/middleware/vdb/pgvectors_config.py index fa3bca5bb7..7bc144c4ab 100644 --- a/api/configs/middleware/vdb/pgvectors_config.py +++ b/api/configs/middleware/vdb/pgvectors_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class PGVectoRSConfig(BaseSettings): Configuration settings for PGVecto.RS (Rust-based vector extension for PostgreSQL) """ - PGVECTO_RS_HOST: Optional[str] = Field( + PGVECTO_RS_HOST: str | None = Field( description="Hostname or IP address of the PostgreSQL server with PGVecto.RS extension (e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class PGVectoRSConfig(BaseSettings): default=5431, ) - PGVECTO_RS_USER: Optional[str] = Field( + PGVECTO_RS_USER: str | None = Field( description="Username for authenticating with the PostgreSQL database using PGVecto.RS", default=None, ) - PGVECTO_RS_PASSWORD: Optional[str] = Field( + PGVECTO_RS_PASSWORD: str | None = Field( description="Password for authenticating with the PostgreSQL database using PGVecto.RS", default=None, ) - PGVECTO_RS_DATABASE: Optional[str] = Field( + PGVECTO_RS_DATABASE: str | None = Field( description="Name of the PostgreSQL database with PGVecto.RS extension to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/qdrant_config.py b/api/configs/middleware/vdb/qdrant_config.py index 0a753eddec..b9e8e861da 100644 --- a/api/configs/middleware/vdb/qdrant_config.py +++ b/api/configs/middleware/vdb/qdrant_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class QdrantConfig(BaseSettings): Configuration settings for Qdrant vector database """ - QDRANT_URL: Optional[str] = Field( + QDRANT_URL: str | None = Field( description="URL of the Qdrant server (e.g., 'http://localhost:6333' or 'https://qdrant.example.com')", default=None, ) - QDRANT_API_KEY: Optional[str] = Field( + QDRANT_API_KEY: str | None = Field( description="API key for authenticating with the Qdrant server", default=None, ) diff --git a/api/configs/middleware/vdb/relyt_config.py b/api/configs/middleware/vdb/relyt_config.py index 5ffbea7b19..0ed5357852 100644 --- a/api/configs/middleware/vdb/relyt_config.py +++ b/api/configs/middleware/vdb/relyt_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class RelytConfig(BaseSettings): Configuration settings for Relyt database """ - RELYT_HOST: Optional[str] = Field( + RELYT_HOST: str | None = Field( description="Hostname or IP address of the Relyt server (e.g., 'localhost' or 'relyt.example.com')", default=None, ) @@ -19,17 +17,17 @@ class RelytConfig(BaseSettings): default=9200, ) - RELYT_USER: Optional[str] = Field( + RELYT_USER: str | None = Field( description="Username for authenticating with the Relyt database", default=None, ) - RELYT_PASSWORD: Optional[str] = Field( + RELYT_PASSWORD: str | None = Field( description="Password for authenticating with the Relyt database", default=None, ) - RELYT_DATABASE: Optional[str] = Field( + RELYT_DATABASE: str | None = Field( description="Name of the Relyt database to connect to (default is 'default')", default="default", ) diff --git a/api/configs/middleware/vdb/tablestore_config.py b/api/configs/middleware/vdb/tablestore_config.py index 1aab01c6e1..2cec384b5d 100644 --- a/api/configs/middleware/vdb/tablestore_config.py +++ b/api/configs/middleware/vdb/tablestore_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class TableStoreConfig(BaseSettings): Configuration settings for TableStore. """ - TABLESTORE_ENDPOINT: Optional[str] = Field( + TABLESTORE_ENDPOINT: str | None = Field( description="Endpoint address of the TableStore server (e.g. 'https://instance-name.cn-hangzhou.ots.aliyuncs.com')", default=None, ) - TABLESTORE_INSTANCE_NAME: Optional[str] = Field( + TABLESTORE_INSTANCE_NAME: str | None = Field( description="Instance name to access TableStore server (eg. 'instance-name')", default=None, ) - TABLESTORE_ACCESS_KEY_ID: Optional[str] = Field( + TABLESTORE_ACCESS_KEY_ID: str | None = Field( description="AccessKey id for the instance name", default=None, ) - TABLESTORE_ACCESS_KEY_SECRET: Optional[str] = Field( + TABLESTORE_ACCESS_KEY_SECRET: str | None = Field( description="AccessKey secret for the instance name", default=None, ) diff --git a/api/configs/middleware/vdb/tencent_vector_config.py b/api/configs/middleware/vdb/tencent_vector_config.py index a51823c3f3..3dc21ab89a 100644 --- a/api/configs/middleware/vdb/tencent_vector_config.py +++ b/api/configs/middleware/vdb/tencent_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class TencentVectorDBConfig(BaseSettings): Configuration settings for Tencent Vector Database """ - TENCENT_VECTOR_DB_URL: Optional[str] = Field( + TENCENT_VECTOR_DB_URL: str | None = Field( description="URL of the Tencent Vector Database service (e.g., 'https://vectordb.tencentcloudapi.com')", default=None, ) - TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field( + TENCENT_VECTOR_DB_API_KEY: str | None = Field( description="API key for authenticating with the Tencent Vector Database service", default=None, ) @@ -24,12 +22,12 @@ class TencentVectorDBConfig(BaseSettings): default=30, ) - TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field( + TENCENT_VECTOR_DB_USERNAME: str | None = Field( description="Username for authenticating with the Tencent Vector Database (if required)", default=None, ) - TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field( + TENCENT_VECTOR_DB_PASSWORD: str | None = Field( description="Password for authenticating with the Tencent Vector Database (if required)", default=None, ) @@ -44,7 +42,7 @@ class TencentVectorDBConfig(BaseSettings): default=2, ) - TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field( + TENCENT_VECTOR_DB_DATABASE: str | None = Field( description="Name of the specific Tencent Vector Database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/tidb_on_qdrant_config.py b/api/configs/middleware/vdb/tidb_on_qdrant_config.py index d2625af264..9ca0955129 100644 --- a/api/configs/middleware/vdb/tidb_on_qdrant_config.py +++ b/api/configs/middleware/vdb/tidb_on_qdrant_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class TidbOnQdrantConfig(BaseSettings): Tidb on Qdrant configs """ - TIDB_ON_QDRANT_URL: Optional[str] = Field( + TIDB_ON_QDRANT_URL: str | None = Field( description="Tidb on Qdrant url", default=None, ) - TIDB_ON_QDRANT_API_KEY: Optional[str] = Field( + TIDB_ON_QDRANT_API_KEY: str | None = Field( description="Tidb on Qdrant api key", default=None, ) @@ -34,37 +32,37 @@ class TidbOnQdrantConfig(BaseSettings): default=6334, ) - TIDB_PUBLIC_KEY: Optional[str] = Field( + TIDB_PUBLIC_KEY: str | None = Field( description="Tidb account public key", default=None, ) - TIDB_PRIVATE_KEY: Optional[str] = Field( + TIDB_PRIVATE_KEY: str | None = Field( description="Tidb account private key", default=None, ) - TIDB_API_URL: Optional[str] = Field( + TIDB_API_URL: str | None = Field( description="Tidb API url", default=None, ) - TIDB_IAM_API_URL: Optional[str] = Field( + TIDB_IAM_API_URL: str | None = Field( description="Tidb IAM API url", default=None, ) - TIDB_REGION: Optional[str] = Field( + TIDB_REGION: str | None = Field( description="Tidb serverless region", default="regions/aws-us-east-1", ) - TIDB_PROJECT_ID: Optional[str] = Field( + TIDB_PROJECT_ID: str | None = Field( description="Tidb project id", default=None, ) - TIDB_SPEND_LIMIT: Optional[int] = Field( + TIDB_SPEND_LIMIT: int | None = Field( description="Tidb spend limit", default=100, ) diff --git a/api/configs/middleware/vdb/tidb_vector_config.py b/api/configs/middleware/vdb/tidb_vector_config.py index bc68be69d8..0ebf226bea 100644 --- a/api/configs/middleware/vdb/tidb_vector_config.py +++ b/api/configs/middleware/vdb/tidb_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class TiDBVectorConfig(BaseSettings): Configuration settings for TiDB Vector database """ - TIDB_VECTOR_HOST: Optional[str] = Field( + TIDB_VECTOR_HOST: str | None = Field( description="Hostname or IP address of the TiDB Vector server (e.g., 'localhost' or 'tidb.example.com')", default=None, ) - TIDB_VECTOR_PORT: Optional[PositiveInt] = Field( + TIDB_VECTOR_PORT: PositiveInt | None = Field( description="Port number on which the TiDB Vector server is listening (default is 4000)", default=4000, ) - TIDB_VECTOR_USER: Optional[str] = Field( + TIDB_VECTOR_USER: str | None = Field( description="Username for authenticating with the TiDB Vector database", default=None, ) - TIDB_VECTOR_PASSWORD: Optional[str] = Field( + TIDB_VECTOR_PASSWORD: str | None = Field( description="Password for authenticating with the TiDB Vector database", default=None, ) - TIDB_VECTOR_DATABASE: Optional[str] = Field( + TIDB_VECTOR_DATABASE: str | None = Field( description="Name of the TiDB Vector database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/upstash_config.py b/api/configs/middleware/vdb/upstash_config.py index 412c56374a..01a0442f70 100644 --- a/api/configs/middleware/vdb/upstash_config.py +++ b/api/configs/middleware/vdb/upstash_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class UpstashConfig(BaseSettings): Configuration settings for Upstash vector database """ - UPSTASH_VECTOR_URL: Optional[str] = Field( + UPSTASH_VECTOR_URL: str | None = Field( description="URL of the upstash server (e.g., 'https://vector.upstash.io')", default=None, ) - UPSTASH_VECTOR_TOKEN: Optional[str] = Field( + UPSTASH_VECTOR_TOKEN: str | None = Field( description="Token for authenticating with the upstash server", default=None, ) diff --git a/api/configs/middleware/vdb/vastbase_vector_config.py b/api/configs/middleware/vdb/vastbase_vector_config.py index 816d6df90a..ced4cf154c 100644 --- a/api/configs/middleware/vdb/vastbase_vector_config.py +++ b/api/configs/middleware/vdb/vastbase_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class VastbaseVectorConfig(BaseSettings): Configuration settings for Vector (Vastbase with vector extension) """ - VASTBASE_HOST: Optional[str] = Field( + VASTBASE_HOST: str | None = Field( description="Hostname or IP address of the Vastbase server with Vector extension (e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class VastbaseVectorConfig(BaseSettings): default=5432, ) - VASTBASE_USER: Optional[str] = Field( + VASTBASE_USER: str | None = Field( description="Username for authenticating with the Vastbase database", default=None, ) - VASTBASE_PASSWORD: Optional[str] = Field( + VASTBASE_PASSWORD: str | None = Field( description="Password for authenticating with the Vastbase database", default=None, ) - VASTBASE_DATABASE: Optional[str] = Field( + VASTBASE_DATABASE: str | None = Field( description="Name of the Vastbase database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/vikingdb_config.py b/api/configs/middleware/vdb/vikingdb_config.py index aba49ff670..3d5306bb61 100644 --- a/api/configs/middleware/vdb/vikingdb_config.py +++ b/api/configs/middleware/vdb/vikingdb_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -11,14 +9,14 @@ class VikingDBConfig(BaseSettings): https://www.volcengine.com/docs/6291/65568 """ - VIKINGDB_ACCESS_KEY: Optional[str] = Field( + VIKINGDB_ACCESS_KEY: str | None = Field( description="The Access Key provided by Volcengine VikingDB for API authentication." "Refer to the following documentation for details on obtaining credentials:" "https://www.volcengine.com/docs/6291/65568", default=None, ) - VIKINGDB_SECRET_KEY: Optional[str] = Field( + VIKINGDB_SECRET_KEY: str | None = Field( description="The Secret Key provided by Volcengine VikingDB for API authentication.", default=None, ) diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py index 25000e8bde..6a79412ab8 100644 --- a/api/configs/middleware/vdb/weaviate_config.py +++ b/api/configs/middleware/vdb/weaviate_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class WeaviateConfig(BaseSettings): Configuration settings for Weaviate vector database """ - WEAVIATE_ENDPOINT: Optional[str] = Field( + WEAVIATE_ENDPOINT: str | None = Field( description="URL of the Weaviate server (e.g., 'http://localhost:8080' or 'https://weaviate.example.com')", default=None, ) - WEAVIATE_API_KEY: Optional[str] = Field( + WEAVIATE_API_KEY: str | None = Field( description="API key for authenticating with the Weaviate server", default=None, ) diff --git a/api/configs/remote_settings_sources/apollo/__init__.py b/api/configs/remote_settings_sources/apollo/__init__.py index f02f7dc9ff..55c14ead56 100644 --- a/api/configs/remote_settings_sources/apollo/__init__.py +++ b/api/configs/remote_settings_sources/apollo/__init__.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from pydantic import Field from pydantic.fields import FieldInfo @@ -15,22 +15,22 @@ class ApolloSettingsSourceInfo(BaseSettings): Packaging build information """ - APOLLO_APP_ID: Optional[str] = Field( + APOLLO_APP_ID: str | None = Field( description="apollo app_id", default=None, ) - APOLLO_CLUSTER: Optional[str] = Field( + APOLLO_CLUSTER: str | None = Field( description="apollo cluster", default=None, ) - APOLLO_CONFIG_URL: Optional[str] = Field( + APOLLO_CONFIG_URL: str | None = Field( description="apollo config url", default=None, ) - APOLLO_NAMESPACE: Optional[str] = Field( + APOLLO_NAMESPACE: str | None = Field( description="apollo namespace", default=None, ) diff --git a/api/constants/__init__.py b/api/constants/__init__.py index c98f4d55c8..fe8f4f8785 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -16,14 +16,14 @@ AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"] AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) +_doc_extensions: list[str] if dify_config.ETL_TYPE == "Unstructured": - DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] - DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) + _doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] + _doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) if dify_config.UNSTRUCTURED_API_URL: - DOCUMENT_EXTENSIONS.append("ppt") - DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) + _doc_extensions.append("ppt") else: - DOCUMENT_EXTENSIONS = [ + _doc_extensions = [ "txt", "markdown", "md", @@ -38,4 +38,4 @@ else: "vtt", "properties", ] - DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) +DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions] diff --git a/api/constants/model_template.py b/api/constants/model_template.py index c26d8c0186..cacf6b6874 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -7,7 +7,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # workflow default mode AppMode.WORKFLOW: { "app": { - "mode": AppMode.WORKFLOW.value, + "mode": AppMode.WORKFLOW, "enable_site": True, "enable_api": True, } @@ -15,7 +15,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # completion default mode AppMode.COMPLETION: { "app": { - "mode": AppMode.COMPLETION.value, + "mode": AppMode.COMPLETION, "enable_site": True, "enable_api": True, }, @@ -44,7 +44,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # chat default mode AppMode.CHAT: { "app": { - "mode": AppMode.CHAT.value, + "mode": AppMode.CHAT, "enable_site": True, "enable_api": True, }, @@ -60,7 +60,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # advanced-chat default mode AppMode.ADVANCED_CHAT: { "app": { - "mode": AppMode.ADVANCED_CHAT.value, + "mode": AppMode.ADVANCED_CHAT, "enable_site": True, "enable_api": True, }, @@ -68,7 +68,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # agent-chat default mode AppMode.AGENT_CHAT: { "app": { - "mode": AppMode.AGENT_CHAT.value, + "mode": AppMode.AGENT_CHAT, "enable_site": True, "enable_api": True, }, diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 8be769e798..3c5d76d14a 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: from core.model_runtime.entities.model_entities import AIModelEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController - from core.workflow.entities.variable_pool import VariablePool """ diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 7991fe633a..4e54fa9220 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -1,4 +1,5 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi @@ -26,7 +27,16 @@ from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi bp = Blueprint("console", __name__, url_prefix="/console/api") -api = ExternalApi(bp) + +api = ExternalApi( + bp, + version="1.0", + title="Console API", + description="Console management APIs for app configuration, monitoring, and administration", +) + +# Create namespace +console_ns = Namespace("console", description="Console management API operations", path="/") # File api.add_resource(FileApi, "/files/upload") @@ -43,7 +53,16 @@ api.add_resource(AppImportConfirmApi, "/apps/imports//confirm" api.add_resource(AppImportCheckDependenciesApi, "/apps/imports//check-dependencies") # Import other controllers -from . import admin, apikey, extension, feature, ping, setup, spec, version +from . import ( + admin, + apikey, + extension, + feature, + init_validate, + ping, + setup, + version, +) # Import app controllers from .app import ( @@ -70,7 +89,16 @@ from .app import ( ) # Import auth controllers -from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server +from .auth import ( + activate, + data_source_bearer_auth, + data_source_oauth, + email_register, + forgot_password, + login, + oauth, + oauth_server, +) # Import billing controllers from .billing import billing, compliance @@ -104,6 +132,23 @@ from .explore import ( saved_message, ) +# Import tag controllers +from .tag import tags + +# Import workspace controllers +from .workspace import ( + account, + agent_providers, + endpoint, + load_balancing_config, + members, + model_providers, + models, + plugin, + tool_providers, + workspace, +) + # Explore Audio api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") @@ -175,19 +220,71 @@ api.add_resource( InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" ) -# Import tag controllers -from .tag import tags +api.add_namespace(console_ns) -# Import workspace controllers -from .workspace import ( - account, - agent_providers, - endpoint, - load_balancing_config, - members, - model_providers, - models, - plugin, - tool_providers, - workspace, -) +__all__ = [ + "account", + "activate", + "admin", + "advanced_prompt_template", + "agent", + "agent_providers", + "annotation", + "api", + "apikey", + "app", + "audio", + "billing", + "bp", + "completion", + "compliance", + "console_ns", + "conversation", + "conversation_variables", + "data_source", + "data_source_bearer_auth", + "data_source_oauth", + "datasets", + "datasets_document", + "datasets_segments", + "email_register", + "endpoint", + "extension", + "external", + "feature", + "forgot_password", + "generator", + "hit_testing", + "init_validate", + "installed_app", + "load_balancing_config", + "login", + "mcp_server", + "members", + "message", + "metadata", + "model_config", + "model_providers", + "models", + "oauth", + "oauth_server", + "ops_trace", + "parameter", + "ping", + "plugin", + "recommended_app", + "saved_message", + "setup", + "site", + "statistic", + "tags", + "tool_providers", + "version", + "website", + "workflow", + "workflow_app_log", + "workflow_draft_variable", + "workflow_run", + "workflow_statistic", + "workspace", +] diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 1306efacf4..93f242ad28 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -3,7 +3,7 @@ from functools import wraps from typing import ParamSpec, TypeVar from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized @@ -12,7 +12,7 @@ P = ParamSpec("P") R = TypeVar("R") from configs import dify_config from constants.languages import supported_language -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import only_edition_cloud from extensions.ext_database import db from models.model import App, InstalledApp, RecommendedApp @@ -45,7 +45,28 @@ def admin_required(view: Callable[P, R]): return decorated +@console_ns.route("/admin/insert-explore-apps") class InsertExploreAppListApi(Resource): + @api.doc("insert_explore_app") + @api.doc(description="Insert or update an app in the explore list") + @api.expect( + api.model( + "InsertExploreAppRequest", + { + "app_id": fields.String(required=True, description="Application ID"), + "desc": fields.String(description="App description"), + "copyright": fields.String(description="Copyright information"), + "privacy_policy": fields.String(description="Privacy policy"), + "custom_disclaimer": fields.String(description="Custom disclaimer"), + "language": fields.String(required=True, description="Language code"), + "category": fields.String(required=True, description="App category"), + "position": fields.Integer(required=True, description="Display position"), + }, + ) + ) + @api.response(200, "App updated successfully") + @api.response(201, "App inserted successfully") + @api.response(404, "App not found") @only_edition_cloud @admin_required def post(self): @@ -115,7 +136,12 @@ class InsertExploreAppListApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/admin/insert-explore-apps/") class InsertExploreAppApi(Resource): + @api.doc("delete_explore_app") + @api.doc(description="Remove an app from the explore list") + @api.doc(params={"app_id": "Application ID to remove"}) + @api.response(204, "App removed successfully") @only_edition_cloud @admin_required def delete(self, app_id): @@ -152,7 +178,3 @@ class InsertExploreAppApi(Resource): db.session.commit() return {"result": "success"}, 204 - - -api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps") -api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/") diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index cfd5f73ade..fec527e4cb 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,8 +1,7 @@ -from typing import Any, Optional - import flask_restx from flask_login import current_user from flask_restx import Resource, fields, marshal_with +from flask_restx._http import HTTPStatus from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -13,7 +12,7 @@ from libs.login import login_required from models.dataset import Dataset from models.model import ApiToken, App -from . import api +from . import api, console_ns from .wraps import account_initialization_required, setup_required api_key_fields = { @@ -40,7 +39,7 @@ def _get_resource(resource_id, tenant_id, resource_model): ).scalar_one_or_none() if resource is None: - flask_restx.abort(404, message=f"{resource_model.__name__} not found.") + flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.") return resource @@ -49,7 +48,7 @@ class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[Any] = None + resource_model: type | None = None resource_id_field: str | None = None token_prefix: str | None = None max_keys = 10 @@ -59,11 +58,11 @@ class BaseApiKeyListResource(Resource): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) - keys = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) - .all() - ) + keys = db.session.scalars( + select(ApiToken).where( + ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id + ) + ).all() return {"items": keys} @marshal_with(api_key_fields) @@ -82,7 +81,7 @@ class BaseApiKeyListResource(Resource): if current_key_count >= self.max_keys: flask_restx.abort( - 400, + HTTPStatus.BAD_REQUEST, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", custom="max_keys_exceeded", ) @@ -102,7 +101,7 @@ class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[Any] = None + resource_model: type | None = None resource_id_field: str | None = None def delete(self, resource_id, api_key_id): @@ -126,7 +125,7 @@ class BaseApiKeyResource(Resource): ) if key is None: - flask_restx.abort(404, message="API key not found") + flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found") db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() @@ -134,7 +133,25 @@ class BaseApiKeyResource(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//api-keys") class AppApiKeyListResource(BaseApiKeyListResource): + @api.doc("get_app_api_keys") + @api.doc(description="Get all API keys for an app") + @api.doc(params={"resource_id": "App ID"}) + @api.response(200, "Success", api_key_list) + def get(self, resource_id): + """Get all API keys for an app""" + return super().get(resource_id) + + @api.doc("create_app_api_key") + @api.doc(description="Create a new API key for an app") + @api.doc(params={"resource_id": "App ID"}) + @api.response(201, "API key created successfully", api_key_fields) + @api.response(400, "Maximum keys exceeded") + def post(self, resource_id): + """Create a new API key for an app""" + return super().post(resource_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -146,7 +163,16 @@ class AppApiKeyListResource(BaseApiKeyListResource): token_prefix = "app-" +@console_ns.route("/apps//api-keys/") class AppApiKeyResource(BaseApiKeyResource): + @api.doc("delete_app_api_key") + @api.doc(description="Delete an API key for an app") + @api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"}) + @api.response(204, "API key deleted successfully") + def delete(self, resource_id, api_key_id): + """Delete an API key for an app""" + return super().delete(resource_id, api_key_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -157,7 +183,25 @@ class AppApiKeyResource(BaseApiKeyResource): resource_id_field = "app_id" +@console_ns.route("/datasets//api-keys") class DatasetApiKeyListResource(BaseApiKeyListResource): + @api.doc("get_dataset_api_keys") + @api.doc(description="Get all API keys for a dataset") + @api.doc(params={"resource_id": "Dataset ID"}) + @api.response(200, "Success", api_key_list) + def get(self, resource_id): + """Get all API keys for a dataset""" + return super().get(resource_id) + + @api.doc("create_dataset_api_key") + @api.doc(description="Create a new API key for a dataset") + @api.doc(params={"resource_id": "Dataset ID"}) + @api.response(201, "API key created successfully", api_key_fields) + @api.response(400, "Maximum keys exceeded") + def post(self, resource_id): + """Create a new API key for a dataset""" + return super().post(resource_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -169,7 +213,16 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): token_prefix = "ds-" +@console_ns.route("/datasets//api-keys/") class DatasetApiKeyResource(BaseApiKeyResource): + @api.doc("delete_dataset_api_key") + @api.doc(description="Delete an API key for a dataset") + @api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"}) + @api.response(204, "API key deleted successfully") + def delete(self, resource_id, api_key_id): + """Delete an API key for a dataset""" + return super().delete(resource_id, api_key_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -178,9 +231,3 @@ class DatasetApiKeyResource(BaseApiKeyResource): resource_type = "dataset" resource_model = Dataset resource_id_field = "dataset_id" - - -api.add_resource(AppApiKeyListResource, "/apps//api-keys") -api.add_resource(AppApiKeyResource, "/apps//api-keys/") -api.add_resource(DatasetApiKeyListResource, "/datasets//api-keys") -api.add_resource(DatasetApiKeyResource, "/datasets//api-keys/") diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c6cb6f6e3a..315825db79 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,12 +1,26 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.advanced_prompt_template_service import AdvancedPromptTemplateService +@console_ns.route("/app/prompt-templates") class AdvancedPromptTemplateList(Resource): + @api.doc("get_advanced_prompt_templates") + @api.doc(description="Get advanced prompt templates based on app mode and model configuration") + @api.expect( + api.parser() + .add_argument("app_mode", type=str, required=True, location="args", help="Application mode") + .add_argument("model_mode", type=str, required=True, location="args", help="Model mode") + .add_argument("has_context", type=str, default="true", location="args", help="Whether has context") + .add_argument("model_name", type=str, required=True, location="args", help="Model name") + ) + @api.response( + 200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data")) + ) + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -19,6 +33,3 @@ class AdvancedPromptTemplateList(Resource): args = parser.parse_args() return AdvancedPromptTemplateService.get_prompt(args) - - -api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates") diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index a964154207..c063f336c7 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,6 +1,6 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from libs.helper import uuid_value @@ -9,7 +9,18 @@ from models.model import AppMode from services.agent_service import AgentService +@console_ns.route("/apps//agent/logs") class AgentLogApi(Resource): + @api.doc("get_agent_logs") + @api.doc(description="Get agent execution logs for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("message_id", type=str, required=True, location="args", help="Message UUID") + .add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID") + ) + @api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))) + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -23,6 +34,3 @@ class AgentLogApi(Resource): args = parser.parse_args() return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) - - -api.add_resource(AgentLogApi, "/apps//agent/logs") diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 37d23ccd9f..d0ee11fe75 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -2,11 +2,11 @@ from typing import Literal from flask import request from flask_login import current_user -from flask_restx import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.common.errors import NoFileUploadedError, TooManyFilesError -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, @@ -21,7 +21,23 @@ from libs.login import login_required from services.annotation_service import AppAnnotationService +@console_ns.route("/apps//annotation-reply/") class AnnotationReplyActionApi(Resource): + @api.doc("annotation_reply_action") + @api.doc(description="Enable or disable annotation reply for an app") + @api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"}) + @api.expect( + api.model( + "AnnotationReplyActionRequest", + { + "score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"), + "embedding_provider_name": fields.String(required=True, description="Embedding provider name"), + "embedding_model_name": fields.String(required=True, description="Embedding model name"), + }, + ) + ) + @api.response(200, "Action completed successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -43,7 +59,13 @@ class AnnotationReplyActionApi(Resource): return result, 200 +@console_ns.route("/apps//annotation-setting") class AppAnnotationSettingDetailApi(Resource): + @api.doc("get_annotation_setting") + @api.doc(description="Get annotation settings for an app") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Annotation settings retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -56,7 +78,23 @@ class AppAnnotationSettingDetailApi(Resource): return result, 200 +@console_ns.route("/apps//annotation-settings/") class AppAnnotationSettingUpdateApi(Resource): + @api.doc("update_annotation_setting") + @api.doc(description="Update annotation settings for an app") + @api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"}) + @api.expect( + api.model( + "AnnotationSettingUpdateRequest", + { + "score_threshold": fields.Float(required=True, description="Score threshold"), + "embedding_provider_name": fields.String(required=True, description="Embedding provider"), + "embedding_model_name": fields.String(required=True, description="Embedding model"), + }, + ) + ) + @api.response(200, "Settings updated successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -75,7 +113,13 @@ class AppAnnotationSettingUpdateApi(Resource): return result, 200 +@console_ns.route("/apps//annotation-reply//status/") class AnnotationReplyActionStatusApi(Resource): + @api.doc("get_annotation_reply_action_status") + @api.doc(description="Get status of annotation reply action job") + @api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"}) + @api.response(200, "Job status retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -99,7 +143,19 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 +@console_ns.route("/apps//annotations") class AnnotationApi(Resource): + @api.doc("list_annotations") + @api.doc(description="Get annotations for an app with pagination") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size") + .add_argument("keyword", type=str, location="args", default="", help="Search keyword") + ) + @api.response(200, "Annotations retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -122,6 +178,21 @@ class AnnotationApi(Resource): } return response, 200 + @api.doc("create_annotation") + @api.doc(description="Create a new annotation for an app") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "CreateAnnotationRequest", + { + "question": fields.String(required=True, description="Question text"), + "answer": fields.String(required=True, description="Answer text"), + "annotation_reply": fields.Raw(description="Annotation reply data"), + }, + ) + ) + @api.response(201, "Annotation created successfully", annotation_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -168,7 +239,13 @@ class AnnotationApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//annotations/export") class AnnotationExportApi(Resource): + @api.doc("export_annotations") + @api.doc(description="Export all annotations for an app") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields))) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -182,7 +259,14 @@ class AnnotationExportApi(Resource): return response, 200 +@console_ns.route("/apps//annotations/") class AnnotationUpdateDeleteApi(Resource): + @api.doc("update_delete_annotation") + @api.doc(description="Update or delete an annotation") + @api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) + @api.response(200, "Annotation updated successfully", annotation_fields) + @api.response(204, "Annotation deleted successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -214,7 +298,14 @@ class AnnotationUpdateDeleteApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//annotations/batch-import") class AnnotationBatchImportApi(Resource): + @api.doc("batch_import_annotations") + @api.doc(description="Batch import annotations from CSV file") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Batch import started successfully") + @api.response(403, "Insufficient permissions") + @api.response(400, "No file uploaded or too many files") @setup_required @login_required @account_initialization_required @@ -239,7 +330,13 @@ class AnnotationBatchImportApi(Resource): return AppAnnotationService.batch_import_app_annotations(app_id, file) +@console_ns.route("/apps//annotations/batch-import-status/") class AnnotationBatchImportStatusApi(Resource): + @api.doc("get_batch_import_status") + @api.doc(description="Get status of batch import job") + @api.doc(params={"app_id": "Application ID", "job_id": "Job ID"}) + @api.response(200, "Job status retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -262,7 +359,20 @@ class AnnotationBatchImportStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 +@console_ns.route("/apps//annotations//hit-histories") class AnnotationHitHistoryListApi(Resource): + @api.doc("list_annotation_hit_histories") + @api.doc(description="Get hit histories for an annotation") + @api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) + @api.expect( + api.parser() + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size") + ) + @api.response( + 200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields)) + ) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -285,17 +395,3 @@ class AnnotationHitHistoryListApi(Resource): "page": page, } return response - - -api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply/") -api.add_resource( - AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" -) -api.add_resource(AnnotationApi, "/apps//annotations") -api.add_resource(AnnotationExportApi, "/apps//annotations/export") -api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") -api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") -api.add_resource(AnnotationBatchImportStatusApi, "/apps//annotations/batch-import-status/") -api.add_resource(AnnotationHitHistoryListApi, "/apps//annotations//hit-histories") -api.add_resource(AppAnnotationSettingDetailApi, "/apps//annotation-setting") -api.add_resource(AppAnnotationSettingUpdateApi, "/apps//annotation-settings/") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 10753d2f95..2d2e4b448a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -2,12 +2,12 @@ import uuid from typing import cast from flask_login import current_user -from flask_restx import Resource, inputs, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( account_initialization_required, @@ -34,7 +34,27 @@ def _validate_description_length(description): return description +@console_ns.route("/apps") class AppListApi(Resource): + @api.doc("list_apps") + @api.doc(description="Get list of applications with pagination and filtering") + @api.expect( + api.parser() + .add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1) + .add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20) + .add_argument( + "mode", + type=str, + location="args", + choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"], + default="all", + help="App mode filter", + ) + .add_argument("name", type=str, location="args", help="Filter by app name") + .add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs") + .add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator") + ) + @api.response(200, "Success", app_pagination_fields) @setup_required @login_required @account_initialization_required @@ -91,6 +111,24 @@ class AppListApi(Resource): return marshal(app_pagination, app_pagination_fields), 200 + @api.doc("create_app") + @api.doc(description="Create a new application") + @api.expect( + api.model( + "CreateAppRequest", + { + "name": fields.String(required=True, description="App name"), + "description": fields.String(description="App description (max 400 chars)"), + "mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + }, + ) + ) + @api.response(201, "App created successfully", app_detail_fields) + @api.response(403, "Insufficient permissions") + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -115,12 +153,21 @@ class AppListApi(Resource): raise BadRequest("mode is required") app_service = AppService() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + if current_user.current_tenant_id is None: + raise ValueError("current_user.current_tenant_id cannot be None") app = app_service.create_app(current_user.current_tenant_id, args, current_user) return app, 201 +@console_ns.route("/apps/") class AppApi(Resource): + @api.doc("get_app_detail") + @api.doc(description="Get application details") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Success", app_detail_fields_with_site) @setup_required @login_required @account_initialization_required @@ -139,6 +186,26 @@ class AppApi(Resource): return app_model + @api.doc("update_app") + @api.doc(description="Update application details") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "UpdateAppRequest", + { + "name": fields.String(required=True, description="App name"), + "description": fields.String(description="App description (max 400 chars)"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + "use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"), + "max_active_requests": fields.Integer(description="Maximum active requests"), + }, + ) + ) + @api.response(200, "App updated successfully", app_detail_fields_with_site) + @api.response(403, "Insufficient permissions") + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -161,14 +228,31 @@ class AppApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app(app_model, args) + # Construct ArgsDict from parsed arguments + from services.app_service import AppService as AppServiceType + + args_dict: AppServiceType.ArgsDict = { + "name": args["name"], + "description": args.get("description", ""), + "icon_type": args.get("icon_type", ""), + "icon": args.get("icon", ""), + "icon_background": args.get("icon_background", ""), + "use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False), + "max_active_requests": args.get("max_active_requests", 0), + } + app_model = app_service.update_app(app_model, args_dict) return app_model + @api.doc("delete_app") + @api.doc(description="Delete application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(204, "App deleted successfully") + @api.response(403, "Insufficient permissions") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def delete(self, app_model): """Delete app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -181,7 +265,25 @@ class AppApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//copy") class AppCopyApi(Resource): + @api.doc("copy_app") + @api.doc(description="Create a copy of an existing application") + @api.doc(params={"app_id": "Application ID to copy"}) + @api.expect( + api.model( + "CopyAppRequest", + { + "name": fields.String(description="Name for the copied app"), + "description": fields.String(description="Description for the copied app"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + }, + ) + ) + @api.response(201, "App copied successfully", app_detail_fields_with_site) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -223,11 +325,26 @@ class AppCopyApi(Resource): return app, 201 +@console_ns.route("/apps//export") class AppExportApi(Resource): + @api.doc("export_app") + @api.doc(description="Export application configuration as DSL") + @api.doc(params={"app_id": "Application ID to export"}) + @api.expect( + api.parser() + .add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export") + .add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export") + ) + @api.response( + 200, + "App exported successfully", + api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}), + ) + @api.response(403, "Insufficient permissions") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): """Export app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -247,7 +364,13 @@ class AppExportApi(Resource): } +@console_ns.route("/apps//name") class AppNameApi(Resource): + @api.doc("check_app_name") + @api.doc(description="Check if app name is available") + @api.doc(params={"app_id": "Application ID"}) + @api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check")) + @api.response(200, "Name availability checked") @setup_required @login_required @account_initialization_required @@ -263,12 +386,28 @@ class AppNameApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_name(app_model, args.get("name")) + app_model = app_service.update_app_name(app_model, args["name"]) return app_model +@console_ns.route("/apps//icon") class AppIconApi(Resource): + @api.doc("update_app_icon") + @api.doc(description="Update application icon") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppIconRequest", + { + "icon": fields.String(required=True, description="Icon data"), + "icon_type": fields.String(description="Icon type"), + "icon_background": fields.String(description="Icon background color"), + }, + ) + ) + @api.response(200, "Icon updated successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -285,12 +424,23 @@ class AppIconApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) + app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "") return app_model +@console_ns.route("/apps//site-enable") class AppSiteStatus(Resource): + @api.doc("update_app_site_status") + @api.doc(description="Enable or disable app site") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")} + ) + ) + @api.response(200, "Site status updated successfully", app_detail_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -306,12 +456,23 @@ class AppSiteStatus(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) + app_model = app_service.update_app_site_status(app_model, args["enable_site"]) return app_model +@console_ns.route("/apps//api-enable") class AppApiStatus(Resource): + @api.doc("update_app_api_status") + @api.doc(description="Enable or disable app API") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")} + ) + ) + @api.response(200, "API status updated successfully", app_detail_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -327,12 +488,17 @@ class AppApiStatus(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) + app_model = app_service.update_app_api_status(app_model, args["enable_api"]) return app_model +@console_ns.route("/apps//trace") class AppTraceApi(Resource): + @api.doc("get_app_trace") + @api.doc(description="Get app tracing configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Trace configuration retrieved successfully") @setup_required @login_required @account_initialization_required @@ -342,6 +508,20 @@ class AppTraceApi(Resource): return app_trace_config + @api.doc("update_app_trace") + @api.doc(description="Update app tracing configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppTraceRequest", + { + "enabled": fields.Boolean(required=True, description="Enable or disable tracing"), + "tracing_provider": fields.String(required=True, description="Tracing provider"), + }, + ) + ) + @api.response(200, "Trace configuration updated successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -361,14 +541,3 @@ class AppTraceApi(Resource): ) return {"result": "success"} - - -api.add_resource(AppListApi, "/apps") -api.add_resource(AppApi, "/apps/") -api.add_resource(AppCopyApi, "/apps//copy") -api.add_resource(AppExportApi, "/apps//export") -api.add_resource(AppNameApi, "/apps//name") -api.add_resource(AppIconApi, "/apps//icon") -api.add_resource(AppSiteStatus, "/apps//site-enable") -api.add_resource(AppApiStatus, "/apps//api-enable") -api.add_resource(AppTraceApi, "/apps//trace") diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index aaf5c3dfaa..7d659dae0d 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,11 +1,11 @@ import logging from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from werkzeug.exceptions import InternalServerError import services -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -34,7 +34,18 @@ from services.errors.audio import ( logger = logging.getLogger(__name__) +@console_ns.route("/apps//audio-to-text") class ChatMessageAudioApi(Resource): + @api.doc("chat_message_audio_transcript") + @api.doc(description="Transcript audio to text for chat messages") + @api.doc(params={"app_id": "App ID"}) + @api.response( + 200, + "Audio transcription successful", + api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), + ) + @api.response(400, "Bad request - No audio uploaded or unsupported type") + @api.response(413, "Audio file too large") @setup_required @login_required @account_initialization_required @@ -76,11 +87,28 @@ class ChatMessageAudioApi(Resource): raise InternalServerError() +@console_ns.route("/apps//text-to-audio") class ChatMessageTextApi(Resource): + @api.doc("chat_message_text_to_speech") + @api.doc(description="Convert text to speech for chat messages") + @api.doc(params={"app_id": "App ID"}) + @api.expect( + api.model( + "TextToSpeechRequest", + { + "message_id": fields.String(description="Message ID"), + "text": fields.String(required=True, description="Text to convert to speech"), + "voice": fields.String(description="Voice to use for TTS"), + "streaming": fields.Boolean(description="Whether to stream the audio"), + }, + ) + ) + @api.response(200, "Text to speech conversion successful") + @api.response(400, "Bad request - Invalid parameters") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def post(self, app_model: App): try: parser = reqparse.RequestParser() @@ -124,11 +152,18 @@ class ChatMessageTextApi(Resource): raise InternalServerError() +@console_ns.route("/apps//text-to-audio/voices") class TextModesApi(Resource): + @api.doc("get_text_to_speech_voices") + @api.doc(description="Get available TTS voices for a specific language") + @api.doc(params={"app_id": "App ID"}) + @api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code")) + @api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))) + @api.response(400, "Invalid language parameter") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): try: parser = reqparse.RequestParser() @@ -164,8 +199,3 @@ class TextModesApi(Resource): except Exception as e: logger.exception("Failed to handle get request to TextModesApi") raise InternalServerError() - - -api.add_resource(ChatMessageAudioApi, "/apps//audio-to-text") -api.add_resource(ChatMessageTextApi, "/apps//text-to-audio") -api.add_resource(TextModesApi, "/apps//text-to-audio/voices") diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 701ebb0b4a..2f7b90e7fb 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,12 +1,11 @@ import logging -import flask_login from flask import request -from flask_restx import Resource, reqparse -from werkzeug.exceptions import InternalServerError, NotFound +from flask_restx import Resource, fields, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -29,7 +28,8 @@ from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value -from libs.login import login_required +from libs.login import current_user, login_required +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -38,7 +38,27 @@ logger = logging.getLogger(__name__) # define completion message api for user +@console_ns.route("/apps//completion-messages") class CompletionMessageApi(Resource): + @api.doc("create_completion_message") + @api.doc(description="Generate completion message for debugging") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "CompletionMessageRequest", + { + "inputs": fields.Raw(required=True, description="Input variables"), + "query": fields.String(description="Query text", default=""), + "files": fields.List(fields.Raw(), description="Uploaded files"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"), + "retriever_from": fields.String(default="dev", description="Retriever source"), + }, + ) + ) + @api.response(200, "Completion generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(404, "App not found") @setup_required @login_required @account_initialization_required @@ -56,11 +76,11 @@ class CompletionMessageApi(Resource): streaming = args["response_mode"] != "blocking" args["auto_generate_name"] = False - account = flask_login.current_user - try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account or EndUser instance") response = AppGenerateService.generate( - app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -86,25 +106,58 @@ class CompletionMessageApi(Resource): raise InternalServerError() +@console_ns.route("/apps//completion-messages//stop") class CompletionMessageStopApi(Resource): + @api.doc("stop_completion_message") + @api.doc(description="Stop a running completion message generation") + @api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"}) + @api.response(200, "Task stopped successfully") @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model, task_id): - account = flask_login.current_user - - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) return {"result": "success"}, 200 +@console_ns.route("/apps//chat-messages") class ChatMessageApi(Resource): + @api.doc("create_chat_message") + @api.doc(description="Generate chat message for debugging") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "ChatMessageRequest", + { + "inputs": fields.Raw(required=True, description="Input variables"), + "query": fields.String(required=True, description="User query"), + "files": fields.List(fields.Raw(), description="Uploaded files"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "conversation_id": fields.String(description="Conversation ID"), + "parent_message_id": fields.String(description="Parent message ID"), + "response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"), + "retriever_from": fields.String(default="dev", description="Retriever source"), + }, + ) + ) + @api.response(200, "Chat message generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(404, "App or conversation not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): + if not isinstance(current_user, Account): + raise Forbidden() + + if not current_user.has_edit_permission: + raise Forbidden() + parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("query", type=str, required=True, location="json") @@ -123,11 +176,11 @@ class ChatMessageApi(Resource): if external_trace_id: args["external_trace_id"] = external_trace_id - account = flask_login.current_user - try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account or EndUser instance") response = AppGenerateService.generate( - app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -155,20 +208,19 @@ class ChatMessageApi(Resource): raise InternalServerError() +@console_ns.route("/apps//chat-messages//stop") class ChatMessageStopApi(Resource): + @api.doc("stop_chat_message") + @api.doc(description="Stop a running chat message generation") + @api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"}) + @api.response(200, "Task stopped successfully") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model, task_id): - account = flask_login.current_user - - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) return {"result": "success"}, 200 - - -api.add_resource(CompletionMessageApi, "/apps//completion-messages") -api.add_resource(CompletionMessageStopApi, "/apps//completion-messages//stop") -api.add_resource(ChatMessageApi, "/apps//chat-messages") -api.add_resource(ChatMessageStopApi, "/apps//chat-messages//stop") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index bc825effad..c0cbf6613e 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -8,7 +8,7 @@ from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload from werkzeug.exceptions import Forbidden, NotFound -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom @@ -22,13 +22,35 @@ from fields.conversation_fields import ( from libs.datetime_utils import naive_utc_now from libs.helper import DatetimeString from libs.login import login_required -from models import Conversation, EndUser, Message, MessageAnnotation +from models import Account, Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError +@console_ns.route("/apps//completion-conversations") class CompletionConversationApi(Resource): + @api.doc("list_completion_conversations") + @api.doc(description="Get completion conversations with pagination and filtering") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("keyword", type=str, location="args", help="Search keyword") + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + .add_argument( + "annotation_status", + type=str, + location="args", + choices=["annotated", "not_annotated", "all"], + default="all", + help="Annotation status filter", + ) + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)") + ) + @api.response(200, "Success", conversation_pagination_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -101,7 +123,14 @@ class CompletionConversationApi(Resource): return conversations +@console_ns.route("/apps//completion-conversations/") class CompletionConversationDetailApi(Resource): + @api.doc("get_completion_conversation") + @api.doc(description="Get completion conversation details with messages") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(200, "Success", conversation_message_detail_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @@ -114,6 +143,12 @@ class CompletionConversationDetailApi(Resource): return _get_conversation(app_model, conversation_id) + @api.doc("delete_completion_conversation") + @api.doc(description="Delete a completion conversation") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(204, "Conversation deleted successfully") + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @@ -124,6 +159,8 @@ class CompletionConversationDetailApi(Resource): conversation_id = str(conversation_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -131,7 +168,38 @@ class CompletionConversationDetailApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//chat-conversations") class ChatConversationApi(Resource): + @api.doc("list_chat_conversations") + @api.doc(description="Get chat conversations with pagination, filtering and summary") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("keyword", type=str, location="args", help="Search keyword") + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + .add_argument( + "annotation_status", + type=str, + location="args", + choices=["annotated", "not_annotated", "all"], + default="all", + help="Annotation status filter", + ) + .add_argument("message_count_gte", type=int, location="args", help="Minimum message count") + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)") + .add_argument( + "sort_by", + type=str, + location="args", + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + default="-updated_at", + help="Sort field and direction", + ) + ) + @api.response(200, "Success", conversation_with_summary_pagination_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -239,7 +307,7 @@ class ChatConversationApi(Resource): .having(func.count(Message.id) >= args["message_count_gte"]) ) - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) match args["sort_by"]: @@ -259,7 +327,14 @@ class ChatConversationApi(Resource): return conversations +@console_ns.route("/apps//chat-conversations/") class ChatConversationDetailApi(Resource): + @api.doc("get_chat_conversation") + @api.doc(description="Get chat conversation details") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(200, "Success", conversation_detail_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @@ -272,6 +347,12 @@ class ChatConversationDetailApi(Resource): return _get_conversation(app_model, conversation_id) + @api.doc("delete_chat_conversation") + @api.doc(description="Delete a chat conversation") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(204, "Conversation deleted successfully") + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @@ -282,6 +363,8 @@ class ChatConversationDetailApi(Resource): conversation_id = str(conversation_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -289,12 +372,6 @@ class ChatConversationDetailApi(Resource): return {"result": "success"}, 204 -api.add_resource(CompletionConversationApi, "/apps//completion-conversations") -api.add_resource(CompletionConversationDetailApi, "/apps//completion-conversations/") -api.add_resource(ChatConversationApi, "/apps//chat-conversations") -api.add_resource(ChatConversationDetailApi, "/apps//chat-conversations/") - - def _get_conversation(app_model, conversation_id): conversation = ( db.session.query(Conversation) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 5ca4c33f87..8a65a89963 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -2,7 +2,7 @@ from flask_restx import Resource, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db @@ -12,7 +12,17 @@ from models import ConversationVariable from models.model import AppMode +@console_ns.route("/apps//conversation-variables") class ConversationVariablesApi(Resource): + @api.doc("get_conversation_variables") + @api.doc(description="Get conversation variables for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser().add_argument( + "conversation_id", type=str, location="args", help="Conversation ID to filter variables" + ) + ) + @api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields) @setup_required @login_required @account_initialization_required @@ -55,6 +65,3 @@ class ConversationVariablesApi(Resource): for row in rows ], } - - -api.add_resource(ConversationVariablesApi, "/apps//conversation-variables") diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index e1e8bf946a..bcc88c848d 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,9 +1,9 @@ from collections.abc import Sequence from flask_login import current_user -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -22,7 +22,23 @@ from models import App from services.workflow_service import WorkflowService +@console_ns.route("/rule-generate") class RuleGenerateApi(Resource): + @api.doc("generate_rule_config") + @api.doc(description="Generate rule configuration using LLM") + @api.expect( + api.model( + "RuleGenerateRequest", + { + "instruction": fields.String(required=True, description="Rule generation instruction"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"), + }, + ) + ) + @api.response(200, "Rule configuration generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -53,7 +69,26 @@ class RuleGenerateApi(Resource): return rules +@console_ns.route("/rule-code-generate") class RuleCodeGenerateApi(Resource): + @api.doc("generate_rule_code") + @api.doc(description="Generate code rules using LLM") + @api.expect( + api.model( + "RuleCodeGenerateRequest", + { + "instruction": fields.String(required=True, description="Code generation instruction"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"), + "code_language": fields.String( + default="javascript", description="Programming language for code generation" + ), + }, + ) + ) + @api.response(200, "Code rules generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -85,7 +120,22 @@ class RuleCodeGenerateApi(Resource): return code_result +@console_ns.route("/rule-structured-output-generate") class RuleStructuredOutputGenerateApi(Resource): + @api.doc("generate_structured_output") + @api.doc(description="Generate structured output rules using LLM") + @api.expect( + api.model( + "StructuredOutputGenerateRequest", + { + "instruction": fields.String(required=True, description="Structured output generation instruction"), + "model_config": fields.Raw(required=True, description="Model configuration"), + }, + ) + ) + @api.response(200, "Structured output generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -114,7 +164,27 @@ class RuleStructuredOutputGenerateApi(Resource): return structured_output +@console_ns.route("/instruction-generate") class InstructionGenerateApi(Resource): + @api.doc("generate_instruction") + @api.doc(description="Generate instruction for workflow nodes or general use") + @api.expect( + api.model( + "InstructionGenerateRequest", + { + "flow_id": fields.String(required=True, description="Workflow/Flow ID"), + "node_id": fields.String(description="Node ID for workflow context"), + "current": fields.String(description="Current instruction text"), + "language": fields.String(default="javascript", description="Programming language (javascript/python)"), + "instruction": fields.String(required=True, description="Instruction for generation"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "ideal_output": fields.String(description="Expected ideal output"), + }, + ) + ) + @api.response(200, "Instruction generated successfully") + @api.response(400, "Invalid request parameters or flow/workflow not found") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -203,7 +273,21 @@ class InstructionGenerateApi(Resource): raise CompletionRequestError(e.description) +@console_ns.route("/instruction-generate/template") class InstructionGenerationTemplateApi(Resource): + @api.doc("get_instruction_template") + @api.doc(description="Get instruction generation template") + @api.expect( + api.model( + "InstructionTemplateRequest", + { + "instruction": fields.String(required=True, description="Template instruction"), + "ideal_output": fields.String(description="Expected ideal output"), + }, + ) + ) + @api.response(200, "Template retrieved successfully") + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -222,10 +306,3 @@ class InstructionGenerationTemplateApi(Resource): return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE} case _: raise ValueError(f"Invalid type: {args['type']}") - - -api.add_resource(RuleGenerateApi, "/rule-generate") -api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") -api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate") -api.add_resource(InstructionGenerateApi, "/instruction-generate") -api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template") diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 541803e539..b9a383ee61 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,10 +2,10 @@ import json from enum import StrEnum from flask_login import current_user -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from werkzeug.exceptions import NotFound -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db @@ -19,7 +19,12 @@ class AppMCPServerStatus(StrEnum): INACTIVE = "inactive" +@console_ns.route("/apps//server") class AppMCPServerController(Resource): + @api.doc("get_app_mcp_server") + @api.doc(description="Get MCP server configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "MCP server configuration retrieved successfully", app_server_fields) @setup_required @login_required @account_initialization_required @@ -29,6 +34,20 @@ class AppMCPServerController(Resource): server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() return server + @api.doc("create_app_mcp_server") + @api.doc(description="Create MCP server configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MCPServerCreateRequest", + { + "description": fields.String(description="Server description"), + "parameters": fields.Raw(required=True, description="Server parameters configuration"), + }, + ) + ) + @api.response(201, "MCP server configuration created successfully", app_server_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -59,6 +78,23 @@ class AppMCPServerController(Resource): db.session.commit() return server + @api.doc("update_app_mcp_server") + @api.doc(description="Update MCP server configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MCPServerUpdateRequest", + { + "id": fields.String(required=True, description="Server ID"), + "description": fields.String(description="Server description"), + "parameters": fields.Raw(required=True, description="Server parameters configuration"), + "status": fields.String(description="Server status"), + }, + ) + ) + @api.response(200, "MCP server configuration updated successfully", app_server_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Server not found") @setup_required @login_required @account_initialization_required @@ -94,7 +130,14 @@ class AppMCPServerController(Resource): return server +@console_ns.route("/apps//server/refresh") class AppMCPServerRefreshController(Resource): + @api.doc("refresh_app_mcp_server") + @api.doc(description="Refresh MCP server configuration and regenerate server code") + @api.doc(params={"server_id": "Server ID"}) + @api.response(200, "MCP server refreshed successfully", app_server_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Server not found") @setup_required @login_required @account_initialization_required @@ -113,7 +156,3 @@ class AppMCPServerRefreshController(Resource): server.server_code = AppMCPServer.generate_server_code(16) db.session.commit() return server - - -api.add_resource(AppMCPServerController, "/apps//server") -api.add_resource(AppMCPServerRefreshController, "/apps//server/refresh") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index f0605a37f9..3bd9c53a85 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,12 +1,11 @@ import logging -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy import exists, select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -27,7 +26,8 @@ from extensions.ext_database import db from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination -from libs.login import login_required +from libs.login import current_user, login_required +from models.account import Account from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.annotation_service import AppAnnotationService from services.errors.conversation import ConversationNotExistsError @@ -37,6 +37,7 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) +@console_ns.route("/apps//chat-messages") class ChatMessageListApi(Resource): message_infinite_scroll_pagination_fields = { "limit": fields.Integer, @@ -44,6 +45,17 @@ class ChatMessageListApi(Resource): "data": fields.List(fields.Nested(message_detail_fields)), } + @api.doc("list_chat_messages") + @api.doc(description="Get chat messages for a conversation with pagination") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID") + .add_argument("first_id", type=str, location="args", help="First message ID for pagination") + .add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)") + ) + @api.response(200, "Success", message_infinite_scroll_pagination_fields) + @api.response(404, "Conversation not found") @setup_required @login_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @@ -117,12 +129,31 @@ class ChatMessageListApi(Resource): return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) +@console_ns.route("/apps//feedbacks") class MessageFeedbackApi(Resource): + @api.doc("create_message_feedback") + @api.doc(description="Create or update message feedback (like/dislike)") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MessageFeedbackRequest", + { + "message_id": fields.String(required=True, description="Message ID"), + "rating": fields.String(enum=["like", "dislike"], description="Feedback rating"), + }, + ) + ) + @api.response(200, "Feedback updated successfully") + @api.response(404, "Message not found") + @api.response(403, "Insufficient permissions") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def post(self, app_model): + if current_user is None: + raise Forbidden() + parser = reqparse.RequestParser() parser.add_argument("message_id", required=True, type=uuid_value, location="json") parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") @@ -159,7 +190,24 @@ class MessageFeedbackApi(Resource): return {"result": "success"} +@console_ns.route("/apps//annotations") class MessageAnnotationApi(Resource): + @api.doc("create_message_annotation") + @api.doc(description="Create message annotation") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MessageAnnotationRequest", + { + "message_id": fields.String(description="Message ID"), + "question": fields.String(required=True, description="Question text"), + "answer": fields.String(required=True, description="Answer text"), + "annotation_reply": fields.Raw(description="Annotation reply"), + }, + ) + ) + @api.response(200, "Annotation created successfully", annotation_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -167,7 +215,9 @@ class MessageAnnotationApi(Resource): @get_app_model @marshal_with(annotation_fields) def post(self, app_model): - if not current_user.is_editor: + if not isinstance(current_user, Account): + raise Forbidden() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -181,18 +231,37 @@ class MessageAnnotationApi(Resource): return annotation +@console_ns.route("/apps//annotations/count") class MessageAnnotationCountApi(Resource): + @api.doc("get_annotation_count") + @api.doc(description="Get count of message annotations for the app") + @api.doc(params={"app_id": "Application ID"}) + @api.response( + 200, + "Annotation count retrieved successfully", + api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() return {"count": count} +@console_ns.route("/apps//chat-messages//suggested-questions") class MessageSuggestedQuestionApi(Resource): + @api.doc("get_message_suggested_questions") + @api.doc(description="Get suggested questions for a message") + @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) + @api.response( + 200, + "Suggested questions retrieved successfully", + api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}), + ) + @api.response(404, "Message or conversation not found") @setup_required @login_required @account_initialization_required @@ -225,7 +294,13 @@ class MessageSuggestedQuestionApi(Resource): return {"data": questions} +@console_ns.route("/apps//messages/") class MessageApi(Resource): + @api.doc("get_message") + @api.doc(description="Get message details by ID") + @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) + @api.response(200, "Message retrieved successfully", message_detail_fields) + @api.response(404, "Message not found") @setup_required @login_required @account_initialization_required @@ -240,11 +315,3 @@ class MessageApi(Resource): raise NotFound("Message Not Exists.") return message - - -api.add_resource(MessageSuggestedQuestionApi, "/apps//chat-messages//suggested-questions") -api.add_resource(ChatMessageListApi, "/apps//chat-messages", endpoint="console_chat_messages") -api.add_resource(MessageFeedbackApi, "/apps//feedbacks") -api.add_resource(MessageAnnotationApi, "/apps//annotations") -api.add_resource(MessageAnnotationCountApi, "/apps//annotations/count") -api.add_resource(MessageApi, "/apps//messages/", endpoint="console_message") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 52ff9b923d..11df511840 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -3,9 +3,10 @@ from typing import cast from flask import request from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields +from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.agent.entities import AgentToolEntity @@ -14,17 +15,51 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.login import login_required +from models.account import Account from models.model import AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService +@console_ns.route("/apps//model-config") class ModelConfigResource(Resource): + @api.doc("update_app_model_config") + @api.doc(description="Update application model configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "ModelConfigRequest", + { + "provider": fields.String(description="Model provider"), + "model": fields.String(description="Model name"), + "configs": fields.Raw(description="Model configuration parameters"), + "opening_statement": fields.String(description="Opening statement"), + "suggested_questions": fields.List(fields.String(), description="Suggested questions"), + "more_like_this": fields.Raw(description="More like this configuration"), + "speech_to_text": fields.Raw(description="Speech to text configuration"), + "text_to_speech": fields.Raw(description="Text to speech configuration"), + "retrieval_model": fields.Raw(description="Retrieval model configuration"), + "tools": fields.List(fields.Raw(), description="Available tools"), + "dataset_configs": fields.Raw(description="Dataset configurations"), + "agent_mode": fields.Raw(description="Agent mode configuration"), + }, + ) + ) + @api.response(200, "Model configuration updated successfully") + @api.response(400, "Invalid configuration") + @api.response(404, "App not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): """Modify app model config""" + if not isinstance(current_user, Account): + raise Forbidden() + + if not current_user.has_edit_permission: + raise Forbidden() + + assert current_user.current_tenant_id is not None, "The tenant information should be loaded." # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, @@ -39,7 +74,7 @@ class ModelConfigResource(Resource): ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) - if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: # get original app model config original_app_model_config = ( db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() @@ -142,6 +177,3 @@ class ModelConfigResource(Resource): app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) return {"result": "success"} - - -api.add_resource(ModelConfigResource, "/apps//model-config") diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 74c2867c2f..981974e842 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,18 +1,31 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from werkzeug.exceptions import BadRequest -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.ops_service import OpsService +@console_ns.route("/apps//trace-config") class TraceAppConfigApi(Resource): """ Manage trace app configurations """ + @api.doc("get_trace_app_config") + @api.doc(description="Get tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser().add_argument( + "tracing_provider", type=str, required=True, location="args", help="Tracing provider name" + ) + ) + @api.response( + 200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data") + ) + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -29,6 +42,22 @@ class TraceAppConfigApi(Resource): except Exception as e: raise BadRequest(str(e)) + @api.doc("create_trace_app_config") + @api.doc(description="Create a new tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "TraceConfigCreateRequest", + { + "tracing_provider": fields.String(required=True, description="Tracing provider name"), + "tracing_config": fields.Raw(required=True, description="Tracing configuration data"), + }, + ) + ) + @api.response( + 201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data") + ) + @api.response(400, "Invalid request parameters or configuration already exists") @setup_required @login_required @account_initialization_required @@ -51,6 +80,20 @@ class TraceAppConfigApi(Resource): except Exception as e: raise BadRequest(str(e)) + @api.doc("update_trace_app_config") + @api.doc(description="Update an existing tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "TraceConfigUpdateRequest", + { + "tracing_provider": fields.String(required=True, description="Tracing provider name"), + "tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"), + }, + ) + ) + @api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response")) + @api.response(400, "Invalid request parameters or configuration not found") @setup_required @login_required @account_initialization_required @@ -71,6 +114,16 @@ class TraceAppConfigApi(Resource): except Exception as e: raise BadRequest(str(e)) + @api.doc("delete_trace_app_config") + @api.doc(description="Delete an existing tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser().add_argument( + "tracing_provider", type=str, required=True, location="args", help="Tracing provider name" + ) + ) + @api.response(204, "Tracing configuration deleted successfully") + @api.response(400, "Invalid request parameters or configuration not found") @setup_required @login_required @account_initialization_required @@ -87,6 +140,3 @@ class TraceAppConfigApi(Resource): return {"result": "success"}, 204 except Exception as e: raise BadRequest(str(e)) - - -api.add_resource(TraceAppConfigApi, "/apps//trace-config") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 778ce92da6..95befc5df9 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,16 +1,16 @@ from flask_login import current_user -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.datetime_utils import naive_utc_now from libs.login import login_required -from models import Site +from models import Account, Site def parse_app_site_args(): @@ -36,7 +36,39 @@ def parse_app_site_args(): return parser.parse_args() +@console_ns.route("/apps//site") class AppSite(Resource): + @api.doc("update_app_site") + @api.doc(description="Update application site configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppSiteRequest", + { + "title": fields.String(description="Site title"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + "description": fields.String(description="Site description"), + "default_language": fields.String(description="Default language"), + "chat_color_theme": fields.String(description="Chat color theme"), + "chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"), + "customize_domain": fields.String(description="Custom domain"), + "copyright": fields.String(description="Copyright text"), + "privacy_policy": fields.String(description="Privacy policy"), + "custom_disclaimer": fields.String(description="Custom disclaimer"), + "customize_token_strategy": fields.String( + enum=["must", "allow", "not_allow"], description="Token strategy" + ), + "prompt_public": fields.Boolean(description="Make prompt public"), + "show_workflow_steps": fields.Boolean(description="Show workflow steps"), + "use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"), + }, + ) + ) + @api.response(200, "Site configuration updated successfully", app_site_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "App not found") @setup_required @login_required @account_initialization_required @@ -75,6 +107,8 @@ class AppSite(Resource): if value is not None: setattr(site, attr_name, value) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") site.updated_by = current_user.id site.updated_at = naive_utc_now() db.session.commit() @@ -82,7 +116,14 @@ class AppSite(Resource): return site +@console_ns.route("/apps//site/access-token-reset") class AppSiteAccessTokenReset(Resource): + @api.doc("reset_app_site_access_token") + @api.doc(description="Reset access token for application site") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Access token reset successfully", app_site_fields) + @api.response(403, "Insufficient permissions (admin/owner required)") + @api.response(404, "App or site not found") @setup_required @login_required @account_initialization_required @@ -99,12 +140,10 @@ class AppSiteAccessTokenReset(Resource): raise NotFound site.code = Site.generate_code(16) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") site.updated_by = current_user.id site.updated_at = naive_utc_now() db.session.commit() return site - - -api.add_resource(AppSite, "/apps//site") -api.add_resource(AppSiteAccessTokenReset, "/apps//site/access-token-reset") diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 27e405af38..6894458578 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -5,9 +5,9 @@ import pytz import sqlalchemy as sa from flask import jsonify from flask_login import current_user -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom @@ -17,11 +17,25 @@ from libs.login import login_required from models import AppMode, Message +@console_ns.route("/apps//statistics/daily-messages") class DailyMessageStatistic(Resource): + @api.doc("get_daily_message_statistics") + @api.doc(description="Get daily message statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily message statistics retrieved successfully", + fields.List(fields.Raw(description="Daily message count data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -74,11 +88,25 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/daily-conversations") class DailyConversationStatistic(Resource): + @api.doc("get_daily_conversation_statistics") + @api.doc(description="Get daily conversation statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily conversation statistics retrieved successfully", + fields.List(fields.Raw(description="Daily conversation count data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -126,11 +154,25 @@ class DailyConversationStatistic(Resource): return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/daily-end-users") class DailyTerminalsStatistic(Resource): + @api.doc("get_daily_terminals_statistics") + @api.doc(description="Get daily terminal/end-user statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily terminal statistics retrieved successfully", + fields.List(fields.Raw(description="Daily terminal count data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -183,11 +225,25 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/token-costs") class DailyTokenCostStatistic(Resource): + @api.doc("get_daily_token_cost_statistics") + @api.doc(description="Get daily token cost statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily token cost statistics retrieved successfully", + fields.List(fields.Raw(description="Daily token cost data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -243,7 +299,21 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/average-session-interactions") class AverageSessionInteractionStatistic(Resource): + @api.doc("get_average_session_interaction_statistics") + @api.doc(description="Get average session interaction statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Average session interaction statistics retrieved successfully", + fields.List(fields.Raw(description="Average session interaction data")), + ) @setup_required @login_required @account_initialization_required @@ -319,11 +389,25 @@ ORDER BY return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/user-satisfaction-rate") class UserSatisfactionRateStatistic(Resource): + @api.doc("get_user_satisfaction_rate_statistics") + @api.doc(description="Get user satisfaction rate statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "User satisfaction rate statistics retrieved successfully", + fields.List(fields.Raw(description="User satisfaction rate data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -385,7 +469,21 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/average-response-time") class AverageResponseTimeStatistic(Resource): + @api.doc("get_average_response_time_statistics") + @api.doc(description="Get average response time statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Average response time statistics retrieved successfully", + fields.List(fields.Raw(description="Average response time data")), + ) @setup_required @login_required @account_initialization_required @@ -442,11 +540,25 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/tokens-per-second") class TokensPerSecondStatistic(Resource): + @api.doc("get_tokens_per_second_statistics") + @api.doc(description="Get tokens per second statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Tokens per second statistics retrieved successfully", + fields.List(fields.Raw(description="Tokens per second data")), + ) + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -500,13 +612,3 @@ WHERE response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) return jsonify({"data": response_data}) - - -api.add_resource(DailyMessageStatistic, "/apps//statistics/daily-messages") -api.add_resource(DailyConversationStatistic, "/apps//statistics/daily-conversations") -api.add_resource(DailyTerminalsStatistic, "/apps//statistics/daily-end-users") -api.add_resource(DailyTokenCostStatistic, "/apps//statistics/token-costs") -api.add_resource(AverageSessionInteractionStatistic, "/apps//statistics/average-session-interactions") -api.add_resource(UserSatisfactionRateStatistic, "/apps//statistics/user-satisfaction-rate") -api.add_resource(AverageResponseTimeStatistic, "/apps//statistics/average-response-time") -api.add_resource(TokensPerSecondStatistic, "/apps//statistics/tokens-per-second") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 1a472e771d..a8fadb4a4f 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -4,13 +4,13 @@ from collections.abc import Sequence from typing import cast from flask import abort, request -from flask_restx import Resource, inputs, marshal_with, reqparse +from flask_restx import Resource, fields, inputs, marshal_with, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from configs import dify_config -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required @@ -58,7 +58,13 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence return file_objs +@console_ns.route("/apps//workflows/draft") class DraftWorkflowApi(Resource): + @api.doc("get_draft_workflow") + @api.doc(description="Get draft workflow for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Draft workflow retrieved successfully", workflow_fields) + @api.response(404, "Draft workflow not found") @setup_required @login_required @account_initialization_required @@ -70,7 +76,7 @@ class DraftWorkflowApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() # fetch draft workflow by app_model @@ -87,13 +93,30 @@ class DraftWorkflowApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @api.doc("sync_draft_workflow") + @api.doc(description="Sync draft workflow configuration") + @api.expect( + api.model( + "SyncDraftWorkflowRequest", + { + "graph": fields.Raw(required=True, description="Workflow graph configuration"), + "features": fields.Raw(required=True, description="Workflow features configuration"), + "hash": fields.String(description="Workflow hash for validation"), + "environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"), + "conversation_variables": fields.List(fields.Raw, description="Conversation variables"), + }, + ) + ) + @api.response(200, "Draft workflow synced successfully", workflow_fields) + @api.response(400, "Invalid workflow configuration") + @api.response(403, "Permission denied") def post(self, app_model: App): """ Sync draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() content_type = request.headers.get("Content-Type", "") @@ -160,7 +183,25 @@ class DraftWorkflowApi(Resource): } +@console_ns.route("/apps//advanced-chat/workflows/draft/run") class AdvancedChatDraftWorkflowRunApi(Resource): + @api.doc("run_advanced_chat_draft_workflow") + @api.doc(description="Run draft workflow for advanced chat application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AdvancedChatWorkflowRunRequest", + { + "query": fields.String(required=True, description="User query"), + "inputs": fields.Raw(description="Input variables"), + "files": fields.List(fields.Raw, description="File uploads"), + "conversation_id": fields.String(description="Conversation ID"), + }, + ) + ) + @api.response(200, "Workflow run started successfully") + @api.response(400, "Invalid request parameters") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -171,7 +212,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() if not isinstance(current_user, Account): @@ -209,7 +250,23 @@ class AdvancedChatDraftWorkflowRunApi(Resource): raise InternalServerError() +@console_ns.route("/apps//advanced-chat/workflows/draft/iteration/nodes//run") class AdvancedChatDraftRunIterationNodeApi(Resource): + @api.doc("run_advanced_chat_draft_iteration_node") + @api.doc(description="Run draft workflow iteration node for advanced chat") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.expect( + api.model( + "IterationNodeRunRequest", + { + "task_id": fields.String(required=True, description="Task ID"), + "inputs": fields.Raw(description="Input variables"), + }, + ) + ) + @api.response(200, "Iteration node run started successfully") + @api.response(403, "Permission denied") + @api.response(404, "Node not found") @setup_required @login_required @account_initialization_required @@ -221,7 +278,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -245,7 +302,23 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): raise InternalServerError() +@console_ns.route("/apps//workflows/draft/iteration/nodes//run") class WorkflowDraftRunIterationNodeApi(Resource): + @api.doc("run_workflow_draft_iteration_node") + @api.doc(description="Run draft workflow iteration node") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.expect( + api.model( + "WorkflowIterationNodeRunRequest", + { + "task_id": fields.String(required=True, description="Task ID"), + "inputs": fields.Raw(description="Input variables"), + }, + ) + ) + @api.response(200, "Workflow iteration node run started successfully") + @api.response(403, "Permission denied") + @api.response(404, "Node not found") @setup_required @login_required @account_initialization_required @@ -257,7 +330,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not isinstance(current_user, Account): raise Forbidden() - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -281,7 +354,23 @@ class WorkflowDraftRunIterationNodeApi(Resource): raise InternalServerError() +@console_ns.route("/apps//advanced-chat/workflows/draft/loop/nodes//run") class AdvancedChatDraftRunLoopNodeApi(Resource): + @api.doc("run_advanced_chat_draft_loop_node") + @api.doc(description="Run draft workflow loop node for advanced chat") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.expect( + api.model( + "LoopNodeRunRequest", + { + "task_id": fields.String(required=True, description="Task ID"), + "inputs": fields.Raw(description="Input variables"), + }, + ) + ) + @api.response(200, "Loop node run started successfully") + @api.response(403, "Permission denied") + @api.response(404, "Node not found") @setup_required @login_required @account_initialization_required @@ -294,7 +383,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -318,7 +407,23 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): raise InternalServerError() +@console_ns.route("/apps//workflows/draft/loop/nodes//run") class WorkflowDraftRunLoopNodeApi(Resource): + @api.doc("run_workflow_draft_loop_node") + @api.doc(description="Run draft workflow loop node") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.expect( + api.model( + "WorkflowLoopNodeRunRequest", + { + "task_id": fields.String(required=True, description="Task ID"), + "inputs": fields.Raw(description="Input variables"), + }, + ) + ) + @api.response(200, "Workflow loop node run started successfully") + @api.response(403, "Permission denied") + @api.response(404, "Node not found") @setup_required @login_required @account_initialization_required @@ -331,7 +436,7 @@ class WorkflowDraftRunLoopNodeApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -355,7 +460,22 @@ class WorkflowDraftRunLoopNodeApi(Resource): raise InternalServerError() +@console_ns.route("/apps//workflows/draft/run") class DraftWorkflowRunApi(Resource): + @api.doc("run_draft_workflow") + @api.doc(description="Run draft workflow") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "DraftWorkflowRunRequest", + { + "inputs": fields.Raw(required=True, description="Input variables"), + "files": fields.List(fields.Raw, description="File uploads"), + }, + ) + ) + @api.response(200, "Draft workflow run started successfully") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -368,7 +488,7 @@ class DraftWorkflowRunApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -394,7 +514,14 @@ class DraftWorkflowRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) +@console_ns.route("/apps//workflows/tasks//stop") class WorkflowTaskStopApi(Resource): + @api.doc("stop_workflow_task") + @api.doc(description="Stop running workflow task") + @api.doc(params={"app_id": "Application ID", "task_id": "Task ID"}) + @api.response(200, "Task stopped successfully") + @api.response(404, "Task not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -407,7 +534,7 @@ class WorkflowTaskStopApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() # Stop using both mechanisms for backward compatibility @@ -420,7 +547,22 @@ class WorkflowTaskStopApi(Resource): return {"result": "success"} +@console_ns.route("/apps//workflows/draft/nodes//run") class DraftWorkflowNodeRunApi(Resource): + @api.doc("run_draft_workflow_node") + @api.doc(description="Run draft workflow node") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.expect( + api.model( + "DraftWorkflowNodeRunRequest", + { + "inputs": fields.Raw(description="Input variables"), + }, + ) + ) + @api.response(200, "Node run started successfully", workflow_run_node_execution_fields) + @api.response(403, "Permission denied") + @api.response(404, "Node not found") @setup_required @login_required @account_initialization_required @@ -434,7 +576,7 @@ class DraftWorkflowNodeRunApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -468,7 +610,13 @@ class DraftWorkflowNodeRunApi(Resource): return workflow_node_execution +@console_ns.route("/apps//workflows/publish") class PublishedWorkflowApi(Resource): + @api.doc("get_published_workflow") + @api.doc(description="Get published workflow for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Published workflow retrieved successfully", workflow_fields) + @api.response(404, "Published workflow not found") @setup_required @login_required @account_initialization_required @@ -482,7 +630,7 @@ class PublishedWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() # fetch published workflow by app_model @@ -503,7 +651,7 @@ class PublishedWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -540,7 +688,12 @@ class PublishedWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/default-block-configs") class DefaultBlockConfigsApi(Resource): + @api.doc("get_default_block_configs") + @api.doc(description="Get default block configurations for workflow") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Default block configurations retrieved successfully") @setup_required @login_required @account_initialization_required @@ -553,7 +706,7 @@ class DefaultBlockConfigsApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() # Get default block configs @@ -561,7 +714,13 @@ class DefaultBlockConfigsApi(Resource): return workflow_service.get_default_block_configs() +@console_ns.route("/apps//workflows/default-block-configs/") class DefaultBlockConfigApi(Resource): + @api.doc("get_default_block_config") + @api.doc(description="Get default block configuration by type") + @api.doc(params={"app_id": "Application ID", "block_type": "Block type"}) + @api.response(200, "Default block configuration retrieved successfully") + @api.response(404, "Block type not found") @setup_required @login_required @account_initialization_required @@ -573,7 +732,7 @@ class DefaultBlockConfigApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -594,7 +753,14 @@ class DefaultBlockConfigApi(Resource): return workflow_service.get_default_block_config(node_type=block_type, filters=filters) +@console_ns.route("/apps//convert-to-workflow") class ConvertToWorkflowApi(Resource): + @api.doc("convert_to_workflow") + @api.doc(description="Convert application to workflow mode") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Application converted to workflow successfully") + @api.response(400, "Application cannot be converted") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -608,7 +774,7 @@ class ConvertToWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() if request.data: @@ -631,9 +797,14 @@ class ConvertToWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/config") class WorkflowConfigApi(Resource): """Resource for workflow configuration.""" + @api.doc("get_workflow_config") + @api.doc(description="Get workflow configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Workflow configuration retrieved successfully") @setup_required @login_required @account_initialization_required @@ -644,7 +815,12 @@ class WorkflowConfigApi(Resource): } +@console_ns.route("/apps//workflows/published") class PublishedAllWorkflowApi(Resource): + @api.doc("get_all_published_workflows") + @api.doc(description="Get all published workflows for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Published workflows retrieved successfully", workflow_pagination_fields) @setup_required @login_required @account_initialization_required @@ -657,7 +833,7 @@ class PublishedAllWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -695,7 +871,23 @@ class PublishedAllWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/") class WorkflowByIdApi(Resource): + @api.doc("update_workflow_by_id") + @api.doc(description="Update workflow by ID") + @api.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"}) + @api.expect( + api.model( + "UpdateWorkflowRequest", + { + "environment_variables": fields.List(fields.Raw, description="Environment variables"), + "conversation_variables": fields.List(fields.Raw, description="Conversation variables"), + }, + ) + ) + @api.response(200, "Workflow updated successfully", workflow_fields) + @api.response(404, "Workflow not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -708,7 +900,7 @@ class WorkflowByIdApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # Check permission - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -721,7 +913,6 @@ class WorkflowByIdApi(Resource): raise ValueError("Marked name cannot exceed 20 characters") if args.marked_comment and len(args.marked_comment) > 100: raise ValueError("Marked comment cannot exceed 100 characters") - args = parser.parse_args() # Prepare update data update_data = {} @@ -764,7 +955,7 @@ class WorkflowByIdApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # Check permission - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() workflow_service = WorkflowService() @@ -787,7 +978,14 @@ class WorkflowByIdApi(Resource): return None, 204 +@console_ns.route("/apps//workflows/draft/nodes//last-run") class DraftWorkflowNodeLastRunApi(Resource): + @api.doc("get_draft_workflow_node_last_run") + @api.doc(description="Get last run result for draft workflow node") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.response(200, "Node last run retrieved successfully", workflow_run_node_execution_fields) + @api.response(404, "Node last run not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -806,73 +1004,3 @@ class DraftWorkflowNodeLastRunApi(Resource): if node_exec is None: raise NotFound("last run not found") return node_exec - - -api.add_resource( - DraftWorkflowApi, - "/apps//workflows/draft", -) -api.add_resource( - WorkflowConfigApi, - "/apps//workflows/draft/config", -) -api.add_resource( - AdvancedChatDraftWorkflowRunApi, - "/apps//advanced-chat/workflows/draft/run", -) -api.add_resource( - DraftWorkflowRunApi, - "/apps//workflows/draft/run", -) -api.add_resource( - WorkflowTaskStopApi, - "/apps//workflow-runs/tasks//stop", -) -api.add_resource( - DraftWorkflowNodeRunApi, - "/apps//workflows/draft/nodes//run", -) -api.add_resource( - AdvancedChatDraftRunIterationNodeApi, - "/apps//advanced-chat/workflows/draft/iteration/nodes//run", -) -api.add_resource( - WorkflowDraftRunIterationNodeApi, - "/apps//workflows/draft/iteration/nodes//run", -) -api.add_resource( - AdvancedChatDraftRunLoopNodeApi, - "/apps//advanced-chat/workflows/draft/loop/nodes//run", -) -api.add_resource( - WorkflowDraftRunLoopNodeApi, - "/apps//workflows/draft/loop/nodes//run", -) -api.add_resource( - PublishedWorkflowApi, - "/apps//workflows/publish", -) -api.add_resource( - PublishedAllWorkflowApi, - "/apps//workflows", -) -api.add_resource( - DefaultBlockConfigsApi, - "/apps//workflows/default-workflow-block-configs", -) -api.add_resource( - DefaultBlockConfigApi, - "/apps//workflows/default-workflow-block-configs/", -) -api.add_resource( - ConvertToWorkflowApi, - "/apps//convert-to-workflow", -) -api.add_resource( - WorkflowByIdApi, - "/apps//workflows/", -) -api.add_resource( - DraftWorkflowNodeLastRunApi, - "/apps//workflows/draft/nodes//last-run", -) diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 8f4dcbfd42..8e24be4fa7 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -3,7 +3,7 @@ from flask_restx import Resource, marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy.orm import Session -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.workflow.enums import WorkflowExecutionStatus @@ -15,7 +15,24 @@ from models.model import AppMode from services.workflow_app_service import WorkflowAppService +@console_ns.route("/apps//workflow-app-logs") class WorkflowAppLogApi(Resource): + @api.doc("get_workflow_app_logs") + @api.doc(description="Get workflow application execution logs") + @api.doc(params={"app_id": "Application ID"}) + @api.doc( + params={ + "keyword": "Search keyword for filtering logs", + "status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)", + "created_at__before": "Filter logs created before this timestamp", + "created_at__after": "Filter logs created after this timestamp", + "created_by_end_user_session_id": "Filter by end user session ID", + "created_by_account": "Filter by account", + "page": "Page number (1-99999)", + "limit": "Number of items per page (1-100)", + } + ) + @api.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_fields) @setup_required @login_required @account_initialization_required @@ -78,6 +95,3 @@ class WorkflowAppLogApi(Resource): ) return workflow_app_log_pagination - - -api.add_resource(WorkflowAppLogApi, "/apps//workflow-app-logs") diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 5467e3cd5e..da6b56d026 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -6,7 +6,7 @@ from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqpars from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( DraftWorkflowNotExist, ) @@ -157,14 +157,20 @@ def _api_prerequisite(f): @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) def wrapper(*args, **kwargs): assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() return f(*args, **kwargs) return wrapper +@console_ns.route("/apps//workflows/draft/variables") class WorkflowVariableCollectionApi(Resource): + @api.doc("get_workflow_variables") + @api.doc(description="Get draft workflow variables") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"}) + @api.response(200, "Workflow variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) def get(self, app_model: App): @@ -193,6 +199,9 @@ class WorkflowVariableCollectionApi(Resource): return workflow_vars + @api.doc("delete_workflow_variables") + @api.doc(description="Delete all draft workflow variables") + @api.response(204, "Workflow variables deleted successfully") @_api_prerequisite def delete(self, app_model: App): draft_var_srv = WorkflowDraftVariableService( @@ -221,7 +230,12 @@ def validate_node_id(node_id: str) -> NoReturn | None: return None +@console_ns.route("/apps//workflows/draft/nodes//variables") class NodeVariableCollectionApi(Resource): + @api.doc("get_node_variables") + @api.doc(description="Get variables for a specific node") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.response(200, "Node variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) def get(self, app_model: App, node_id: str): @@ -234,6 +248,9 @@ class NodeVariableCollectionApi(Resource): return node_vars + @api.doc("delete_node_variables") + @api.doc(description="Delete all variables for a specific node") + @api.response(204, "Node variables deleted successfully") @_api_prerequisite def delete(self, app_model: App, node_id: str): validate_node_id(node_id) @@ -243,10 +260,16 @@ class NodeVariableCollectionApi(Resource): return Response("", 204) +@console_ns.route("/apps//workflows/draft/variables/") class VariableApi(Resource): _PATCH_NAME_FIELD = "name" _PATCH_VALUE_FIELD = "value" + @api.doc("get_variable") + @api.doc(description="Get a specific workflow variable") + @api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"}) + @api.response(200, "Variable retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS) + @api.response(404, "Variable not found") @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) def get(self, app_model: App, variable_id: str): @@ -260,6 +283,19 @@ class VariableApi(Resource): raise NotFoundError(description=f"variable not found, id={variable_id}") return variable + @api.doc("update_variable") + @api.doc(description="Update a workflow variable") + @api.expect( + api.model( + "UpdateVariableRequest", + { + "name": fields.String(description="Variable name"), + "value": fields.Raw(description="Variable value"), + }, + ) + ) + @api.response(200, "Variable updated successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS) + @api.response(404, "Variable not found") @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) def patch(self, app_model: App, variable_id: str): @@ -322,6 +358,10 @@ class VariableApi(Resource): db.session.commit() return variable + @api.doc("delete_variable") + @api.doc(description="Delete a workflow variable") + @api.response(204, "Variable deleted successfully") + @api.response(404, "Variable not found") @_api_prerequisite def delete(self, app_model: App, variable_id: str): draft_var_srv = WorkflowDraftVariableService( @@ -337,7 +377,14 @@ class VariableApi(Resource): return Response("", 204) +@console_ns.route("/apps//workflows/draft/variables//reset") class VariableResetApi(Resource): + @api.doc("reset_variable") + @api.doc(description="Reset a workflow variable to its default value") + @api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"}) + @api.response(200, "Variable reset successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS) + @api.response(204, "Variable reset (no content)") + @api.response(404, "Variable not found") @_api_prerequisite def put(self, app_model: App, variable_id: str): draft_var_srv = WorkflowDraftVariableService( @@ -378,7 +425,13 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: return draft_vars +@console_ns.route("/apps//workflows/draft/conversation-variables") class ConversationVariableCollectionApi(Resource): + @api.doc("get_conversation_variables") + @api.doc(description="Get conversation variables for workflow") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Conversation variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + @api.response(404, "Draft workflow not found") @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) def get(self, app_model: App): @@ -394,14 +447,25 @@ class ConversationVariableCollectionApi(Resource): return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) +@console_ns.route("/apps//workflows/draft/system-variables") class SystemVariableCollectionApi(Resource): + @api.doc("get_system_variables") + @api.doc(description="Get system variables for workflow") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "System variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) def get(self, app_model: App): return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID) +@console_ns.route("/apps//workflows/draft/environment-variables") class EnvironmentVariableCollectionApi(Resource): + @api.doc("get_environment_variables") + @api.doc(description="Get environment variables for workflow") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Environment variables retrieved successfully") + @api.response(404, "Draft workflow not found") @_api_prerequisite def get(self, app_model: App): """ @@ -433,16 +497,3 @@ class EnvironmentVariableCollectionApi(Resource): ) return {"items": env_vars_list} - - -api.add_resource( - WorkflowVariableCollectionApi, - "/apps//workflows/draft/variables", -) -api.add_resource(NodeVariableCollectionApi, "/apps//workflows/draft/nodes//variables") -api.add_resource(VariableApi, "/apps//workflows/draft/variables/") -api.add_resource(VariableResetApi, "/apps//workflows/draft/variables//reset") - -api.add_resource(ConversationVariableCollectionApi, "/apps//workflows/draft/conversation-variables") -api.add_resource(SystemVariableCollectionApi, "/apps//workflows/draft/system-variables") -api.add_resource(EnvironmentVariableCollectionApi, "/apps//workflows/draft/environment-variables") diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index dccbfd8648..23ba63845c 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -4,7 +4,7 @@ from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from flask_restx.inputs import int_range -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from fields.workflow_run_fields import ( @@ -19,7 +19,13 @@ from models import Account, App, AppMode, EndUser from services.workflow_run_service import WorkflowRunService +@console_ns.route("/apps//advanced-chat/workflow-runs") class AdvancedChatAppWorkflowRunListApi(Resource): + @api.doc("get_advanced_chat_workflow_runs") + @api.doc(description="Get advanced chat workflow run list") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) + @api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields) @setup_required @login_required @account_initialization_required @@ -40,7 +46,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource): return result +@console_ns.route("/apps//workflow-runs") class WorkflowRunListApi(Resource): + @api.doc("get_workflow_runs") + @api.doc(description="Get workflow run list") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) + @api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields) @setup_required @login_required @account_initialization_required @@ -61,7 +73,13 @@ class WorkflowRunListApi(Resource): return result +@console_ns.route("/apps//workflow-runs/") class WorkflowRunDetailApi(Resource): + @api.doc("get_workflow_run_detail") + @api.doc(description="Get workflow run detail") + @api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @api.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_fields) + @api.response(404, "Workflow run not found") @setup_required @login_required @account_initialization_required @@ -79,7 +97,13 @@ class WorkflowRunDetailApi(Resource): return workflow_run +@console_ns.route("/apps//workflow-runs//node-executions") class WorkflowRunNodeExecutionListApi(Resource): + @api.doc("get_workflow_run_node_executions") + @api.doc(description="Get workflow run node execution list") + @api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @api.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_fields) + @api.response(404, "Workflow run not found") @setup_required @login_required @account_initialization_required @@ -100,9 +124,3 @@ class WorkflowRunNodeExecutionListApi(Resource): ) return {"data": node_executions} - - -api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps//advanced-chat/workflow-runs") -api.add_resource(WorkflowRunListApi, "/apps//workflow-runs") -api.add_resource(WorkflowRunDetailApi, "/apps//workflow-runs/") -api.add_resource(WorkflowRunNodeExecutionListApi, "/apps//workflow-runs//node-executions") diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 7cef175c14..535e7cadd6 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -7,7 +7,7 @@ from flask import jsonify from flask_login import current_user from flask_restx import Resource, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db @@ -17,11 +17,17 @@ from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode +@console_ns.route("/apps//workflow/statistics/daily-conversations") class WorkflowDailyRunsStatistic(Resource): + @api.doc("get_workflow_daily_runs_statistic") + @api.doc(description="Get workflow daily runs statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) + @api.response(200, "Daily runs statistics retrieved successfully") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -79,11 +85,17 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//workflow/statistics/daily-terminals") class WorkflowDailyTerminalsStatistic(Resource): + @api.doc("get_workflow_daily_terminals_statistic") + @api.doc(description="Get workflow daily terminals statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) + @api.response(200, "Daily terminals statistics retrieved successfully") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -141,11 +153,17 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//workflow/statistics/token-costs") class WorkflowDailyTokenCostStatistic(Resource): + @api.doc("get_workflow_daily_token_cost_statistic") + @api.doc(description="Get workflow daily token cost statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) + @api.response(200, "Daily token cost statistics retrieved successfully") + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -208,7 +226,13 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//workflow/statistics/average-app-interactions") class WorkflowAverageAppInteractionStatistic(Resource): + @api.doc("get_workflow_average_app_interaction_statistic") + @api.doc(description="Get workflow average app interaction statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) + @api.response(200, "Average app interaction statistics retrieved successfully") @setup_required @login_required @account_initialization_required @@ -285,11 +309,3 @@ GROUP BY ) return jsonify({"data": response_data}) - - -api.add_resource(WorkflowDailyRunsStatistic, "/apps//workflow/statistics/daily-conversations") -api.add_resource(WorkflowDailyTerminalsStatistic, "/apps//workflow/statistics/daily-terminals") -api.add_resource(WorkflowDailyTokenCostStatistic, "/apps//workflow/statistics/token-costs") -api.add_resource( - WorkflowAverageAppInteractionStatistic, "/apps//workflow/statistics/average-app-interactions" -) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index c7e300279a..44aba01820 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional, Union +from typing import ParamSpec, TypeVar, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db @@ -8,8 +8,11 @@ from libs.login import current_user from models import App, AppMode from models.account import Account +P = ParamSpec("P") +R = TypeVar("R") -def _load_app_model(app_id: str) -> Optional[App]: + +def _load_app_model(app_id: str) -> App | None: assert isinstance(current_user, Account) app_model = ( db.session.query(App) @@ -19,10 +22,10 @@ def _load_app_model(app_id: str) -> Optional[App]: return app_model -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): - def decorator(view_func): +def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): if not kwargs.get("app_id"): raise ValueError("missing app_id in path parameters") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index e82e403ec2..8cdadfb03c 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,8 +1,8 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from constants.languages import supported_language -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -10,14 +10,36 @@ from libs.helper import StrLen, email, extract_remote_ip, timezone from models.account import AccountStatus from services.account_service import AccountService, RegisterService +active_check_parser = reqparse.RequestParser() +active_check_parser.add_argument( + "workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID" +) +active_check_parser.add_argument( + "email", type=email, required=False, nullable=True, location="args", help="Email address" +) +active_check_parser.add_argument( + "token", type=str, required=True, nullable=False, location="args", help="Activation token" +) + +@console_ns.route("/activate/check") class ActivateCheckApi(Resource): + @api.doc("check_activation_token") + @api.doc(description="Check if activation token is valid") + @api.expect(active_check_parser) + @api.response( + 200, + "Success", + api.model( + "ActivationCheckResponse", + { + "is_valid": fields.Boolean(description="Whether token is valid"), + "data": fields.Raw(description="Activation data if valid"), + }, + ), + ) def get(self): - parser = reqparse.RequestParser() - parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args") - parser.add_argument("email", type=email, required=False, nullable=True, location="args") - parser.add_argument("token", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() + args = active_check_parser.parse_args() workspaceId = args["workspace_id"] reg_email = args["email"] @@ -38,18 +60,36 @@ class ActivateCheckApi(Resource): return {"is_valid": False} +active_parser = reqparse.RequestParser() +active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") +active_parser.add_argument("email", type=email, required=False, nullable=True, location="json") +active_parser.add_argument("token", type=str, required=True, nullable=False, location="json") +active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") +active_parser.add_argument( + "interface_language", type=supported_language, required=True, nullable=False, location="json" +) +active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") + + +@console_ns.route("/activate") class ActivateApi(Resource): + @api.doc("activate_account") + @api.doc(description="Activate account with invitation token") + @api.expect(active_parser) + @api.response( + 200, + "Account activated successfully", + api.model( + "ActivationResponse", + { + "result": fields.String(description="Operation result"), + "data": fields.Raw(description="Login token data"), + }, + ), + ) + @api.response(400, "Already activated or invalid token") def post(self): - parser = reqparse.RequestParser() - parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") - parser.add_argument("email", type=email, required=False, nullable=True, location="json") - parser.add_argument("token", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") - parser.add_argument( - "interface_language", type=supported_language, required=True, nullable=False, location="json" - ) - parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") - args = parser.parse_args() + args = active_parser.parse_args() invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) if invitation is None: @@ -70,7 +110,3 @@ class ActivateApi(Resource): token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) return {"result": "success", "data": token_pair.model_dump()} - - -api.add_resource(ActivateCheckApi, "/activate/check") -api.add_resource(ActivateApi, "/activate") diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 8f57b3d03e..fc4ba3a2c7 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -3,11 +3,11 @@ import logging import requests from flask import current_app, redirect, request from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.console import api +from controllers.console import api, console_ns from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -28,7 +28,21 @@ def get_oauth_providers(): return OAUTH_PROVIDERS +@console_ns.route("/oauth/data-source/") class OAuthDataSource(Resource): + @api.doc("oauth_data_source") + @api.doc(description="Get OAuth authorization URL for data source provider") + @api.doc(params={"provider": "Data source provider name (notion)"}) + @api.response( + 200, + "Authorization URL or internal setup success", + api.model( + "OAuthDataSourceResponse", + {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, + ), + ) + @api.response(400, "Invalid provider") + @api.response(403, "Admin privileges required") def get(self, provider: str): # The role of the current user in the table must be admin or owner if not current_user.is_admin_or_owner: @@ -49,7 +63,19 @@ class OAuthDataSource(Resource): return {"data": auth_url}, 200 +@console_ns.route("/oauth/data-source/callback/") class OAuthDataSourceCallback(Resource): + @api.doc("oauth_data_source_callback") + @api.doc(description="Handle OAuth callback from data source provider") + @api.doc( + params={ + "provider": "Data source provider name (notion)", + "code": "Authorization code from OAuth provider", + "error": "Error message from OAuth provider", + } + ) + @api.response(302, "Redirect to console with result") + @api.response(400, "Invalid provider") def get(self, provider: str): OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -68,7 +94,19 @@ class OAuthDataSourceCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") +@console_ns.route("/oauth/data-source/binding/") class OAuthDataSourceBinding(Resource): + @api.doc("oauth_data_source_binding") + @api.doc(description="Bind OAuth data source with authorization code") + @api.doc( + params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"} + ) + @api.response( + 200, + "Data source binding success", + api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid provider or code") def get(self, provider: str): OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -90,7 +128,17 @@ class OAuthDataSourceBinding(Resource): return {"result": "success"}, 200 +@console_ns.route("/oauth/data-source///sync") class OAuthDataSourceSync(Resource): + @api.doc("oauth_data_source_sync") + @api.doc(description="Sync data from OAuth data source") + @api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"}) + @api.response( + 200, + "Data source sync success", + api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid provider or sync failed") @setup_required @login_required @account_initialization_required @@ -111,9 +159,3 @@ class OAuthDataSourceSync(Resource): return {"error": "OAuth data source process failed"}, 400 return {"result": "success"}, 200 - - -api.add_resource(OAuthDataSource, "/oauth/data-source/") -api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") -api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") -api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py new file mode 100644 index 0000000000..91de19a78a --- /dev/null +++ b/api/controllers/console/auth/email_register.py @@ -0,0 +1,155 @@ +from flask import request +from flask_restx import Resource, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from constants.languages import languages +from controllers.console import api +from controllers.console.auth.error import ( + EmailAlreadyInUseError, + EmailCodeError, + EmailRegisterLimitError, + InvalidEmailError, + InvalidTokenError, + PasswordMismatchError, +) +from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError +from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required +from extensions.ext_database import db +from libs.helper import email, extract_remote_ip +from libs.password import valid_password +from models.account import Account +from services.account_service import AccountService +from services.billing_service import BillingService +from services.errors.account import AccountNotFoundError, AccountRegisterError + + +class EmailRegisterSendEmailApi(Resource): + @setup_required + @email_password_login_enabled + @email_register_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + language = "en-US" + if args["language"] in languages: + language = args["language"] + + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): + raise AccountInFreezeError() + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + token = None + token = AccountService.send_email_register_email(email=args["email"], account=account, language=language) + return {"result": "success", "data": token} + + +class EmailRegisterCheckApi(Resource): + @setup_required + @email_password_login_enabled + @email_register_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + user_email = args["email"] + + is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"]) + if is_email_register_error_rate_limit: + raise EmailRegisterLimitError() + + token_data = AccountService.get_email_register_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_email_register_error_rate_limit(args["email"]) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_email_register_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_email_register_token( + user_email, code=args["code"], additional_data={"phase": "register"} + ) + + AccountService.reset_email_register_error_rate_limit(args["email"]) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class EmailRegisterResetApi(Resource): + @setup_required + @email_password_login_enabled + @email_register_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + args = parser.parse_args() + + # Validate passwords match + if args["new_password"] != args["password_confirm"]: + raise PasswordMismatchError() + + # Validate token and get register data + register_data = AccountService.get_email_register_data(args["token"]) + if not register_data: + raise InvalidTokenError() + # Must use token in reset phase + if register_data.get("phase", "") != "register": + raise InvalidTokenError() + + # Revoke token to prevent reuse + AccountService.revoke_email_register_token(args["token"]) + + email = register_data.get("email", "") + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + + if account: + raise EmailAlreadyInUseError() + else: + account = self._create_new_account(email, args["password_confirm"]) + if not account: + raise AccountNotFoundError() + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(email) + + return {"result": "success", "data": token_pair.model_dump()} + + def _create_new_account(self, email, password) -> Account | None: + # Create new account if allowed + account = None + try: + account = AccountService.create_account_and_tenant( + email=email, + name=email, + password=password, + interface_language=languages[0], + ) + except AccountRegisterError: + raise AccountInFreezeError() + + return account + + +api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email") +api.add_resource(EmailRegisterCheckApi, "/email-register/validity") +api.add_resource(EmailRegisterResetApi, "/email-register") diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index 7853bef917..81f1c6e70f 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -27,21 +27,43 @@ class InvalidTokenError(BaseHTTPException): class PasswordResetRateLimitExceededError(BaseHTTPException): error_code = "password_reset_rate_limit_exceeded" - description = "Too many password reset emails have been sent. Please try again in 1 minute." + description = "Too many password reset emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + + +class EmailRegisterRateLimitExceededError(BaseHTTPException): + error_code = "email_register_rate_limit_exceeded" + description = "Too many email register emails have been sent. Please try again in {minutes} minutes." + code = 429 + + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailChangeRateLimitExceededError(BaseHTTPException): error_code = "email_change_rate_limit_exceeded" - description = "Too many email change emails have been sent. Please try again in 1 minute." + description = "Too many email change emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class OwnerTransferRateLimitExceededError(BaseHTTPException): error_code = "owner_transfer_rate_limit_exceeded" - description = "Too many owner transfer emails have been sent. Please try again in 1 minute." + description = "Too many owner transfer emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailCodeError(BaseHTTPException): error_code = "email_code_error" @@ -69,15 +91,23 @@ class EmailPasswordLoginLimitError(BaseHTTPException): class EmailCodeLoginRateLimitExceededError(BaseHTTPException): error_code = "email_code_login_rate_limit_exceeded" - description = "Too many login emails have been sent. Please try again in 5 minutes." + description = "Too many login emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 5): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException): error_code = "email_code_account_deletion_rate_limit_exceeded" - description = "Too many account deletion emails have been sent. Please try again in 5 minutes." + description = "Too many account deletion emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 5): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailPasswordResetLimitError(BaseHTTPException): error_code = "email_password_reset_limit" @@ -85,6 +115,12 @@ class EmailPasswordResetLimitError(BaseHTTPException): code = 429 +class EmailRegisterLimitError(BaseHTTPException): + error_code = "email_register_limit" + description = "Too many failed email register attempts. Please try again in 24 hours." + code = 429 + + class EmailChangeLimitError(BaseHTTPException): error_code = "email_change_limit" description = "Too many failed email change attempts. Please try again in 24 hours." diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index ede0696854..36ccb1d562 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,12 +2,11 @@ import base64 import secrets from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session -from constants.languages import languages -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.auth.error import ( EmailCodeError, EmailPasswordResetLimitError, @@ -15,7 +14,7 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) -from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError +from controllers.console.error import AccountNotFound, EmailSendIpLimitError from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db @@ -23,12 +22,35 @@ from libs.helper import email, extract_remote_ip from libs.password import hash_password, valid_password from models.account import Account from services.account_service import AccountService, TenantService -from services.errors.account import AccountRegisterError -from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService +@console_ns.route("/forgot-password") class ForgotPasswordSendEmailApi(Resource): + @api.doc("send_forgot_password_email") + @api.doc(description="Send password reset email") + @api.expect( + api.model( + "ForgotPasswordEmailRequest", + { + "email": fields.String(required=True, description="Email address"), + "language": fields.String(description="Language for email (zh-Hans/en-US)"), + }, + ) + ) + @api.response( + 200, + "Email sent successfully", + api.model( + "ForgotPasswordEmailResponse", + { + "result": fields.String(description="Operation result"), + "data": fields.String(description="Reset token"), + "code": fields.String(description="Error code if account not found"), + }, + ), + ) + @api.response(400, "Invalid email or rate limit exceeded") @setup_required @email_password_login_enabled def post(self): @@ -48,20 +70,44 @@ class ForgotPasswordSendEmailApi(Resource): with Session(db.engine) as session: account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() - token = None - if account is None: - if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_reset_password_email(email=args["email"], language=language) - return {"result": "fail", "data": token, "code": "account_not_found"} - else: - raise AccountNotFound() - else: - token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) + + token = AccountService.send_reset_password_email( + account=account, + email=args["email"], + language=language, + is_allow_register=FeatureService.get_system_features().is_allow_register, + ) return {"result": "success", "data": token} +@console_ns.route("/forgot-password/validity") class ForgotPasswordCheckApi(Resource): + @api.doc("check_forgot_password_code") + @api.doc(description="Verify password reset code") + @api.expect( + api.model( + "ForgotPasswordCheckRequest", + { + "email": fields.String(required=True, description="Email address"), + "code": fields.String(required=True, description="Verification code"), + "token": fields.String(required=True, description="Reset token"), + }, + ) + ) + @api.response( + 200, + "Code verified successfully", + api.model( + "ForgotPasswordCheckResponse", + { + "is_valid": fields.Boolean(description="Whether code is valid"), + "email": fields.String(description="Email address"), + "token": fields.String(description="New reset token"), + }, + ), + ) + @api.response(400, "Invalid code or token") @setup_required @email_password_login_enabled def post(self): @@ -100,7 +146,26 @@ class ForgotPasswordCheckApi(Resource): return {"is_valid": True, "email": token_data.get("email"), "token": new_token} +@console_ns.route("/forgot-password/resets") class ForgotPasswordResetApi(Resource): + @api.doc("reset_password") + @api.doc(description="Reset password with verification token") + @api.expect( + api.model( + "ForgotPasswordResetRequest", + { + "token": fields.String(required=True, description="Verification token"), + "new_password": fields.String(required=True, description="New password"), + "password_confirm": fields.String(required=True, description="Password confirmation"), + }, + ) + ) + @api.response( + 200, + "Password reset successfully", + api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid token or password mismatch") @setup_required @email_password_login_enabled def post(self): @@ -137,7 +202,7 @@ class ForgotPasswordResetApi(Resource): if account: self._update_existing_account(account, password_hashed, salt, session) else: - self._create_new_account(email, args["password_confirm"]) + raise AccountNotFound() return {"result": "success"} @@ -157,22 +222,6 @@ class ForgotPasswordResetApi(Resource): account.current_tenant = tenant tenant_was_created.send(tenant) - def _create_new_account(self, email, password): - # Create new account if allowed - try: - AccountService.create_account_and_tenant( - email=email, - name=email, - password=password, - interface_language=languages[0], - ) - except WorkSpaceNotAllowedCreateError: - pass - except WorkspacesLimitExceededError: - pass - except AccountRegisterError: - raise AccountInFreezeError() - api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index b11bc0c6ac..3b35ab3c23 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -26,7 +26,6 @@ from controllers.console.error import ( from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip -from libs.password import valid_password from models.account import Account from services.account_service import AccountService, RegisterService, TenantService from services.billing_service import BillingService @@ -44,10 +43,9 @@ class LoginApi(Resource): """Authenticate user and login.""" parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("password", type=valid_password, required=True, location="json") + parser.add_argument("password", type=str, required=True, location="json") parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") parser.add_argument("invite_token", type=str, required=False, default=None, location="json") - parser.add_argument("language", type=str, required=False, default="en-US", location="json") args = parser.parse_args() if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): @@ -61,11 +59,6 @@ class LoginApi(Resource): if invitation: invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) - if args["language"] is not None and args["language"] == "zh-Hans": - language = "zh-Hans" - else: - language = "en-US" - try: if invitation: data = invitation.get("data", {}) @@ -80,12 +73,6 @@ class LoginApi(Resource): except services.errors.account.AccountPasswordError: AccountService.add_login_error_rate_limit(args["email"]) raise AuthenticationFailedError() - except services.errors.account.AccountNotFoundError: - if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_reset_password_email(email=args["email"], language=language) - return {"result": "fail", "data": token, "code": "account_not_found"} - else: - raise AccountNotFound() # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: @@ -133,13 +120,12 @@ class ResetPasswordSendEmailApi(Resource): except AccountRegisterError: raise AccountInFreezeError() - if account is None: - if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_reset_password_email(email=args["email"], language=language) - else: - raise AccountNotFound() - else: - token = AccountService.send_reset_password_email(account=account, language=language) + token = AccountService.send_reset_password_email( + email=args["email"], + account=account, + language=language, + is_allow_register=FeatureService.get_system_features().is_allow_register, + ) return {"result": "success", "data": token} diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 332a98c474..1602ee6eea 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import requests from flask import current_app, redirect, request @@ -18,11 +17,12 @@ from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models import Account from models.account import AccountStatus from services.account_service import AccountService, RegisterService, TenantService +from services.billing_service import BillingService from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.feature_service import FeatureService -from .. import api +from .. import api, console_ns logger = logging.getLogger(__name__) @@ -50,7 +50,13 @@ def get_oauth_providers(): return OAUTH_PROVIDERS +@console_ns.route("/oauth/login/") class OAuthLogin(Resource): + @api.doc("oauth_login") + @api.doc(description="Initiate OAuth login process") + @api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"}) + @api.response(302, "Redirect to OAuth authorization URL") + @api.response(400, "Invalid provider") def get(self, provider: str): invite_token = request.args.get("invite_token") or None OAUTH_PROVIDERS = get_oauth_providers() @@ -63,7 +69,19 @@ class OAuthLogin(Resource): return redirect(auth_url) +@console_ns.route("/oauth/authorize/") class OAuthCallback(Resource): + @api.doc("oauth_callback") + @api.doc(description="Handle OAuth callback and complete login process") + @api.doc( + params={ + "provider": "OAuth provider name (github/google)", + "code": "Authorization code from OAuth provider", + "state": "Optional state parameter (used for invite token)", + } + ) + @api.response(302, "Redirect to console with access token") + @api.response(400, "OAuth process failed") def get(self, provider: str): OAUTH_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -77,6 +95,9 @@ class OAuthCallback(Resource): if state: invite_token = state + if not code: + return {"error": "Authorization code is required"}, 400 + try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) @@ -86,7 +107,7 @@ class OAuthCallback(Resource): return {"error": "OAuth process failed"}, 400 if invite_token and RegisterService.is_valid_invite_token(invite_token): - invitation = RegisterService._get_invitation_by_token(token=invite_token) + invitation = RegisterService.get_invitation_by_token(token=invite_token) if invitation: invitation_email = invitation.get("email", None) if invitation_email != user_info.email: @@ -135,8 +156,8 @@ class OAuthCallback(Resource): ) -def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: - account: Optional[Account] = Account.get_by_openid(provider, user_info.id) +def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None: + account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: with Session(db.engine) as session: @@ -162,7 +183,15 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): if not account: if not FeatureService.get_system_features().is_allow_register: - raise AccountNotFoundError() + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email): + raise AccountRegisterError( + description=( + "This email account has been deleted within the past " + "30 days and is temporarily unavailable for new account registration" + ) + ) + else: + raise AccountRegisterError(description=("Invalid email or password")) account_name = user_info.name or "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider @@ -181,7 +210,3 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): AccountService.link_account_integrate(provider, user_info.id, account) return account - - -api.add_resource(OAuthLogin, "/oauth/login/") -api.add_resource(OAuthCallback, "/oauth/authorize/") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 04a2aa6594..3a9530af84 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -34,14 +34,12 @@ class DataSourceApi(Resource): @marshal_with(integrate_list_fields) def get(self): # get workspace data source integrates - data_source_integrates = ( - db.session.query(DataSourceOauthBinding) - .where( + data_source_integrates = db.session.scalars( + select(DataSourceOauthBinding).where( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.disabled == False, ) - .all() - ) + ).all() base_url = request.url_root.rstrip("/") data_source_oauth_base_path = "/console/api/oauth/data-source" diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ef1fc5a958..74fb07f897 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,12 +1,13 @@ import flask_restx from flask import request from flask_login import current_user -from flask_restx import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with, reqparse +from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services from configs import dify_config -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError @@ -47,7 +48,21 @@ def _validate_description_length(description): return description +@console_ns.route("/datasets") class DatasetListApi(Resource): + @api.doc("get_datasets") + @api.doc(description="Get list of datasets") + @api.doc( + params={ + "page": "Page number (default: 1)", + "limit": "Number of items per page (default: 20)", + "ids": "Filter by dataset IDs (list)", + "keyword": "Search keyword", + "tag_ids": "Filter by tag IDs (list)", + "include_all": "Include all datasets (default: false)", + } + ) + @api.response(200, "Datasets retrieved successfully") @setup_required @login_required @account_initialization_required @@ -99,6 +114,24 @@ class DatasetListApi(Resource): response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 + @api.doc("create_dataset") + @api.doc(description="Create a new dataset") + @api.expect( + api.model( + "CreateDatasetRequest", + { + "name": fields.String(required=True, description="Dataset name (1-40 characters)"), + "description": fields.String(description="Dataset description (max 400 characters)"), + "indexing_technique": fields.String(description="Indexing technique"), + "permission": fields.String(description="Dataset permission"), + "provider": fields.String(description="Provider"), + "external_knowledge_api_id": fields.String(description="External knowledge API ID"), + "external_knowledge_id": fields.String(description="External knowledge ID"), + }, + ) + ) + @api.response(201, "Dataset created successfully") + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -171,7 +204,14 @@ class DatasetListApi(Resource): return marshal(dataset, dataset_detail_fields), 201 +@console_ns.route("/datasets/") class DatasetApi(Resource): + @api.doc("get_dataset") + @api.doc(description="Get dataset details") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Dataset retrieved successfully", dataset_detail_fields) + @api.response(404, "Dataset not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -214,6 +254,23 @@ class DatasetApi(Resource): return data, 200 + @api.doc("update_dataset") + @api.doc(description="Update dataset details") + @api.expect( + api.model( + "UpdateDatasetRequest", + { + "name": fields.String(description="Dataset name"), + "description": fields.String(description="Dataset description"), + "permission": fields.String(description="Dataset permission"), + "indexing_technique": fields.String(description="Indexing technique"), + "external_retrieval_model": fields.Raw(description="External retrieval model settings"), + }, + ) + ) + @api.response(200, "Dataset updated successfully", dataset_detail_fields) + @api.response(404, "Dataset not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -339,7 +396,7 @@ class DatasetApi(Resource): dataset_id_str = str(dataset_id) # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor or current_user.is_dataset_operator: + if not (current_user.is_editor or current_user.is_dataset_operator): raise Forbidden() try: @@ -352,7 +409,12 @@ class DatasetApi(Resource): raise DatasetInUseError() +@console_ns.route("/datasets//use-check") class DatasetUseCheckApi(Resource): + @api.doc("check_dataset_use") + @api.doc(description="Check if dataset is in use") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Dataset use status retrieved successfully") @setup_required @login_required @account_initialization_required @@ -363,7 +425,12 @@ class DatasetUseCheckApi(Resource): return {"is_using": dataset_is_using}, 200 +@console_ns.route("/datasets//queries") class DatasetQueryApi(Resource): + @api.doc("get_dataset_queries") + @api.doc(description="Get dataset query history") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Query history retrieved successfully", dataset_query_detail_fields) @setup_required @login_required @account_initialization_required @@ -393,7 +460,11 @@ class DatasetQueryApi(Resource): return response, 200 +@console_ns.route("/datasets/indexing-estimate") class DatasetIndexingEstimateApi(Resource): + @api.doc("estimate_dataset_indexing") + @api.doc(description="Estimate dataset indexing cost") + @api.response(200, "Indexing estimate calculated successfully") @setup_required @login_required @account_initialization_required @@ -420,11 +491,11 @@ class DatasetIndexingEstimateApi(Resource): extract_settings = [] if args["info_list"]["data_source_type"] == "upload_file": file_ids = args["info_list"]["file_info_list"]["file_ids"] - file_details = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) - .all() - ) + file_details = db.session.scalars( + select(UploadFile).where( + UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids) + ) + ).all() if file_details is None: raise NotFound("File not found.") @@ -496,7 +567,12 @@ class DatasetIndexingEstimateApi(Resource): return response.model_dump(), 200 +@console_ns.route("/datasets//related-apps") class DatasetRelatedAppListApi(Resource): + @api.doc("get_dataset_related_apps") + @api.doc(description="Get applications related to dataset") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Related apps retrieved successfully", related_app_list) @setup_required @login_required @account_initialization_required @@ -523,17 +599,22 @@ class DatasetRelatedAppListApi(Resource): return {"data": related_apps, "total": len(related_apps)}, 200 +@console_ns.route("/datasets//indexing-status") class DatasetIndexingStatusApi(Resource): + @api.doc("get_dataset_indexing_status") + @api.doc(description="Get dataset indexing status") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Indexing status retrieved successfully") @setup_required @login_required @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - documents = ( - db.session.query(Document) - .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) - .all() - ) + documents = db.session.scalars( + select(Document).where( + Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id + ) + ).all() documents_status = [] for document in documents: completed_segments = ( @@ -570,21 +651,25 @@ class DatasetIndexingStatusApi(Resource): return data, 200 +@console_ns.route("/datasets/api-keys") class DatasetApiKeyApi(Resource): max_keys = 10 token_prefix = "dataset-" resource_type = "dataset" + @api.doc("get_dataset_api_keys") + @api.doc(description="Get dataset API keys") + @api.response(200, "API keys retrieved successfully", api_key_list) @setup_required @login_required @account_initialization_required @marshal_with(api_key_list) def get(self): - keys = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) - .all() - ) + keys = db.session.scalars( + select(ApiToken).where( + ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id + ) + ).all() return {"items": keys} @setup_required @@ -619,9 +704,14 @@ class DatasetApiKeyApi(Resource): return api_token, 200 +@console_ns.route("/datasets/api-keys/") class DatasetApiDeleteApi(Resource): resource_type = "dataset" + @api.doc("delete_dataset_api_key") + @api.doc(description="Delete dataset API key") + @api.doc(params={"api_key_id": "API key ID"}) + @api.response(204, "API key deleted successfully") @setup_required @login_required @account_initialization_required @@ -662,7 +752,11 @@ class DatasetEnableApiApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/datasets/api-base-info") class DatasetApiBaseUrlApi(Resource): + @api.doc("get_dataset_api_base_info") + @api.doc(description="Get dataset API base information") + @api.response(200, "API base info retrieved successfully") @setup_required @login_required @account_initialization_required @@ -670,7 +764,11 @@ class DatasetApiBaseUrlApi(Resource): return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} +@console_ns.route("/datasets/retrieval-setting") class DatasetRetrievalSettingApi(Resource): + @api.doc("get_dataset_retrieval_setting") + @api.doc(description="Get dataset retrieval settings") + @api.response(200, "Retrieval settings retrieved successfully") @setup_required @login_required @account_initialization_required @@ -721,7 +819,12 @@ class DatasetRetrievalSettingApi(Resource): raise ValueError(f"Unsupported vector db type {vector_type}.") +@console_ns.route("/datasets/retrieval-setting/") class DatasetRetrievalSettingMockApi(Resource): + @api.doc("get_dataset_retrieval_setting_mock") + @api.doc(description="Get mock dataset retrieval settings by vector type") + @api.doc(params={"vector_type": "Vector store type"}) + @api.response(200, "Mock retrieval settings retrieved successfully") @setup_required @login_required @account_initialization_required @@ -770,7 +873,13 @@ class DatasetRetrievalSettingMockApi(Resource): raise ValueError(f"Unsupported vector db type {vector_type}.") +@console_ns.route("/datasets//error-docs") class DatasetErrorDocs(Resource): + @api.doc("get_dataset_error_docs") + @api.doc(description="Get dataset error documents") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Error documents retrieved successfully") + @api.response(404, "Dataset not found") @setup_required @login_required @account_initialization_required @@ -784,7 +893,14 @@ class DatasetErrorDocs(Resource): return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200 +@console_ns.route("/datasets//permission-part-users") class DatasetPermissionUserListApi(Resource): + @api.doc("get_dataset_permission_users") + @api.doc(description="Get dataset permission user list") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Permission users retrieved successfully") + @api.response(404, "Dataset not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -805,7 +921,13 @@ class DatasetPermissionUserListApi(Resource): }, 200 +@console_ns.route("/datasets//auto-disable-logs") class DatasetAutoDisableLogApi(Resource): + @api.doc("get_dataset_auto_disable_logs") + @api.doc(description="Get dataset auto disable logs") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Auto disable logs retrieved successfully") + @api.response(404, "Dataset not found") @setup_required @login_required @account_initialization_required @@ -815,21 +937,3 @@ class DatasetAutoDisableLogApi(Resource): if dataset is None: raise NotFound("Dataset not found.") return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200 - - -api.add_resource(DatasetListApi, "/datasets") -api.add_resource(DatasetApi, "/datasets/") -api.add_resource(DatasetUseCheckApi, "/datasets//use-check") -api.add_resource(DatasetQueryApi, "/datasets//queries") -api.add_resource(DatasetErrorDocs, "/datasets//error-docs") -api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate") -api.add_resource(DatasetRelatedAppListApi, "/datasets//related-apps") -api.add_resource(DatasetIndexingStatusApi, "/datasets//indexing-status") -api.add_resource(DatasetApiKeyApi, "/datasets/api-keys") -api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/") -api.add_resource(DatasetEnableApiApi, "/datasets//") -api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") -api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") -api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") -api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users") -api.add_resource(DatasetAutoDisableLogApi, "/datasets//auto-disable-logs") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 3f30ceab43..1b39f651fd 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,16 +1,17 @@ import json import logging from argparse import ArgumentTypeError +from collections.abc import Sequence from typing import Literal, cast from flask import request from flask_login import current_user -from flask_restx import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with, reqparse from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound import services -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -81,7 +82,7 @@ class DocumentResource(Resource): return document - def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: + def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") @@ -99,7 +100,12 @@ class DocumentResource(Resource): return documents +@console_ns.route("/datasets/process-rule") class GetProcessRuleApi(Resource): + @api.doc("get_process_rule") + @api.doc(description="Get dataset document processing rules") + @api.doc(params={"document_id": "Document ID (optional)"}) + @api.response(200, "Process rules retrieved successfully") @setup_required @login_required @account_initialization_required @@ -141,7 +147,21 @@ class GetProcessRuleApi(Resource): return {"mode": mode, "rules": rules, "limits": limits} +@console_ns.route("/datasets//documents") class DatasetDocumentListApi(Resource): + @api.doc("get_dataset_documents") + @api.doc(description="Get documents in a dataset") + @api.doc( + params={ + "dataset_id": "Dataset ID", + "page": "Page number (default: 1)", + "limit": "Number of items per page (default: 20)", + "keyword": "Search keyword", + "sort": "Sort order (default: -created_at)", + "fetch": "Fetch full details (default: false)", + } + ) + @api.response(200, "Documents retrieved successfully") @setup_required @login_required @account_initialization_required @@ -325,7 +345,23 @@ class DatasetDocumentListApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/datasets/init") class DatasetInitApi(Resource): + @api.doc("init_dataset") + @api.doc(description="Initialize dataset with documents") + @api.expect( + api.model( + "DatasetInitRequest", + { + "upload_file_id": fields.String(required=True, description="Upload file ID"), + "indexing_technique": fields.String(description="Indexing technique"), + "process_rule": fields.Raw(description="Processing rules"), + "data_source": fields.Raw(description="Data source configuration"), + }, + ) + ) + @api.response(201, "Dataset initialized successfully", dataset_and_document_fields) + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -395,7 +431,14 @@ class DatasetInitApi(Resource): return response +@console_ns.route("/datasets//documents//indexing-estimate") class DocumentIndexingEstimateApi(DocumentResource): + @api.doc("estimate_document_indexing") + @api.doc(description="Estimate document indexing cost") + @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @api.response(200, "Indexing estimate calculated successfully") + @api.response(404, "Document not found") + @api.response(400, "Document already finished") @setup_required @login_required @account_initialization_required @@ -595,7 +638,13 @@ class DocumentBatchIndexingStatusApi(DocumentResource): return data +@console_ns.route("/datasets//documents//indexing-status") class DocumentIndexingStatusApi(DocumentResource): + @api.doc("get_document_indexing_status") + @api.doc(description="Get document indexing status") + @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @api.response(200, "Indexing status retrieved successfully") + @api.response(404, "Document not found") @setup_required @login_required @account_initialization_required @@ -637,9 +686,21 @@ class DocumentIndexingStatusApi(DocumentResource): return marshal(document_dict, document_status_fields) +@console_ns.route("/datasets//documents/") class DocumentApi(DocumentResource): METADATA_CHOICES = {"all", "only", "without"} + @api.doc("get_document") + @api.doc(description="Get document details") + @api.doc( + params={ + "dataset_id": "Dataset ID", + "document_id": "Document ID", + "metadata": "Metadata inclusion (all/only/without)", + } + ) + @api.response(200, "Document retrieved successfully") + @api.response(404, "Document not found") @setup_required @login_required @account_initialization_required @@ -748,7 +809,16 @@ class DocumentApi(DocumentResource): return {"result": "success"}, 204 +@console_ns.route("/datasets//documents//processing/") class DocumentProcessingApi(DocumentResource): + @api.doc("update_document_processing") + @api.doc(description="Update document processing status (pause/resume)") + @api.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"} + ) + @api.response(200, "Processing status updated successfully") + @api.response(404, "Document not found") + @api.response(400, "Invalid action") @setup_required @login_required @account_initialization_required @@ -783,7 +853,23 @@ class DocumentProcessingApi(DocumentResource): return {"result": "success"}, 200 +@console_ns.route("/datasets//documents//metadata") class DocumentMetadataApi(DocumentResource): + @api.doc("update_document_metadata") + @api.doc(description="Update document metadata") + @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @api.expect( + api.model( + "UpdateDocumentMetadataRequest", + { + "doc_type": fields.String(description="Document type"), + "doc_metadata": fields.Raw(description="Document metadata"), + }, + ) + ) + @api.response(200, "Document metadata updated successfully") + @api.response(404, "Document not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 043f39f623..e8f5a11b41 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,10 +1,10 @@ from flask import request from flask_login import current_user -from flask_restx import Resource, marshal, reqparse +from flask_restx import Resource, fields, marshal, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import account_initialization_required, setup_required from fields.dataset_fields import dataset_detail_fields @@ -21,7 +21,18 @@ def _validate_name(name): return name +@console_ns.route("/datasets/external-knowledge-api") class ExternalApiTemplateListApi(Resource): + @api.doc("get_external_api_templates") + @api.doc(description="Get external knowledge API templates") + @api.doc( + params={ + "page": "Page number (default: 1)", + "limit": "Number of items per page (default: 20)", + "keyword": "Search keyword", + } + ) + @api.response(200, "External API templates retrieved successfully") @setup_required @login_required @account_initialization_required @@ -79,7 +90,13 @@ class ExternalApiTemplateListApi(Resource): return external_knowledge_api.to_dict(), 201 +@console_ns.route("/datasets/external-knowledge-api/") class ExternalApiTemplateApi(Resource): + @api.doc("get_external_api_template") + @api.doc(description="Get external knowledge API template details") + @api.doc(params={"external_knowledge_api_id": "External knowledge API ID"}) + @api.response(200, "External API template retrieved successfully") + @api.response(404, "Template not found") @setup_required @login_required @account_initialization_required @@ -131,14 +148,19 @@ class ExternalApiTemplateApi(Resource): external_knowledge_api_id = str(external_knowledge_api_id) # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor or current_user.is_dataset_operator: + if not (current_user.is_editor or current_user.is_dataset_operator): raise Forbidden() ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id) return {"result": "success"}, 204 +@console_ns.route("/datasets/external-knowledge-api//use-check") class ExternalApiUseCheckApi(Resource): + @api.doc("check_external_api_usage") + @api.doc(description="Check if external knowledge API is being used") + @api.doc(params={"external_knowledge_api_id": "External knowledge API ID"}) + @api.response(200, "Usage check completed successfully") @setup_required @login_required @account_initialization_required @@ -151,7 +173,24 @@ class ExternalApiUseCheckApi(Resource): return {"is_using": external_knowledge_api_is_using, "count": count}, 200 +@console_ns.route("/datasets/external") class ExternalDatasetCreateApi(Resource): + @api.doc("create_external_dataset") + @api.doc(description="Create external knowledge dataset") + @api.expect( + api.model( + "CreateExternalDatasetRequest", + { + "external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"), + "external_knowledge_id": fields.String(required=True, description="External knowledge ID"), + "name": fields.String(required=True, description="Dataset name"), + "description": fields.String(description="Dataset description"), + }, + ) + ) + @api.response(201, "External dataset created successfully", dataset_detail_fields) + @api.response(400, "Invalid parameters") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -191,7 +230,24 @@ class ExternalDatasetCreateApi(Resource): return marshal(dataset, dataset_detail_fields), 201 +@console_ns.route("/datasets//external-hit-testing") class ExternalKnowledgeHitTestingApi(Resource): + @api.doc("test_external_knowledge_retrieval") + @api.doc(description="Test external knowledge retrieval for dataset") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.expect( + api.model( + "ExternalHitTestingRequest", + { + "query": fields.String(required=True, description="Query text for testing"), + "retrieval_model": fields.Raw(description="Retrieval model configuration"), + "external_retrieval_model": fields.Raw(description="External retrieval model configuration"), + }, + ) + ) + @api.response(200, "External hit testing completed successfully") + @api.response(404, "Dataset not found") + @api.response(400, "Invalid parameters") @setup_required @login_required @account_initialization_required @@ -228,8 +284,22 @@ class ExternalKnowledgeHitTestingApi(Resource): raise InternalServerError(str(e)) +@console_ns.route("/test/retrieval") class BedrockRetrievalApi(Resource): # this api is only for internal testing + @api.doc("bedrock_retrieval_test") + @api.doc(description="Bedrock retrieval test (internal use only)") + @api.expect( + api.model( + "BedrockRetrievalTestRequest", + { + "retrieval_setting": fields.Raw(required=True, description="Retrieval settings"), + "query": fields.String(required=True, description="Query text"), + "knowledge_id": fields.String(required=True, description="Knowledge ID"), + }, + ) + ) + @api.response(200, "Bedrock retrieval test completed") def post(self): parser = reqparse.RequestParser() parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") @@ -247,12 +317,3 @@ class BedrockRetrievalApi(Resource): args["retrieval_setting"], args["query"], args["knowledge_id"] ) return result, 200 - - -api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets//external-hit-testing") -api.add_resource(ExternalDatasetCreateApi, "/datasets/external") -api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api") -api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/") -api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api//use-check") -# this api is only for internal test -api.add_resource(BedrockRetrievalApi, "/test/retrieval") diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 2ad192571b..abaca88090 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,6 +1,6 @@ -from flask_restx import Resource +from flask_restx import Resource, fields -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.console.wraps import ( account_initialization_required, @@ -10,7 +10,25 @@ from controllers.console.wraps import ( from libs.login import login_required +@console_ns.route("/datasets//hit-testing") class HitTestingApi(Resource, DatasetsHitTestingBase): + @api.doc("test_dataset_retrieval") + @api.doc(description="Test dataset knowledge retrieval") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.expect( + api.model( + "HitTestingRequest", + { + "query": fields.String(required=True, description="Query text for testing"), + "retrieval_model": fields.Raw(description="Retrieval model configuration"), + "top_k": fields.Integer(description="Number of top results to return"), + "score_threshold": fields.Float(description="Score threshold for filtering results"), + }, + ) + ) + @api.response(200, "Hit testing completed successfully") + @api.response(404, "Dataset not found") + @api.response(400, "Invalid parameters") @setup_required @login_required @account_initialization_required @@ -23,6 +41,3 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) - - -api.add_resource(HitTestingApi, "/datasets//hit-testing") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 964de0a863..c70343ec95 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -950,6 +950,12 @@ class RagPipelineTransformApi(Resource): @login_required @account_initialization_required def post(self, dataset_id): + if not isinstance(current_user, Account): + raise Forbidden() + + if not (current_user.is_editor or current_user.is_dataset_operator): + raise Forbidden() + dataset_id = str(dataset_id) rag_pipeline_transform_service = RagPipelineTransformService() result = rag_pipeline_transform_service.transform_dataset(dataset_id) diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index bdaa268462..b9c1f65bfd 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -1,13 +1,32 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService +@console_ns.route("/website/crawl") class WebsiteCrawlApi(Resource): + @api.doc("crawl_website") + @api.doc(description="Crawl website content") + @api.expect( + api.model( + "WebsiteCrawlRequest", + { + "provider": fields.String( + required=True, + description="Crawl provider (firecrawl/watercrawl/jinareader)", + enum=["firecrawl", "watercrawl", "jinareader"], + ), + "url": fields.String(required=True, description="URL to crawl"), + "options": fields.Raw(required=True, description="Crawl options"), + }, + ) + ) + @api.response(200, "Website crawl initiated successfully") + @api.response(400, "Invalid crawl parameters") @setup_required @login_required @account_initialization_required @@ -39,7 +58,14 @@ class WebsiteCrawlApi(Resource): return result, 200 +@console_ns.route("/website/crawl/status/") class WebsiteCrawlStatusApi(Resource): + @api.doc("get_crawl_status") + @api.doc(description="Get website crawl status") + @api.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"}) + @api.response(200, "Crawl status retrieved successfully") + @api.response(404, "Crawl job not found") + @api.response(400, "Invalid provider") @setup_required @login_required @account_initialization_required @@ -62,7 +88,3 @@ class WebsiteCrawlStatusApi(Resource): except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 - - -api.add_resource(WebsiteCrawlApi, "/website/crawl") -api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/") diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index cc46f54ea3..a99708b7cd 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -28,6 +27,8 @@ from extensions.ext_database import db from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -57,6 +58,8 @@ class CompletionApi(InstalledAppResource): db.session.commit() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) @@ -90,6 +93,8 @@ class CompletionStopApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 @@ -117,6 +122,8 @@ class ChatApi(InstalledAppResource): db.session.commit() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) @@ -153,6 +160,8 @@ class ChatStopApi(InstalledAppResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 43ad3ecfbd..1aef9c544d 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy.orm import Session @@ -10,6 +9,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError @@ -35,6 +36,8 @@ class ConversationListApi(InstalledAppResource): pinned = args["pinned"] == "true" try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") with Session(db.engine) as session: return WebConversationService.pagination_by_last_id( session=session, @@ -58,6 +61,8 @@ class ConversationApi(InstalledAppResource): conversation_id = str(c_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -81,6 +86,8 @@ class ConversationRenameApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return ConversationService.rename( app_model, conversation_id, current_user, args["name"], args["auto_generate"] ) @@ -98,6 +105,8 @@ class ConversationPinApi(InstalledAppResource): conversation_id = str(c_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") WebConversationService.pin(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -113,6 +122,8 @@ class ConversationUnPinApi(InstalledAppResource): raise NotChatAppError() conversation_id = str(c_id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") WebConversationService.unpin(app_model, conversation_id, current_user) return {"result": "success"} diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3ccedd654b..bdc3fb0dbd 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,9 +2,8 @@ import logging from typing import Any from flask import request -from flask_login import current_user from flask_restx import Resource, inputs, marshal_with, reqparse -from sqlalchemy import and_ +from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound from controllers.console import api @@ -13,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields from libs.datetime_utils import naive_utc_now -from libs.login import login_required -from models import App, InstalledApp, RecommendedApp +from libs.login import current_user, login_required +from models import Account, App, InstalledApp, RecommendedApp from services.account_service import TenantService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService @@ -29,17 +28,23 @@ class InstalledAppsListApi(Resource): @marshal_with(installed_app_list_fields) def get(self): app_id = request.args.get("app_id", default=None, type=str) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") current_tenant_id = current_user.current_tenant_id if app_id: - installed_apps = ( - db.session.query(InstalledApp) - .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) - .all() - ) + installed_apps = db.session.scalars( + select(InstalledApp).where( + and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id) + ) + ).all() else: - installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() + installed_apps = db.session.scalars( + select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id) + ).all() + if current_user.current_tenant is None: + raise ValueError("current_user.current_tenant must not be None") current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_app_list: list[dict[str, Any]] = [ { @@ -115,6 +120,8 @@ class InstalledAppsListApi(Resource): if recommended_app is None: raise NotFound("App not found") + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") current_tenant_id = current_user.current_tenant_id app = db.session.query(App).where(App.id == args["app_id"]).first() @@ -154,6 +161,8 @@ class InstalledAppApi(InstalledAppResource): """ def delete(self, installed_app): + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") if installed_app.app_owner_tenant_id == current_user.current_tenant_id: raise BadRequest("You can't uninstall an app owned by the current tenant") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 608bc6d007..c46c1c1f4f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound @@ -24,6 +23,8 @@ from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs import helper from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -54,6 +55,8 @@ class MessageListApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return MessageService.pagination_by_first_id( app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] ) @@ -75,6 +78,8 @@ class MessageFeedbackApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") MessageService.create_feedback( app_model=app_model, message_id=message_id, @@ -105,6 +110,8 @@ class MessageMoreLikeThisApi(InstalledAppResource): streaming = args["response_mode"] == "streaming" try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate_more_like_this( app_model=app_model, user=current_user, @@ -142,6 +149,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource): message_id = str(message_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE ) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index d9afb5bab2..7742ea24a9 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -20,7 +20,7 @@ class AppParameterApi(InstalledAppResource): if app_model is None: raise AppUnavailableError() - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 62f9350b71..974222ddf7 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,11 +1,10 @@ -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from constants.languages import languages from controllers.console import api from controllers.console.wraps import account_initialization_required from libs.helper import AppIconUrlField -from libs.login import login_required +from libs.login import current_user, login_required from services.recommended_app_service import RecommendedAppService app_fields = { @@ -46,8 +45,9 @@ class RecommendedAppListApi(Resource): parser.add_argument("language", type=str, location="args") args = parser.parse_args() - if args.get("language") and args.get("language") in languages: - language_prefix = args.get("language") + language = args.get("language") + if language and language in languages: + language_prefix = language elif current_user and current_user.interface_language: language_prefix = current_user.interface_language else: diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 5353dbcad5..6f05f898f9 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound @@ -8,6 +7,8 @@ from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields from libs.helper import TimestampField, uuid_value +from libs.login import current_user +from models import Account from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService @@ -42,6 +43,8 @@ class SavedMessageListApi(InstalledAppResource): parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) def post(self, installed_app): @@ -54,6 +57,8 @@ class SavedMessageListApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") SavedMessageService.save(app_model, current_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -70,6 +75,8 @@ class SavedMessageApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") SavedMessageService.delete(app_model, current_user, message_id) return {"result": "success"}, 204 diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 6401f804c0..3a8ba64a03 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Concatenate, Optional, ParamSpec, TypeVar +from typing import Concatenate, ParamSpec, TypeVar from flask_login import current_user from flask_restx import Resource @@ -20,7 +20,7 @@ R = TypeVar("R") T = TypeVar("T") -def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): +def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None): def decorator(view: Callable[Concatenate[InstalledApp, P], R]): @wraps(view) def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): @@ -50,7 +50,7 @@ def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], return decorator -def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): +def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None): def decorator(view: Callable[Concatenate[InstalledApp, P], R]): @wraps(view) def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index e157041c35..57f5ab191e 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,8 +1,8 @@ from flask_login import current_user -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from constants import HIDDEN_VALUE -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import login_required @@ -11,7 +11,21 @@ from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService +@console_ns.route("/code-based-extension") class CodeBasedExtensionAPI(Resource): + @api.doc("get_code_based_extension") + @api.doc(description="Get code-based extension data by module name") + @api.expect( + api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name") + ) + @api.response( + 200, + "Success", + api.model( + "CodeBasedExtensionResponse", + {"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")}, + ), + ) @setup_required @login_required @account_initialization_required @@ -23,7 +37,11 @@ class CodeBasedExtensionAPI(Resource): return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} +@console_ns.route("/api-based-extension") class APIBasedExtensionAPI(Resource): + @api.doc("get_api_based_extensions") + @api.doc(description="Get all API-based extensions for current tenant") + @api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields))) @setup_required @login_required @account_initialization_required @@ -32,6 +50,19 @@ class APIBasedExtensionAPI(Resource): tenant_id = current_user.current_tenant_id return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + @api.doc("create_api_based_extension") + @api.doc(description="Create a new API-based extension") + @api.expect( + api.model( + "CreateAPIBasedExtensionRequest", + { + "name": fields.String(required=True, description="Extension name"), + "api_endpoint": fields.String(required=True, description="API endpoint URL"), + "api_key": fields.String(required=True, description="API key for authentication"), + }, + ) + ) + @api.response(201, "Extension created successfully", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -53,7 +84,12 @@ class APIBasedExtensionAPI(Resource): return APIBasedExtensionService.save(extension_data) +@console_ns.route("/api-based-extension/") class APIBasedExtensionDetailAPI(Resource): + @api.doc("get_api_based_extension") + @api.doc(description="Get API-based extension by ID") + @api.doc(params={"id": "Extension ID"}) + @api.response(200, "Success", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -64,6 +100,20 @@ class APIBasedExtensionDetailAPI(Resource): return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + @api.doc("update_api_based_extension") + @api.doc(description="Update API-based extension") + @api.doc(params={"id": "Extension ID"}) + @api.expect( + api.model( + "UpdateAPIBasedExtensionRequest", + { + "name": fields.String(required=True, description="Extension name"), + "api_endpoint": fields.String(required=True, description="API endpoint URL"), + "api_key": fields.String(required=True, description="API key for authentication"), + }, + ) + ) + @api.response(200, "Extension updated successfully", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -88,6 +138,10 @@ class APIBasedExtensionDetailAPI(Resource): return APIBasedExtensionService.save(extension_data_from_db) + @api.doc("delete_api_based_extension") + @api.doc(description="Delete API-based extension") + @api.doc(params={"id": "Extension ID"}) + @api.response(204, "Extension deleted successfully") @setup_required @login_required @account_initialization_required @@ -100,9 +154,3 @@ class APIBasedExtensionDetailAPI(Resource): APIBasedExtensionService.delete(extension_data_from_db) return {"result": "success"}, 204 - - -api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") - -api.add_resource(APIBasedExtensionAPI, "/api-based-extension") -api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/") diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 6236832d39..d43b839291 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,26 +1,40 @@ from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields from libs.login import login_required from services.feature_service import FeatureService -from . import api +from . import api, console_ns from .wraps import account_initialization_required, cloud_utm_record, setup_required +@console_ns.route("/features") class FeatureApi(Resource): + @api.doc("get_tenant_features") + @api.doc(description="Get feature configuration for current tenant") + @api.response( + 200, + "Success", + api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}), + ) @setup_required @login_required @account_initialization_required @cloud_utm_record def get(self): + """Get feature configuration for current tenant""" return FeatureService.get_features(current_user.current_tenant_id).model_dump() +@console_ns.route("/system-features") class SystemFeatureApi(Resource): + @api.doc("get_system_features") + @api.doc(description="Get system-wide feature configuration") + @api.response( + 200, + "Success", + api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}), + ) def get(self): + """Get system-wide feature configuration""" return FeatureService.get_system_features().model_dump() - - -api.add_resource(FeatureApi, "/features") -api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 56412d5bda..105f802878 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -23,6 +23,7 @@ from controllers.console.wraps import ( from extensions.ext_database import db from fields.file_fields import file_fields, upload_config_fields from libs.login import login_required +from models import Account from services.file_service import FileService PREVIEW_WORDS_LIMIT = 3000 @@ -68,6 +69,9 @@ class FileApi(Resource): if source not in ("datasets", None): source = None + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + try: upload_file = FileService(db.engine).upload_file( filename=file.filename, diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 2a37b1708a..30b53458b2 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,7 +1,7 @@ import os from flask import session -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session @@ -11,20 +11,47 @@ from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService -from . import api +from . import api, console_ns from .error import AlreadySetupError, InitValidateFailedError from .wraps import only_edition_self_hosted +@console_ns.route("/init") class InitValidateAPI(Resource): + @api.doc("get_init_status") + @api.doc(description="Get initialization validation status") + @api.response( + 200, + "Success", + model=api.model( + "InitStatusResponse", + {"status": fields.String(description="Initialization status", enum=["finished", "not_started"])}, + ), + ) def get(self): + """Get initialization validation status""" init_status = get_init_validate_status() if init_status: return {"status": "finished"} return {"status": "not_started"} + @api.doc("validate_init_password") + @api.doc(description="Validate initialization password for self-hosted edition") + @api.expect( + api.model( + "InitValidateRequest", + {"password": fields.String(required=True, description="Initialization password", max_length=30)}, + ) + ) + @api.response( + 201, + "Success", + model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Already setup or validation failed") @only_edition_self_hosted def post(self): + """Validate initialization password""" # is tenant created tenant_count = TenantService.get_tenant_count() if tenant_count > 0: @@ -52,6 +79,3 @@ def get_init_validate_status(): return db_session.execute(select(DifySetup)).scalar_one_or_none() return True - - -api.add_resource(InitValidateAPI, "/init") diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index 1a53a2347e..29f49b99de 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,14 +1,17 @@ -from flask_restx import Resource +from flask_restx import Resource, fields -from controllers.console import api +from . import api, console_ns +@console_ns.route("/ping") class PingApi(Resource): + @api.doc("health_check") + @api.doc(description="Health check endpoint for connection testing") + @api.response( + 200, + "Success", + api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}), + ) def get(self): - """ - For connection health check - """ + """Health check endpoint for connection testing""" return {"result": "pong"} - - -api.add_resource(PingApi, "/ping") diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 8e230496f0..bff5fc1651 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,5 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from configs import dify_config from libs.helper import StrLen, email, extract_remote_ip @@ -7,23 +7,56 @@ from libs.password import valid_password from models.model import DifySetup, db from services.account_service import RegisterService, TenantService -from . import api +from . import api, console_ns from .error import AlreadySetupError, NotInitValidateError from .init_validate import get_init_validate_status from .wraps import only_edition_self_hosted +@console_ns.route("/setup") class SetupApi(Resource): + @api.doc("get_setup_status") + @api.doc(description="Get system setup status") + @api.response( + 200, + "Success", + api.model( + "SetupStatusResponse", + { + "step": fields.String(description="Setup step status", enum=["not_started", "finished"]), + "setup_at": fields.String(description="Setup completion time (ISO format)", required=False), + }, + ), + ) def get(self): + """Get system setup status""" if dify_config.EDITION == "SELF_HOSTED": setup_status = get_setup_status() - if setup_status: + # Check if setup_status is a DifySetup object rather than a bool + if setup_status and not isinstance(setup_status, bool): return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} + elif setup_status: + return {"step": "finished"} return {"step": "not_started"} return {"step": "finished"} + @api.doc("setup_system") + @api.doc(description="Initialize system setup with admin account") + @api.expect( + api.model( + "SetupRequest", + { + "email": fields.String(required=True, description="Admin email address"), + "name": fields.String(required=True, description="Admin name (max 30 characters)"), + "password": fields.String(required=True, description="Admin password"), + }, + ) + ) + @api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")})) + @api.response(400, "Already setup or validation failed") @only_edition_self_hosted def post(self): + """Initialize system setup with admin account""" # is set up if get_setup_status(): raise AlreadySetupError() @@ -55,6 +88,3 @@ def get_setup_status(): return db.session.query(DifySetup).first() else: return True - - -api.add_resource(SetupApi, "/setup") diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 95515c38f9..8d081ad995 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,18 +2,41 @@ import json import logging import requests -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from packaging import version from configs import dify_config -from . import api +from . import api, console_ns logger = logging.getLogger(__name__) +@console_ns.route("/version") class VersionApi(Resource): + @api.doc("check_version_update") + @api.doc(description="Check for application version updates") + @api.expect( + api.parser().add_argument( + "current_version", type=str, required=True, location="args", help="Current application version" + ) + ) + @api.response( + 200, + "Success", + api.model( + "VersionResponse", + { + "version": fields.String(description="Latest version number"), + "release_date": fields.String(description="Release date of latest version"), + "release_notes": fields.String(description="Release notes for latest version"), + "can_auto_update": fields.Boolean(description="Whether auto-update is supported"), + "features": fields.Raw(description="Feature flags and capabilities"), + }, + ), + ) def get(self): + """Check for application version updates""" parser = reqparse.RequestParser() parser.add_argument("current_version", type=str, required=True, location="args") args = parser.parse_args() @@ -34,14 +57,14 @@ class VersionApi(Resource): return result try: - response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10)) + response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10)) except Exception as error: logger.warning("Check update version error: %s.", str(error)) - result["version"] = args.get("current_version") + result["version"] = args["current_version"] return result content = json.loads(response.content) - if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"): + if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"): result["version"] = content["version"] result["release_date"] = content["releaseDate"] result["release_notes"] = content["releaseNotes"] @@ -59,6 +82,3 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool: except version.InvalidVersion: logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version) return False - - -api.add_resource(VersionApi, "/version") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 5b2828dbab..7a41a8a5cc 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -49,6 +49,8 @@ class AccountInitApi(Resource): @setup_required @login_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user if account.status == "active": @@ -102,6 +104,8 @@ class AccountProfileApi(Resource): @marshal_with(account_fields) @enterprise_license_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") return current_user @@ -111,6 +115,8 @@ class AccountNameApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() @@ -130,6 +136,8 @@ class AccountAvatarApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() @@ -145,6 +153,8 @@ class AccountInterfaceLanguageApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() @@ -160,6 +170,8 @@ class AccountInterfaceThemeApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() @@ -175,6 +187,8 @@ class AccountTimezoneApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() @@ -194,6 +208,8 @@ class AccountPasswordApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("password", type=str, required=False, location="json") parser.add_argument("new_password", type=str, required=True, location="json") @@ -228,9 +244,13 @@ class AccountIntegrateApi(Resource): @account_initialization_required @marshal_with(integrate_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user - account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() + account_integrates = db.session.scalars( + select(AccountIntegrate).where(AccountIntegrate.account_id == account.id) + ).all() base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" @@ -268,6 +288,8 @@ class AccountDeleteVerifyApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user token, code = AccountService.generate_account_deletion_verification_code(account) @@ -281,6 +303,8 @@ class AccountDeleteApi(Resource): @login_required @account_initialization_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user parser = reqparse.RequestParser() @@ -321,6 +345,8 @@ class EducationVerifyApi(Resource): @cloud_edition_billing_enabled @marshal_with(verify_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user return BillingService.EducationIdentity.verify(account.id, account.email) @@ -340,6 +366,8 @@ class EducationApi(Resource): @only_edition_cloud @cloud_edition_billing_enabled def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user parser = reqparse.RequestParser() @@ -357,6 +385,8 @@ class EducationApi(Resource): @cloud_edition_billing_enabled @marshal_with(status_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user res = BillingService.EducationIdentity.status(account.id) @@ -421,6 +451,8 @@ class ChangeEmailSendEmailApi(Resource): raise InvalidTokenError() user_email = reset_data.get("email", "") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if user_email != current_user.email: raise InvalidEmailError() else: @@ -501,6 +533,8 @@ class ChangeEmailResetApi(Resource): AccountService.revoke_change_email_token(args["token"]) old_email = reset_data.get("old_email", "") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if current_user.email != old_email: raise AccountNotFound() diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 08bab6fcb5..0a2c8fcfb4 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,14 +1,22 @@ from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from libs.login import login_required from services.agent_service import AgentService +@console_ns.route("/workspaces/current/agent-providers") class AgentProviderListApi(Resource): + @api.doc("list_agent_providers") + @api.doc(description="Get list of available agent providers") + @api.response( + 200, + "Success", + fields.List(fields.Raw(description="Agent provider information")), + ) @setup_required @login_required @account_initialization_required @@ -21,7 +29,16 @@ class AgentProviderListApi(Resource): return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) +@console_ns.route("/workspaces/current/agent-provider/") class AgentProviderApi(Resource): + @api.doc("get_agent_provider") + @api.doc(description="Get specific agent provider details") + @api.doc(params={"provider_name": "Agent provider name"}) + @api.response( + 200, + "Success", + fields.Raw(description="Agent provider details"), + ) @setup_required @login_required @account_initialization_required @@ -30,7 +47,3 @@ class AgentProviderApi(Resource): user_id = user.id tenant_id = user.current_tenant_id return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name)) - - -api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers") -api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/") diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 96e873d42b..0657b764cc 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,8 +1,8 @@ from flask_login import current_user -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginPermissionDeniedError @@ -10,7 +10,26 @@ from libs.login import login_required from services.plugin.endpoint_service import EndpointService +@console_ns.route("/workspaces/current/endpoints/create") class EndpointCreateApi(Resource): + @api.doc("create_endpoint") + @api.doc(description="Create a new plugin endpoint") + @api.expect( + api.model( + "EndpointCreateRequest", + { + "plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"), + "settings": fields.Raw(required=True, description="Endpoint settings"), + "name": fields.String(required=True, description="Endpoint name"), + }, + ) + ) + @api.response( + 200, + "Endpoint created successfully", + api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -43,7 +62,20 @@ class EndpointCreateApi(Resource): raise ValueError(e.description) from e +@console_ns.route("/workspaces/current/endpoints/list") class EndpointListApi(Resource): + @api.doc("list_endpoints") + @api.doc(description="List plugin endpoints with pagination") + @api.expect( + api.parser() + .add_argument("page", type=int, required=True, location="args", help="Page number") + .add_argument("page_size", type=int, required=True, location="args", help="Page size") + ) + @api.response( + 200, + "Success", + api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}), + ) @setup_required @login_required @account_initialization_required @@ -70,7 +102,23 @@ class EndpointListApi(Resource): ) +@console_ns.route("/workspaces/current/endpoints/list/plugin") class EndpointListForSinglePluginApi(Resource): + @api.doc("list_plugin_endpoints") + @api.doc(description="List endpoints for a specific plugin") + @api.expect( + api.parser() + .add_argument("page", type=int, required=True, location="args", help="Page number") + .add_argument("page_size", type=int, required=True, location="args", help="Page size") + .add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID") + ) + @api.response( + 200, + "Success", + api.model( + "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} + ), + ) @setup_required @login_required @account_initialization_required @@ -100,7 +148,19 @@ class EndpointListForSinglePluginApi(Resource): ) +@console_ns.route("/workspaces/current/endpoints/delete") class EndpointDeleteApi(Resource): + @api.doc("delete_endpoint") + @api.doc(description="Delete a plugin endpoint") + @api.expect( + api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint deleted successfully", + api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -123,7 +183,26 @@ class EndpointDeleteApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/update") class EndpointUpdateApi(Resource): + @api.doc("update_endpoint") + @api.doc(description="Update a plugin endpoint") + @api.expect( + api.model( + "EndpointUpdateRequest", + { + "endpoint_id": fields.String(required=True, description="Endpoint ID"), + "settings": fields.Raw(required=True, description="Updated settings"), + "name": fields.String(required=True, description="Updated name"), + }, + ) + ) + @api.response( + 200, + "Endpoint updated successfully", + api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -154,7 +233,19 @@ class EndpointUpdateApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/enable") class EndpointEnableApi(Resource): + @api.doc("enable_endpoint") + @api.doc(description="Enable a plugin endpoint") + @api.expect( + api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint enabled successfully", + api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -177,7 +268,19 @@ class EndpointEnableApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/disable") class EndpointDisableApi(Resource): + @api.doc("disable_endpoint") + @api.doc(description="Disable a plugin endpoint") + @api.expect( + api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint disabled successfully", + api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -198,12 +301,3 @@ class EndpointDisableApi(Resource): tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id ) } - - -api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create") -api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list") -api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin") -api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete") -api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update") -api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable") -api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable") diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index cf2a10f453..77f0c9a735 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,8 +1,8 @@ from urllib import parse -from flask import request +from flask import abort, request from flask_login import current_user -from flask_restx import Resource, abort, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse import services from configs import dify_config @@ -41,6 +41,10 @@ class MemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) return {"result": "success", "accounts": members}, 200 @@ -65,7 +69,11 @@ class MemberInviteEmailApi(Resource): if not TenantAccountRole.is_non_owner_role(invitee_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") inviter = current_user + if not inviter.current_tenant: + raise ValueError("No current tenant") invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL @@ -76,6 +84,8 @@ class MemberInviteEmailApi(Resource): for invitee_email in invitee_emails: try: + if not inviter.current_tenant: + raise ValueError("No current tenant") token = RegisterService.invite_new_member( inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter ) @@ -97,7 +107,7 @@ class MemberInviteEmailApi(Resource): return { "result": "success", "invitation_results": invitation_results, - "tenant_id": str(current_user.current_tenant.id), + "tenant_id": str(inviter.current_tenant.id) if inviter.current_tenant else "", }, 201 @@ -108,6 +118,10 @@ class MemberCancelInviteApi(Resource): @login_required @account_initialization_required def delete(self, member_id): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") member = db.session.query(Account).where(Account.id == str(member_id)).first() if member is None: abort(404) @@ -123,7 +137,10 @@ class MemberCancelInviteApi(Resource): except Exception as e: raise ValueError(str(e)) - return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200 + return { + "result": "success", + "tenant_id": str(current_user.current_tenant.id) if current_user.current_tenant else "", + }, 200 class MemberUpdateRoleApi(Resource): @@ -141,6 +158,10 @@ class MemberUpdateRoleApi(Resource): if not TenantAccountRole.is_valid_role(new_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") member = db.session.get(Account, str(member_id)) if not member: abort(404) @@ -164,6 +185,10 @@ class DatasetOperatorMemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) return {"result": "success", "accounts": members}, 200 @@ -184,6 +209,10 @@ class SendOwnerTransferEmailApi(Resource): raise EmailSendIpLimitError() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -198,7 +227,7 @@ class SendOwnerTransferEmailApi(Resource): account=current_user, email=email, language=language, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", ) return {"result": "success", "data": token} @@ -215,6 +244,10 @@ class OwnerTransferCheckApi(Resource): parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -256,6 +289,10 @@ class OwnerTransfer(Resource): args = parser.parse_args() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -274,9 +311,11 @@ class OwnerTransfer(Resource): member = db.session.get(Account, str(member_id)) if not member: abort(404) - else: - member_account = member - if not TenantService.is_member(member_account, current_user.current_tenant): + return # Never reached, but helps type checker + + if not current_user.current_tenant: + raise ValueError("No current tenant") + if not TenantService.is_member(member, current_user.current_tenant): raise MemberNotInTenantError() try: @@ -286,13 +325,13 @@ class OwnerTransfer(Resource): AccountService.send_new_owner_transfer_notify_email( account=member, email=member.email, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", ) AccountService.send_old_owner_transfer_notify_email( account=current_user, email=current_user.email, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", new_owner_email=member.email, ) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index bfcc9a7f0a..0c9db660aa 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import StrLen, uuid_value from libs.login import login_required +from models.account import Account from services.billing_service import BillingService from services.model_provider_service import ModelProviderService @@ -21,6 +22,10 @@ class ModelProviderListApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() @@ -45,6 +50,10 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def get(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id # if credential_id is not provided, return current used credential parser = reqparse.RequestParser() @@ -62,6 +71,8 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() @@ -72,6 +83,8 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") try: model_provider_service.create_provider_credential( tenant_id=current_user.current_tenant_id, @@ -88,6 +101,8 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def put(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() @@ -99,6 +114,8 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") try: model_provider_service.update_provider_credential( tenant_id=current_user.current_tenant_id, @@ -116,12 +133,16 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def delete(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") model_provider_service = ModelProviderService() model_provider_service.remove_provider_credential( tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] @@ -135,12 +156,16 @@ class ModelProviderCredentialSwitchApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") service = ModelProviderService() service.switch_active_provider_credential( tenant_id=current_user.current_tenant_id, @@ -155,10 +180,14 @@ class ModelProviderValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() @@ -205,9 +234,13 @@ class PreferredProviderTypeUpdateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() @@ -236,7 +269,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): def get(self, provider: str): if provider != "anthropic": raise ValueError(f"provider name {provider} is invalid") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") BillingService.is_tenant_owner_or_admin(current_user) + if not current_user.current_tenant_id: + raise ValueError("No current tenant") data = BillingService.get_model_provider_payment_link( provider_name=provider, tenant_id=current_user.current_tenant_id, diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 3c8299b2a1..6bec70b5da 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -25,7 +25,7 @@ from controllers.console.wraps import ( from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required -from models.account import Tenant, TenantStatus +from models.account import Account, Tenant, TenantStatus from services.account_service import TenantService from services.feature_service import FeatureService from services.file_service import FileService @@ -70,6 +70,8 @@ class TenantListApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") tenants = TenantService.get_join_tenants(current_user) tenant_dicts = [] @@ -83,7 +85,7 @@ class TenantListApi(Resource): "status": tenant.status, "created_at": tenant.created_at, "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", - "current": tenant.id == current_user.current_tenant_id, + "current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False, } tenant_dicts.append(tenant_dict) @@ -125,7 +127,11 @@ class TenantApi(Resource): if request.path == "/info": logger.warning("Deprecated URL /info was used.") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") tenant = current_user.current_tenant + if not tenant: + raise ValueError("No current tenant") if tenant.status == TenantStatus.ARCHIVE: tenants = TenantService.get_join_tenants(current_user) @@ -137,6 +143,8 @@ class TenantApi(Resource): else: raise Unauthorized("workspace is archived") + if not tenant: + raise ValueError("No tenant available") return WorkspaceService.get_tenant_info(tenant), 200 @@ -145,6 +153,8 @@ class SwitchWorkspaceApi(Resource): @login_required @account_initialization_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() @@ -168,11 +178,15 @@ class CustomConfigWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("remove_webapp_brand", type=bool, location="json") parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant = db.get_or_404(Tenant, current_user.current_tenant_id) custom_config_dict = { @@ -194,6 +208,8 @@ class WebappLogoWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") # check file if "file" not in request.files: raise NoFileUploadedError() @@ -232,10 +248,14 @@ class WorkspaceInfoApi(Resource): @account_initialization_required # Change workspace name def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant.name = args["name"] db.session.commit() diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 07d3b0091a..914d386c78 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -242,6 +242,19 @@ def email_password_login_enabled(view: Callable[P, R]): return decorated +def email_register_enabled(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if features.is_allow_register: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated + + def enable_change_email(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 821ad220a2..f8976b86b9 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Files API", description="API for file operations including upload and preview", - doc="/docs", # Enable Swagger UI at /files/docs ) files_ns = Namespace("files", description="File operations", path="/") @@ -18,3 +17,12 @@ files_ns = Namespace("files", description="File operations", path="/") from . import image_preview, tool_files, upload api.add_namespace(files_ns) + +__all__ = [ + "api", + "bp", + "files_ns", + "image_preview", + "tool_files", + "upload", +] diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 7a2b3b0428..206a5d1cc2 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -1,5 +1,4 @@ from mimetypes import guess_extension -from typing import Optional from flask_restx import Resource, reqparse from flask_restx.api import HTTPStatus @@ -73,11 +72,11 @@ class PluginUploadFileApi(Resource): nonce: str = args["nonce"] sign: str = args["sign"] tenant_id: str = args["tenant_id"] - user_id: Optional[str] = args.get("user_id") + user_id: str | None = args.get("user_id") user = get_user(tenant_id, user_id) - filename: Optional[str] = file.filename - mimetype: Optional[str] = file.mimetype + filename: str | None = file.filename + mimetype: str | None = file.mimetype if not filename or not mimetype: raise Forbidden("Invalid request.") @@ -86,7 +85,7 @@ class PluginUploadFileApi(Resource): filename=filename, mimetype=mimetype, tenant_id=tenant_id, - user_id=user_id, + user_id=user.id, timestamp=timestamp, nonce=nonce, sign=sign, diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index d29a7be139..74005217ef 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -10,14 +10,22 @@ api = ExternalApi( version="1.0", title="Inner API", description="Internal APIs for enterprise features, billing, and plugin communication", - doc="/docs", # Enable Swagger UI at /inner/api/docs ) # Create namespace inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") -from . import mail -from .plugin import plugin -from .workspace import workspace +from . import mail as _mail +from .plugin import plugin as _plugin +from .workspace import workspace as _workspace api.add_namespace(inner_api_ns) + +__all__ = [ + "_mail", + "_plugin", + "_workspace", + "api", + "bp", + "inner_api_ns", +] diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 170a794d89..c5bb2f2545 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -37,9 +37,9 @@ from models.model import EndUser @inner_api_ns.route("/invoke/llm") class PluginInvokeLLMApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeLLM) @inner_api_ns.doc("plugin_invoke_llm") @inner_api_ns.doc(description="Invoke LLM models through plugin interface") @@ -60,9 +60,9 @@ class PluginInvokeLLMApi(Resource): @inner_api_ns.route("/invoke/llm/structured-output") class PluginInvokeLLMWithStructuredOutputApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput) @inner_api_ns.doc("plugin_invoke_llm_structured") @inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface") @@ -85,9 +85,9 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource): @inner_api_ns.route("/invoke/text-embedding") class PluginInvokeTextEmbeddingApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTextEmbedding) @inner_api_ns.doc("plugin_invoke_text_embedding") @inner_api_ns.doc(description="Invoke text embedding models through plugin interface") @@ -115,9 +115,9 @@ class PluginInvokeTextEmbeddingApi(Resource): @inner_api_ns.route("/invoke/rerank") class PluginInvokeRerankApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeRerank) @inner_api_ns.doc("plugin_invoke_rerank") @inner_api_ns.doc(description="Invoke rerank models through plugin interface") @@ -141,9 +141,9 @@ class PluginInvokeRerankApi(Resource): @inner_api_ns.route("/invoke/tts") class PluginInvokeTTSApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTTS) @inner_api_ns.doc("plugin_invoke_tts") @inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface") @@ -168,9 +168,9 @@ class PluginInvokeTTSApi(Resource): @inner_api_ns.route("/invoke/speech2text") class PluginInvokeSpeech2TextApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeSpeech2Text) @inner_api_ns.doc("plugin_invoke_speech2text") @inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface") @@ -194,9 +194,9 @@ class PluginInvokeSpeech2TextApi(Resource): @inner_api_ns.route("/invoke/moderation") class PluginInvokeModerationApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeModeration) @inner_api_ns.doc("plugin_invoke_moderation") @inner_api_ns.doc(description="Invoke moderation models through plugin interface") @@ -220,9 +220,9 @@ class PluginInvokeModerationApi(Resource): @inner_api_ns.route("/invoke/tool") class PluginInvokeToolApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTool) @inner_api_ns.doc("plugin_invoke_tool") @inner_api_ns.doc(description="Invoke tools through plugin interface") @@ -252,9 +252,9 @@ class PluginInvokeToolApi(Resource): @inner_api_ns.route("/invoke/parameter-extractor") class PluginInvokeParameterExtractorNodeApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeParameterExtractorNode) @inner_api_ns.doc("plugin_invoke_parameter_extractor") @inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface") @@ -285,9 +285,9 @@ class PluginInvokeParameterExtractorNodeApi(Resource): @inner_api_ns.route("/invoke/question-classifier") class PluginInvokeQuestionClassifierNodeApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) @inner_api_ns.doc("plugin_invoke_question_classifier") @inner_api_ns.doc(description="Invoke question classifier node through plugin interface") @@ -318,9 +318,9 @@ class PluginInvokeQuestionClassifierNodeApi(Resource): @inner_api_ns.route("/invoke/app") class PluginInvokeAppApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeApp) @inner_api_ns.doc("plugin_invoke_app") @inner_api_ns.doc(description="Invoke application through plugin interface") @@ -348,9 +348,9 @@ class PluginInvokeAppApi(Resource): @inner_api_ns.route("/invoke/encrypt") class PluginInvokeEncryptApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeEncrypt) @inner_api_ns.doc("plugin_invoke_encrypt") @inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface") @@ -375,9 +375,9 @@ class PluginInvokeEncryptApi(Resource): @inner_api_ns.route("/invoke/summary") class PluginInvokeSummaryApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeSummary) @inner_api_ns.doc("plugin_invoke_summary") @inner_api_ns.doc(description="Invoke summary functionality through plugin interface") @@ -405,9 +405,9 @@ class PluginInvokeSummaryApi(Resource): @inner_api_ns.route("/upload/file/request") class PluginUploadFileRequestApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestRequestUploadFile) @inner_api_ns.doc("plugin_upload_file_request") @inner_api_ns.doc(description="Request signed URL for file upload through plugin interface") @@ -426,9 +426,9 @@ class PluginUploadFileRequestApi(Resource): @inner_api_ns.route("/fetch/app/info") class PluginFetchAppInfoApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestFetchAppInfo) @inner_api_ns.doc("plugin_fetch_app_info") @inner_api_ns.doc(description="Fetch application information through plugin interface") diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index f751e06ddf..3776d0be0e 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional +from typing import ParamSpec, TypeVar, cast from flask import current_app, request from flask_login import user_logged_in @@ -8,11 +8,13 @@ from flask_restx import reqparse from pydantic import BaseModel from sqlalchemy.orm import Session -from core.file.constants import DEFAULT_SERVICE_API_USER_ID from extensions.ext_database import db -from libs.login import _get_user +from libs.login import current_user from models.account import Tenant -from models.model import EndUser +from models.model import DefaultEndUserSessionID, EndUser + +P = ParamSpec("P") +R = TypeVar("R") def get_user(tenant_id: str, user_id: str | None) -> EndUser: @@ -25,7 +27,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: try: with Session(db.engine) as session: if not user_id: - user_id = DEFAULT_SERVICE_API_USER_ID + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value user_model = ( session.query(EndUser) @@ -39,7 +41,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: user_model = EndUser( tenant_id=tenant_id, type="service_api", - is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID, + is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value, session_id=user_id, ) session.add(user_model) @@ -52,28 +54,25 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: return user_model -def get_user_tenant(view: Optional[Callable] = None): - def decorator(view_func): +def get_user_tenant(view: Callable[P, R] | None = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): # fetch json body parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") parser.add_argument("user_id", type=str, required=True, location="json") - kwargs = parser.parse_args() + p = parser.parse_args() - user_id = kwargs.get("user_id") - tenant_id = kwargs.get("tenant_id") + user_id = cast(str, p.get("user_id")) + tenant_id = cast(str, p.get("tenant_id")) if not tenant_id: raise ValueError("tenant_id is required") if not user_id: - user_id = DEFAULT_SERVICE_API_USER_ID - - del kwargs["tenant_id"] - del kwargs["user_id"] + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value try: tenant_model = ( @@ -95,7 +94,7 @@ def get_user_tenant(view: Optional[Callable] = None): kwargs["user_model"] = user current_app.login_manager._update_request_context_with_user(user) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore return view_func(*args, **kwargs) @@ -107,9 +106,9 @@ def get_user_tenant(view: Optional[Callable] = None): return decorator(view) -def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]): - def decorator(view_func): - def decorated_view(*args, **kwargs): +def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): + def decorator(view_func: Callable[P, R]): + def decorated_view(*args: P.args, **kwargs: P.kwargs): try: data = request.get_json() except Exception: diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index de4f1da801..4bdcc6832a 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -46,9 +46,9 @@ def enterprise_inner_api_only(view: Callable[P, R]): return decorated -def enterprise_inner_api_user_auth(view): +def enterprise_inner_api_user_auth(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if not dify_config.INNER_API: return view(*args, **kwargs) diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py index c344ffad08..d6fb2981e4 100644 --- a/api/controllers/mcp/__init__.py +++ b/api/controllers/mcp/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="MCP API", description="API for Model Context Protocol operations", - doc="/docs", # Enable Swagger UI at /mcp/docs ) mcp_ns = Namespace("mcp", description="MCP operations", path="/") @@ -18,3 +17,10 @@ mcp_ns = Namespace("mcp", description="MCP operations", path="/") from . import mcp api.add_namespace(mcp_ns) + +__all__ = [ + "api", + "bp", + "mcp", + "mcp_ns", +] diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 43b59d5334..a8629dca20 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union from flask import Response from flask_restx import Resource, reqparse @@ -73,7 +73,7 @@ class MCPAppApi(Resource): ValidationError: Invalid request format or parameters """ args = mcp_request_parser.parse_args() - request_id: Optional[Union[int, str]] = args.get("id") + request_id: Union[int, str] | None = args.get("id") mcp_request = self._parse_mcp_request(args) with Session(db.engine, expire_on_commit=False) as session: @@ -107,7 +107,7 @@ class MCPAppApi(Resource): def _process_mcp_message( self, mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification, - request_id: Optional[Union[int, str]], + request_id: Union[int, str] | None, app: App, mcp_server: AppMCPServer, user_input_form: list[VariableEntity], @@ -130,7 +130,7 @@ class MCPAppApi(Resource): def _handle_request( self, mcp_request: mcp_types.ClientRequest, - request_id: Optional[Union[int, str]], + request_id: Union[int, str] | None, app: App, mcp_server: AppMCPServer, user_input_form: list[VariableEntity], @@ -150,7 +150,7 @@ class MCPAppApi(Resource): def _get_user_input_form(self, app: App) -> list[VariableEntity]: """Get and convert user input form""" # Get raw user input form based on app mode - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if not app.workflow: raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable") raw_user_input_form = app.workflow.user_input_form(to_old_structure=True) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 763345d723..9032733e2c 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -10,14 +10,50 @@ api = ExternalApi( version="1.0", title="Service API", description="API for application services", - doc="/docs", # Enable Swagger UI at /v1/docs ) service_api_ns = Namespace("service_api", description="Service operations", path="/") from . import index -from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow -from .dataset import dataset, document, hit_testing, metadata, segment, upload_file +from .app import ( + annotation, + app, + audio, + completion, + conversation, + file, + file_preview, + message, + site, + workflow, +) +from .dataset import ( + dataset, + document, + hit_testing, + metadata, + segment, +) from .workspace import models +__all__ = [ + "annotation", + "app", + "audio", + "completion", + "conversation", + "dataset", + "document", + "file", + "file_preview", + "hit_testing", + "index", + "message", + "metadata", + "models", + "segment", + "site", + "workflow", +] + api.add_namespace(service_api_ns) diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 9038bda11a..ad1bdc7334 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -165,7 +165,7 @@ class AnnotationUpdateDeleteApi(Resource): def put(self, app_model: App, annotation_id): """Update an existing annotation.""" assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() annotation_id = str(annotation_id) @@ -189,7 +189,7 @@ class AnnotationUpdateDeleteApi(Resource): """Delete an annotation.""" assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() annotation_id = str(annotation_id) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 2dbeed1d68..25d7ccccec 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -29,7 +29,7 @@ class AppParameterApi(Resource): Returns the input form parameters and configuration for the application. """ - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 4860bf3a79..711dd5704c 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,4 +1,5 @@ from flask_restx import Resource, reqparse +from flask_restx._http import HTTPStatus from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound @@ -121,7 +122,7 @@ class ConversationDetailApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204) + @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT) def delete(self, app_model: App, end_user: EndUser, c_id): """Delete a specific conversation.""" app_mode = AppMode.value_of(app_model.mode) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index f676374e5f..6a70345f7c 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -340,6 +340,9 @@ class DatasetApi(DatasetApiResource): else: data["embedding_available"] = True + # force update search method to keyword_search if indexing_technique is economic + data["retrieval_model_dict"]["search_method"] = "keyword_search" + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) data.update({"partial_member_list": part_users_list}) @@ -559,7 +562,7 @@ class DatasetTagsApi(DatasetApiResource): def post(self, _, dataset_id): """Add a knowledge type tag.""" assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_create_parser.parse_args() @@ -583,7 +586,7 @@ class DatasetTagsApi(DatasetApiResource): @validate_dataset_token def patch(self, _, dataset_id): assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_update_parser.parse_args() @@ -610,7 +613,7 @@ class DatasetTagsApi(DatasetApiResource): def delete(self, _, dataset_id): """Delete a knowledge type tag.""" assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() args = tag_delete_parser.parse_args() TagService.delete_tag(args.get("tag_id")) @@ -634,7 +637,7 @@ class DatasetTagBindingApi(DatasetApiResource): def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_binding_parser.parse_args() @@ -660,7 +663,7 @@ class DatasetTagUnbindingApi(DatasetApiResource): def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_unbinding_parser.parse_args() diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 4bce64e0a1..6b635bcfbd 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -30,6 +30,7 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment +from models.model import EndUser from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -303,6 +304,8 @@ class DocumentAddByFileApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError + if not isinstance(current_user, EndUser): + raise ValueError("Invalid user account") upload_file = FileService(db.engine).upload_file( filename=file.filename, content=file.read(), @@ -391,6 +394,8 @@ class DocumentUpdateByFileApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError + if not isinstance(current_user, EndUser): + raise ValueError("Invalid user account") try: upload_file = FileService(db.engine).upload_file( filename=file.filename, diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py deleted file mode 100644 index 27b36a6402..0000000000 --- a/api/controllers/service_api/dataset/upload_file.py +++ /dev/null @@ -1,65 +0,0 @@ -from werkzeug.exceptions import NotFound - -from controllers.service_api import service_api_ns -from controllers.service_api.wraps import ( - DatasetApiResource, -) -from core.file import helpers as file_helpers -from extensions.ext_database import db -from models.dataset import Dataset -from models.model import UploadFile -from services.dataset_service import DocumentService - - -@service_api_ns.route("/datasets//documents//upload-file") -class UploadFileApi(DatasetApiResource): - @service_api_ns.doc("get_upload_file") - @service_api_ns.doc(description="Get upload file information and download URL") - @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) - @service_api_ns.doc( - responses={ - 200: "Upload file information retrieved successfully", - 401: "Unauthorized - invalid API token", - 404: "Dataset, document, or upload file not found", - } - ) - def get(self, tenant_id, dataset_id, document_id): - """Get upload file information and download URL. - - Returns information about an uploaded file including its download URL. - """ - # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: - raise NotFound("Dataset not found.") - # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset.id, document_id) - if not document: - raise NotFound("Document not found.") - # check upload file - if document.data_source_type != "upload_file": - raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.") - data_source_info = document.data_source_info_dict - if data_source_info and "upload_file_id" in data_source_info: - file_id = data_source_info["upload_file_id"] - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() - if not upload_file: - raise NotFound("UploadFile not found.") - else: - raise ValueError("Upload file id not found in document data source info.") - - url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": url, - "download_url": f"{url}&as_attachment=true", - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at.timestamp(), - }, 200 diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 536cf81a2f..fffcb47bd4 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -19,7 +19,7 @@ class ModelProviderAvailableModelApi(Resource): } ) @validate_dataset_token - def get(self, _, model_type): + def get(self, _, model_type: str): """Get available models by model type. Returns a list of available models for the specified model type. diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index e8816c74a9..246d3750d1 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -3,7 +3,7 @@ from collections.abc import Callable from datetime import timedelta from enum import StrEnum, auto from functools import wraps -from typing import Optional, ParamSpec, TypeVar +from typing import Concatenate, ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in @@ -13,18 +13,18 @@ from sqlalchemy import select, update from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound, Unauthorized -from core.file.constants import DEFAULT_SERVICE_API_USER_ID from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now -from libs.login import _get_user +from libs.login import current_user from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.dataset import Dataset, RateLimitLog -from models.model import ApiToken, App, EndUser +from models.model import ApiToken, App, DefaultEndUserSessionID, EndUser from services.feature_service import FeatureService P = ParamSpec("P") R = TypeVar("R") +T = TypeVar("T") class WhereisUserArg(StrEnum): @@ -42,10 +42,10 @@ class FetchUserArg(BaseModel): required: bool = False -def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): - def decorator(view_func): +def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token("app") app_model = db.session.query(App).where(App.id == api_token.app_id).first() @@ -189,17 +189,17 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): return interceptor -def validate_dataset_token(view=None): - def decorator(view): +def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): + def decorator(view: Callable[Concatenate[T, P], R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): # get url path dataset_id from positional args or kwargs # Flask passes URL path parameters as positional arguments dataset_id = None - + # First try to get from kwargs (explicit parameter) dataset_id = kwargs.get("dataset_id") - + # If not in kwargs, try to extract from positional args if not dataset_id and args: # For class methods: args[0] is self, args[1] is dataset_id (if exists) @@ -225,7 +225,7 @@ def validate_dataset_token(view=None): dataset_id = str_id except: pass - + # Validate dataset if dataset_id is provided if dataset_id: dataset_id = str(dataset_id) @@ -250,7 +250,7 @@ def validate_dataset_token(view=None): if account: account.current_tenant = tenant current_app.login_manager._update_request_context_with_user(account) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore else: raise Unauthorized("Tenant owner account does not exist.") else: @@ -308,12 +308,12 @@ def validate_and_get_api_token(scope: str | None = None): return api_token -def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser: +def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None = None) -> EndUser: """ Create or update session terminal based on user ID. """ if not user_id: - user_id = DEFAULT_SERVICE_API_USER_ID + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value with Session(db.engine, expire_on_commit=False) as session: end_user = ( @@ -332,7 +332,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] tenant_id=app_model.tenant_id, app_id=app_model.id, type="service_api", - is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID, + is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value, session_id=user_id, ) session.add(end_user) diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 3b0a9e341a..1d22954308 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Web API", description="Public APIs for web applications including file uploads, chat interactions, and app management", - doc="/docs", # Enable Swagger UI at /api/docs ) # Create namespace @@ -34,3 +33,23 @@ from . import ( ) api.add_namespace(web_ns) + +__all__ = [ + "api", + "app", + "audio", + "bp", + "completion", + "conversation", + "feature", + "files", + "forgot_password", + "login", + "message", + "passport", + "remote_files", + "saved_message", + "site", + "web_ns", + "workflow", +] diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index e0c3e997ce..2bc068ec75 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -38,7 +38,7 @@ class AppParameterApi(WebApiResource): @marshal_with(fields.parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 2c0f6c9759..c1c46891b6 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -5,7 +5,7 @@ from flask_restx import fields, marshal_with, reqparse from werkzeug.exceptions import InternalServerError import services -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, AudioTooLargeError, @@ -32,15 +32,16 @@ from services.errors.audio import ( logger = logging.getLogger(__name__) +@web_ns.route("/audio-to-text") class AudioApi(WebApiResource): audio_to_text_response_fields = { "text": fields.String, } @marshal_with(audio_to_text_response_fields) - @api.doc("Audio to Text") - @api.doc(description="Convert audio file to text using speech-to-text service.") - @api.doc( + @web_ns.doc("Audio to Text") + @web_ns.doc(description="Convert audio file to text using speech-to-text service.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -85,6 +86,7 @@ class AudioApi(WebApiResource): raise InternalServerError() +@web_ns.route("/text-to-audio") class TextApi(WebApiResource): text_to_audio_response_fields = { "audio_url": fields.String, @@ -92,9 +94,9 @@ class TextApi(WebApiResource): } @marshal_with(text_to_audio_response_fields) - @api.doc("Text to Audio") - @api.doc(description="Convert text to audio using text-to-speech service.") - @api.doc( + @web_ns.doc("Text to Audio") + @web_ns.doc(description="Convert text to audio using text-to-speech service.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -145,7 +147,3 @@ class TextApi(WebApiResource): except Exception as e: logger.exception("Failed to handle post request to TextApi") raise InternalServerError() - - -api.add_resource(AudioApi, "/audio-to-text") -api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index a42bf5fc6e..67ae970388 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -4,7 +4,7 @@ from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, CompletionRequestError, @@ -35,10 +35,11 @@ logger = logging.getLogger(__name__) # define completion api for user +@web_ns.route("/completion-messages") class CompletionApi(WebApiResource): - @api.doc("Create Completion Message") - @api.doc(description="Create a completion message for text generation applications.") - @api.doc( + @web_ns.doc("Create Completion Message") + @web_ns.doc(description="Create a completion message for text generation applications.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the completion", "type": "object", "required": True}, "query": {"description": "Query text for completion", "type": "string", "required": False}, @@ -52,7 +53,7 @@ class CompletionApi(WebApiResource): "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -106,11 +107,12 @@ class CompletionApi(WebApiResource): raise InternalServerError() +@web_ns.route("/completion-messages//stop") class CompletionStopApi(WebApiResource): - @api.doc("Stop Completion Message") - @api.doc(description="Stop a running completion message task.") - @api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) - @api.doc( + @web_ns.doc("Stop Completion Message") + @web_ns.doc(description="Stop a running completion message task.") + @web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -129,10 +131,11 @@ class CompletionStopApi(WebApiResource): return {"result": "success"}, 200 +@web_ns.route("/chat-messages") class ChatApi(WebApiResource): - @api.doc("Create Chat Message") - @api.doc(description="Create a chat message for conversational applications.") - @api.doc( + @web_ns.doc("Create Chat Message") + @web_ns.doc(description="Create a chat message for conversational applications.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the chat", "type": "object", "required": True}, "query": {"description": "User query/message", "type": "string", "required": True}, @@ -148,7 +151,7 @@ class ChatApi(WebApiResource): "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -207,11 +210,12 @@ class ChatApi(WebApiResource): raise InternalServerError() +@web_ns.route("/chat-messages//stop") class ChatStopApi(WebApiResource): - @api.doc("Stop Chat Message") - @api.doc(description="Stop a running chat message task.") - @api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) - @api.doc( + @web_ns.doc("Stop Chat Message") + @web_ns.doc(description="Stop a running chat message task.") + @web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -229,9 +233,3 @@ class ChatStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {"result": "success"}, 200 - - -api.add_resource(CompletionApi, "/completion-messages") -api.add_resource(CompletionStopApi, "/completion-messages//stop") -api.add_resource(ChatApi, "/chat-messages") -api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 24de4f3f2e..03dd986aed 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -3,7 +3,7 @@ from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom @@ -16,7 +16,44 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers from services.web_conversation_service import WebConversationService +@web_ns.route("/conversations") class ConversationListApi(WebApiResource): + @web_ns.doc("Get Conversation List") + @web_ns.doc(description="Retrieve paginated list of conversations for a chat application.") + @web_ns.doc( + params={ + "last_id": {"description": "Last conversation ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of conversations to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + "pinned": { + "description": "Filter by pinned status", + "type": "string", + "enum": ["true", "false"], + "required": False, + }, + "sort_by": { + "description": "Sort order", + "type": "string", + "enum": ["created_at", "-created_at", "updated_at", "-updated_at"], + "required": False, + "default": "-updated_at", + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "App Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -57,11 +94,25 @@ class ConversationListApi(WebApiResource): raise NotFound("Last Conversation Not Exists.") +@web_ns.route("/conversations/") class ConversationApi(WebApiResource): delete_response_fields = { "result": fields.String, } + @web_ns.doc("Delete Conversation") + @web_ns.doc(description="Delete a specific conversation.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 204: "Conversation deleted successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(delete_response_fields) def delete(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -76,7 +127,32 @@ class ConversationApi(WebApiResource): return {"result": "success"}, 204 +@web_ns.route("/conversations//name") class ConversationRenameApi(WebApiResource): + @web_ns.doc("Rename Conversation") + @web_ns.doc(description="Rename a specific conversation with a custom name or auto-generate one.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + params={ + "name": {"description": "New conversation name", "type": "string", "required": False}, + "auto_generate": { + "description": "Auto-generate conversation name", + "type": "boolean", + "required": False, + "default": False, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Conversation renamed successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -96,11 +172,25 @@ class ConversationRenameApi(WebApiResource): raise NotFound("Conversation Not Exists.") +@web_ns.route("/conversations//pin") class ConversationPinApi(WebApiResource): pin_response_fields = { "result": fields.String, } + @web_ns.doc("Pin Conversation") + @web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Conversation pinned successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(pin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -117,11 +207,25 @@ class ConversationPinApi(WebApiResource): return {"result": "success"} +@web_ns.route("/conversations//unpin") class ConversationUnPinApi(WebApiResource): unpin_response_fields = { "result": fields.String, } + @web_ns.doc("Unpin Conversation") + @web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Conversation unpinned successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(unpin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -132,10 +236,3 @@ class ConversationUnPinApi(WebApiResource): WebConversationService.unpin(app_model, conversation_id, end_user) return {"result": "success"} - - -api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="web_conversation_name") -api.add_resource(ConversationListApi, "/conversations") -api.add_resource(ConversationApi, "/conversations/") -api.add_resource(ConversationPinApi, "/conversations//pin") -api.add_resource(ConversationUnPinApi, "/conversations//unpin") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 17e06e8856..26c0b133d9 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -4,7 +4,7 @@ from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, @@ -38,6 +38,7 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) +@web_ns.route("/messages") class MessageListApi(WebApiResource): message_fields = { "id": fields.String, @@ -62,6 +63,30 @@ class MessageListApi(WebApiResource): "data": fields.List(fields.Nested(message_fields)), } + @web_ns.doc("Get Message List") + @web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.") + @web_ns.doc( + params={ + "conversation_id": {"description": "Conversation UUID", "type": "string", "required": True}, + "first_id": {"description": "First message ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of messages to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -84,11 +109,36 @@ class MessageListApi(WebApiResource): raise NotFound("First Message Not Exists.") +@web_ns.route("/messages//feedbacks") class MessageFeedbackApi(WebApiResource): feedback_response_fields = { "result": fields.String, } + @web_ns.doc("Create Message Feedback") + @web_ns.doc(description="Submit feedback (like/dislike) for a specific message.") + @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) + @web_ns.doc( + params={ + "rating": { + "description": "Feedback rating", + "type": "string", + "enum": ["like", "dislike"], + "required": False, + }, + "content": {"description": "Feedback content/comment", "type": "string", "required": False}, + } + ) + @web_ns.doc( + responses={ + 200: "Feedback submitted successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(feedback_response_fields) def post(self, app_model, end_user, message_id): message_id = str(message_id) @@ -112,7 +162,31 @@ class MessageFeedbackApi(WebApiResource): return {"result": "success"} +@web_ns.route("/messages//more-like-this") class MessageMoreLikeThisApi(WebApiResource): + @web_ns.doc("Generate More Like This") + @web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).") + @web_ns.doc( + params={ + "message_id": {"description": "Message UUID", "type": "string", "required": True}, + "response_mode": { + "description": "Response mode", + "type": "string", + "enum": ["blocking", "streaming"], + "required": True, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a completion app or feature disabled", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) def get(self, app_model, end_user, message_id): if app_model.mode != "completion": raise NotCompletionAppError() @@ -156,11 +230,25 @@ class MessageMoreLikeThisApi(WebApiResource): raise InternalServerError() +@web_ns.route("/messages//suggested-questions") class MessageSuggestedQuestionApi(WebApiResource): suggested_questions_response_fields = { "data": fields.List(fields.String), } + @web_ns.doc("Get Suggested Questions") + @web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).") + @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a chat app or feature disabled", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found or Conversation Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(suggested_questions_response_fields) def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) @@ -192,9 +280,3 @@ class MessageSuggestedQuestionApi(WebApiResource): raise InternalServerError() return {"data": questions} - - -api.add_resource(MessageListApi, "/messages") -api.add_resource(MessageFeedbackApi, "/messages//feedbacks") -api.add_resource(MessageMoreLikeThisApi, "/messages//more-like-this") -api.add_resource(MessageSuggestedQuestionApi, "/messages//suggested-questions") diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 7a9d24114e..96f09c8d3c 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import message_file_fields @@ -23,6 +23,7 @@ message_fields = { } +@web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { "limit": fields.Integer, @@ -34,6 +35,29 @@ class SavedMessageListApi(WebApiResource): "result": fields.String, } + @web_ns.doc("Get Saved Messages") + @web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.") + @web_ns.doc( + params={ + "last_id": {"description": "Last message ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of messages to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "App Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): if app_model.mode != "completion": @@ -46,6 +70,23 @@ class SavedMessageListApi(WebApiResource): return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) + @web_ns.doc("Save Message") + @web_ns.doc(description="Save a specific message for later reference.") + @web_ns.doc( + params={ + "message_id": {"description": "Message UUID to save", "type": "string", "required": True}, + } + ) + @web_ns.doc( + responses={ + 200: "Message saved successfully", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(post_response_fields) def post(self, app_model, end_user): if app_model.mode != "completion": @@ -63,11 +104,25 @@ class SavedMessageListApi(WebApiResource): return {"result": "success"} +@web_ns.route("/saved-messages/") class SavedMessageApi(WebApiResource): delete_response_fields = { "result": fields.String, } + @web_ns.doc("Delete Saved Message") + @web_ns.doc(description="Remove a message from saved messages.") + @web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 204: "Message removed successfully", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(delete_response_fields) def delete(self, app_model, end_user, message_id): message_id = str(message_id) @@ -78,7 +133,3 @@ class SavedMessageApi(WebApiResource): SavedMessageService.delete(app_model, end_user, message_id) return {"result": "success"}, 204 - - -api.add_resource(SavedMessageListApi, "/saved-messages") -api.add_resource(SavedMessageApi, "/saved-messages/") diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 91d67bf9d8..b01aaba357 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.web import api +from controllers.web import web_ns from controllers.web.wraps import WebApiResource from extensions.ext_database import db from libs.helper import AppIconUrlField @@ -11,6 +11,7 @@ from models.model import Site from services.feature_service import FeatureService +@web_ns.route("/site") class AppSiteApi(WebApiResource): """Resource for app sites.""" @@ -53,9 +54,9 @@ class AppSiteApi(WebApiResource): "custom_config": fields.Raw(attribute="custom_config"), } - @api.doc("Get App Site Info") - @api.doc(description="Retrieve app site information and configuration.") - @api.doc( + @web_ns.doc("Get App Site Info") + @web_ns.doc(description="Retrieve app site information and configuration.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -82,9 +83,6 @@ class AppSiteApi(WebApiResource): return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) -api.add_resource(AppSiteApi, "/site") - - class AppSiteInfo: """Class to store site information.""" diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 58a70d5961..9a980148d9 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -3,7 +3,7 @@ import logging from flask_restx import reqparse from werkzeug.exceptions import InternalServerError -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( CompletionRequestError, NotWorkflowAppError, @@ -30,16 +30,17 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) +@web_ns.route("/workflows/run") class WorkflowRunApi(WebApiResource): - @api.doc("Run Workflow") - @api.doc(description="Execute a workflow with provided inputs and files.") - @api.doc( + @web_ns.doc("Run Workflow") + @web_ns.doc(description="Execute a workflow with provided inputs and files.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the workflow", "type": "object", "required": True}, "files": {"description": "Files to be processed by the workflow", "type": "array", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -85,15 +86,16 @@ class WorkflowRunApi(WebApiResource): raise InternalServerError() +@web_ns.route("/workflows/tasks//stop") class WorkflowTaskStopApi(WebApiResource): - @api.doc("Stop Workflow Task") - @api.doc(description="Stop a running workflow task.") - @api.doc( + @web_ns.doc("Stop Workflow Task") + @web_ns.doc(description="Stop a running workflow task.") + @web_ns.doc( params={ "task_id": {"description": "Task ID to stop", "type": "string", "required": True}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -119,7 +121,3 @@ class WorkflowTaskStopApi(WebApiResource): GraphEngineManager.send_stop_command(task_id) return {"result": "success"} - - -api.add_resource(WorkflowRunApi, "/workflows/run") -api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 1fbb2c165f..ba03c4eae4 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,6 +1,7 @@ +from collections.abc import Callable from datetime import UTC, datetime from functools import wraps -from typing import ParamSpec, TypeVar +from typing import Concatenate, ParamSpec, TypeVar from flask import request from flask_restx import Resource @@ -20,12 +21,11 @@ P = ParamSpec("P") R = TypeVar("R") -def validate_jwt_token(view=None): - def decorator(view): +def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = None): + def decorator(view: Callable[Concatenate[App, EndUser, P], R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): app_model, end_user = decode_jwt_token() - return view(app_model, end_user, *args, **kwargs) return decorated diff --git a/api/core/__init__.py b/api/core/__init__.py index 6eaea7b1c8..e69de29bb2 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -1 +0,0 @@ -import core.moderation.base diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index a2ee2e57e8..c196dbbdf1 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,7 +1,7 @@ import json import logging import uuid -from typing import Optional, Union, cast +from typing import Union, cast from sqlalchemy import select @@ -60,8 +60,8 @@ class BaseAgentRunner(AppRunner): message: Message, user_id: str, model_instance: ModelInstance, - memory: Optional[TokenBufferMemory] = None, - prompt_messages: Optional[list[PromptMessage]] = None, + memory: TokenBufferMemory | None = None, + prompt_messages: list[PromptMessage] | None = None, ): self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -114,7 +114,7 @@ class BaseAgentRunner(AppRunner): features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] - self.query: Optional[str] = "" + self.query: str | None = "" self._current_thoughts: list[PromptMessage] = [] def _repack_app_generate_entity( diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index b94a60c40a..25ad6dc060 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional +from typing import Any from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit @@ -70,10 +70,12 @@ class CotAgentRunner(BaseAgentRunner, ABC): self._prompt_messages_tools = prompt_messages_tools function_call_state = True - llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} + llm_usage: dict[str, LLMUsage | None] = {"usage": None} final_answer = "" + prompt_messages: list = [] # Initialize prompt_messages + agent_thought_id = "" # Initialize agent_thought_id - def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: @@ -120,7 +122,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): callbacks=[], ) - usage_dict: dict[str, Optional[LLMUsage]] = {} + usage_dict: dict[str, LLMUsage | None] = {} react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( agent_response="", @@ -272,7 +274,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): action: AgentScratchpadUnit.Action, tool_instances: Mapping[str, Tool], message_file_ids: list[str], - trace_manager: Optional[TraceQueueManager] = None, + trace_manager: TraceQueueManager | None = None, ) -> tuple[str, ToolInvokeMeta]: """ handle invoke action diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 3a4d31e047..da9a001d84 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,5 +1,4 @@ import json -from typing import Optional from core.agent.cot_agent_runner import CotAgentRunner from core.model_runtime.entities.message_entities import ( @@ -31,7 +30,7 @@ class CotCompletionAgentRunner(CotAgentRunner): return system_prompt - def _organize_historic_prompt(self, current_session_messages: Optional[list[PromptMessage]] = None) -> str: + def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str: """ Organize historic prompt """ diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 816d2782f0..220feced1d 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, Field @@ -50,11 +50,11 @@ class AgentScratchpadUnit(BaseModel): "action_input": self.action_input, } - agent_response: Optional[str] = None - thought: Optional[str] = None - action_str: Optional[str] = None - observation: Optional[str] = None - action: Optional[Action] = None + agent_response: str | None = None + thought: str | None = None + action_str: str | None = None + observation: str | None = None + action: Action | None = None def is_final(self) -> bool: """ @@ -81,8 +81,8 @@ class AgentEntity(BaseModel): provider: str model: str strategy: Strategy - prompt: Optional[AgentPromptEntity] = None - tools: Optional[list[AgentToolEntity]] = None + prompt: AgentPromptEntity | None = None + tools: list[AgentToolEntity] | None = None max_iteration: int = 10 diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 9eb853aa74..dcc1326b33 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -2,7 +2,7 @@ import json import logging from collections.abc import Generator from copy import deepcopy -from typing import Any, Optional, Union +from typing import Any, Union from core.agent.base_agent_runner import BaseAgentRunner from core.app.apps.base_app_queue_manager import PublishFrom @@ -52,13 +52,14 @@ class FunctionCallAgentRunner(BaseAgentRunner): # continue to run until there is not any tool call function_call_state = True - llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} + llm_usage: dict[str, LLMUsage | None] = {"usage": None} final_answer = "" + prompt_messages: list = [] # Initialize prompt_messages # get tracing instance trace_manager = app_generate_entity.trace_manager - def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index a3438fc2c7..90aa7b5fd4 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -1,5 +1,5 @@ -import enum -from typing import Any, Optional +from enum import StrEnum +from typing import Any from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator @@ -26,25 +26,25 @@ class AgentStrategyProviderIdentity(ToolProviderIdentity): class AgentStrategyParameter(PluginParameter): - class AgentStrategyParameterType(enum.StrEnum): + class AgentStrategyParameterType(StrEnum): """ Keep all the types from PluginParameterType """ - STRING = CommonParameterType.STRING.value - NUMBER = CommonParameterType.NUMBER.value - BOOLEAN = CommonParameterType.BOOLEAN.value - SELECT = CommonParameterType.SELECT.value - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - FILE = CommonParameterType.FILE.value - FILES = CommonParameterType.FILES.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value - ANY = CommonParameterType.ANY.value + STRING = CommonParameterType.STRING + NUMBER = CommonParameterType.NUMBER + BOOLEAN = CommonParameterType.BOOLEAN + SELECT = CommonParameterType.SELECT + SECRET_INPUT = CommonParameterType.SECRET_INPUT + FILE = CommonParameterType.FILE + FILES = CommonParameterType.FILES + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR + ANY = CommonParameterType.ANY # deprecated, should not use. - SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + SYSTEM_FILES = CommonParameterType.SYSTEM_FILES def as_normal_type(self): return as_normal_type(self) @@ -53,7 +53,7 @@ class AgentStrategyParameter(PluginParameter): return cast_parameter_value(self, value) type: AgentStrategyParameterType = Field(..., description="The type of the parameter") - help: Optional[I18nObject] = None + help: I18nObject | None = None def init_frontend_parameter(self, value: Any): return init_frontend_parameter(self, self.type, value) @@ -61,7 +61,7 @@ class AgentStrategyParameter(PluginParameter): class AgentStrategyProviderEntity(BaseModel): identity: AgentStrategyProviderIdentity - plugin_id: Optional[str] = Field(None, description="The id of the plugin") + plugin_id: str | None = Field(None, description="The id of the plugin") class AgentStrategyIdentity(ToolIdentity): @@ -72,7 +72,7 @@ class AgentStrategyIdentity(ToolIdentity): pass -class AgentFeature(enum.StrEnum): +class AgentFeature(StrEnum): """ Agent Feature, used to describe the features of the agent strategy. """ @@ -84,9 +84,9 @@ class AgentStrategyEntity(BaseModel): identity: AgentStrategyIdentity parameters: list[AgentStrategyParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The description of the agent strategy") - output_schema: Optional[dict] = None - features: Optional[list[AgentFeature]] = None - meta_version: Optional[str] = None + output_schema: dict | None = None + features: list[AgentFeature] | None = None + meta_version: str | None = None # pydantic configs model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/agent/strategy/base.py b/api/core/agent/strategy/base.py index a52a1dfd7a..8a9be05dde 100644 --- a/api/core/agent/strategy/base.py +++ b/api/core/agent/strategy/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Sequence -from typing import Any, Optional +from typing import Any from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyParameter @@ -16,10 +16,10 @@ class BaseAgentStrategy(ABC): self, params: dict[str, Any], user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - credentials: Optional[InvokeCredentials] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. @@ -37,9 +37,9 @@ class BaseAgentStrategy(ABC): self, params: dict[str, Any], user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - credentials: Optional[InvokeCredentials] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, ) -> Generator[AgentInvokeMessage, None, None]: pass diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index 04661581a7..a3cc798352 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Sequence -from typing import Any, Optional +from typing import Any from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter @@ -38,10 +38,10 @@ class PluginAgentStrategy(BaseAgentStrategy): self, params: dict[str, Any], user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - credentials: Optional[InvokeCredentials] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 037037e6ca..e925d6dd52 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -1,12 +1,10 @@ -from typing import Optional - from core.app.app_config.entities import SensitiveWordAvoidanceEntity from core.moderation.factory import ModerationFactory class SensitiveWordAvoidanceConfigManager: @classmethod - def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: + def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None: sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None @@ -21,7 +19,7 @@ class SensitiveWordAvoidanceConfigManager: @classmethod def validate_and_set_defaults( - cls, tenant_id, config: dict, only_structure_validate: bool = False + cls, tenant_id: str, config: dict, only_structure_validate: bool = False ) -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = {"enabled": False} @@ -38,7 +36,14 @@ class SensitiveWordAvoidanceConfigManager: if not only_structure_validate: typ = config["sensitive_word_avoidance"]["type"] - sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] + if not isinstance(typ, str): + raise ValueError("sensitive_word_avoidance.type must be a string") + + sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config") + if sensitive_word_avoidance_config is None: + sensitive_word_avoidance_config = {} + if not isinstance(sensitive_word_avoidance_config, dict): + raise ValueError("sensitive_word_avoidance.config must be a dict") ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 8887d2500c..eab26e5af9 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -1,12 +1,10 @@ -from typing import Optional - from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity from core.agent.prompt.template import REACT_PROMPT_TEMPLATES class AgentConfigManager: @classmethod - def convert(cls, config: dict) -> Optional[AgentEntity]: + def convert(cls, config: dict) -> AgentEntity | None: """ Convert model config to model config diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index fcbf479e2e..4b824bde76 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,5 +1,4 @@ import uuid -from typing import Optional from core.app.app_config.entities import ( DatasetEntity, @@ -14,7 +13,7 @@ from services.dataset_service import DatasetService class DatasetConfigManager: @classmethod - def convert(cls, config: dict) -> Optional[DatasetEntity]: + def convert(cls, config: dict) -> DatasetEntity | None: """ Convert model config to model config diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index e6ab31e586..ec4f6074ab 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -25,10 +25,14 @@ class PromptTemplateConfigManager: if chat_prompt_config: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): + text = message.get("text") + if not isinstance(text, str): + raise ValueError("message text must be a string") + role = message.get("role") + if not isinstance(role, str): + raise ValueError("message role must be a string") chat_prompt_messages.append( - AdvancedChatMessageEntity( - **{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} - ) + AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role)) ) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) @@ -66,7 +70,7 @@ class PromptTemplateConfigManager: :param config: app model config args """ if not config.get("prompt_type"): - config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value + config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] if config["prompt_type"] not in prompt_type_vals: @@ -86,7 +90,7 @@ class PromptTemplateConfigManager: if not isinstance(config["completion_prompt_config"], dict): raise ValueError("completion_prompt_config must be of object type") - if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value: + if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED: if not config["chat_prompt_config"] and not config["completion_prompt_config"]: raise ValueError( "chat_prompt_config or completion_prompt_config is required when prompt_type is advanced" diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 895ee8581e..2ad81fe005 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,6 +1,6 @@ from collections.abc import Sequence -from enum import Enum, StrEnum -from typing import Any, Literal, Optional +from enum import StrEnum, auto +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator @@ -17,7 +17,7 @@ class ModelConfigEntity(BaseModel): provider: str model: str - mode: Optional[str] = None + mode: str | None = None parameters: dict[str, Any] = Field(default_factory=dict) stop: list[str] = Field(default_factory=list) @@ -53,7 +53,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel): assistant: str prompt: str - role_prefix: Optional[RolePrefixEntity] = None + role_prefix: RolePrefixEntity | None = None class PromptTemplateEntity(BaseModel): @@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel): Prompt Template Entity. """ - class PromptType(Enum): + class PromptType(StrEnum): """ Prompt Type. 'simple', 'advanced' """ - SIMPLE = "simple" - ADVANCED = "advanced" + SIMPLE = auto() + ADVANCED = auto() @classmethod def value_of(cls, value: str): @@ -84,9 +84,9 @@ class PromptTemplateEntity(BaseModel): raise ValueError(f"invalid prompt type value {value}") prompt_type: PromptType - simple_prompt_template: Optional[str] = None - advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None - advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None + simple_prompt_template: str | None = None + advanced_chat_prompt_template: AdvancedChatPromptTemplateEntity | None = None + advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None class VariableEntityType(StrEnum): @@ -112,7 +112,7 @@ class VariableEntity(BaseModel): type: VariableEntityType required: bool = False hide: bool = False - max_length: Optional[int] = None + max_length: int | None = None options: Sequence[str] = Field(default_factory=list) allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list) allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list) @@ -183,7 +183,7 @@ class ModelConfig(BaseModel): class Condition(BaseModel): """ - Conditon detail + Condition detail """ name: str @@ -196,8 +196,8 @@ class MetadataFilteringCondition(BaseModel): Metadata Filtering Condition. """ - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) class DatasetRetrieveConfigEntity(BaseModel): @@ -205,14 +205,14 @@ class DatasetRetrieveConfigEntity(BaseModel): Dataset Retrieve Config Entity. """ - class RetrieveStrategy(Enum): + class RetrieveStrategy(StrEnum): """ Dataset Retrieve Strategy. 'single' or 'multiple' """ - SINGLE = "single" - MULTIPLE = "multiple" + SINGLE = auto() + MULTIPLE = auto() @classmethod def value_of(cls, value: str): @@ -227,18 +227,18 @@ class DatasetRetrieveConfigEntity(BaseModel): return mode raise ValueError(f"invalid retrieve strategy value {value}") - query_variable: Optional[str] = None # Only when app mode is completion + query_variable: str | None = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy - top_k: Optional[int] = None - score_threshold: Optional[float] = 0.0 - rerank_mode: Optional[str] = "reranking_model" - reranking_model: Optional[dict] = None - weights: Optional[dict] = None - reranking_enabled: Optional[bool] = True - metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_model_config: Optional[ModelConfig] = None - metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None + top_k: int | None = None + score_threshold: float | None = 0.0 + rerank_mode: str | None = "reranking_model" + reranking_model: dict | None = None + weights: dict | None = None + reranking_enabled: bool | None = True + metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled" + metadata_model_config: ModelConfig | None = None + metadata_filtering_conditions: MetadataFilteringCondition | None = None class DatasetEntity(BaseModel): @@ -265,8 +265,8 @@ class TextToSpeechEntity(BaseModel): """ enabled: bool - voice: Optional[str] = None - language: Optional[str] = None + voice: str | None = None + language: str | None = None class TracingConfigEntity(BaseModel): @@ -279,15 +279,15 @@ class TracingConfigEntity(BaseModel): class AppAdditionalFeatures(BaseModel): - file_upload: Optional[FileUploadConfig] = None - opening_statement: Optional[str] = None + file_upload: FileUploadConfig | None = None + opening_statement: str | None = None suggested_questions: list[str] = [] suggested_questions_after_answer: bool = False show_retrieve_source: bool = False more_like_this: bool = False speech_to_text: bool = False - text_to_speech: Optional[TextToSpeechEntity] = None - trace_config: Optional[TracingConfigEntity] = None + text_to_speech: TextToSpeechEntity | None = None + trace_config: TracingConfigEntity | None = None class AppConfig(BaseModel): @@ -300,15 +300,15 @@ class AppConfig(BaseModel): app_mode: AppMode additional_features: Optional[AppAdditionalFeatures] = None variables: list[VariableEntity] = [] - sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None + sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None -class EasyUIBasedAppModelConfigFrom(Enum): +class EasyUIBasedAppModelConfigFrom(StrEnum): """ App Model Config From. """ - ARGS = "args" + ARGS = auto() APP_LATEST_CONFIG = "app-latest-config" CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config" @@ -323,7 +323,7 @@ class EasyUIBasedAppConfig(AppConfig): app_model_config_dict: dict model: ModelConfigEntity prompt_template: PromptTemplateEntity - dataset: Optional[DatasetEntity] = None + dataset: DatasetEntity | None = None external_data_variables: list[ExternalDataVariableEntity] = [] diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index d50be956d4..35fdb865ed 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -3,7 +3,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping -from typing import Any, Literal, Optional, Union, overload +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -390,7 +390,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): application_generate_entity: AdvancedChatAppGenerateEntity, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, - conversation: Optional[Conversation] = None, + conversation: Conversation | None = None, stream: bool = True, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 61fed3bae9..af8b7e4e17 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,7 +1,7 @@ import logging import time from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -249,7 +249,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): def query_app_annotations_to_reply( self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom - ) -> Optional[MessageAnnotation]: + ) -> MessageAnnotation | None: """ Query app annotations to reply :param app_record: app record diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 627f6b47ce..02ec96f209 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, Any] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, Any] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): - response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 726cf7e4d7..71588870fa 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,7 +4,7 @@ import time from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager from threading import Thread -from typing import Any, Optional, Union +from typing import Any, Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -169,7 +169,7 @@ class AdvancedChatAppGenerateTaskPipeline: generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._base_task_pipeline._stream: + if self._base_task_pipeline.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -228,7 +228,7 @@ class AdvancedChatAppGenerateTaskPipeline: return None def _wrapper_process_stream_response( - self, trace_manager: Optional[TraceQueueManager] = None + self, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id @@ -289,7 +289,7 @@ class AdvancedChatAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState: + def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState: """Fluent validation for graph runtime state.""" if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") @@ -297,13 +297,13 @@ class AdvancedChatAppGenerateTaskPipeline: def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: """Handle ping events.""" - yield self._base_task_pipeline._ping_stream_response() + yield self._base_task_pipeline.ping_stream_response() def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: """Handle error events.""" with self._database_session() as session: - err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id) - yield self._base_task_pipeline._error_to_stream_response(err) + err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]: """Handle workflow started events.""" @@ -404,8 +404,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueTextChunkEvent, *, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle text chunk events.""" @@ -505,8 +505,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowSucceededEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow succeeded events.""" @@ -536,8 +536,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowPartialSuccessEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow partial success events.""" @@ -568,8 +568,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowFailedEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow failed events.""" @@ -594,17 +594,17 @@ class AdvancedChatAppGenerateTaskPipeline: workflow_execution=workflow_execution, ) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) - err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id) + err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id) yield workflow_finish_resp - yield self._base_task_pipeline._error_to_stream_response(err) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_stop_event( self, event: QueueStopEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle stop events.""" @@ -644,13 +644,13 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueAdvancedChatMessageEndEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, + graph_runtime_state: GraphRuntimeState | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle advanced chat message end events.""" self._ensure_graph_runtime_initialized(graph_runtime_state) - output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( + output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished( self._task_state.answer ) if output_moderation_answer: @@ -740,10 +740,10 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: Any, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + graph_runtime_state: GraphRuntimeState | None = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, ) -> Generator[StreamResponse, None, None]: """Dispatch events using elegant pattern matching.""" handlers = self._get_event_handlers() @@ -782,15 +782,15 @@ class AdvancedChatAppGenerateTaskPipeline: def _process_stream_response( self, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response using elegant Fluent Python patterns. Maintains exact same functionality as original 57-if-statement version. """ # Initialize graph runtime state - graph_runtime_state: Optional[GraphRuntimeState] = None + graph_runtime_state: GraphRuntimeState | None = None for queue_message in self._base_task_pipeline.queue_manager.listen(): event = queue_message.event @@ -835,7 +835,7 @@ class AdvancedChatAppGenerateTaskPipeline: if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None): + def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None): message = self._get_message(session=session) # If there are assistant files, remove markdown image links from answer @@ -846,7 +846,7 @@ class AdvancedChatAppGenerateTaskPipeline: message.answer = answer_text message.updated_at = naive_utc_now() - message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at + message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at message.message_metadata = self._task_state.metadata.model_dump_json() message_files = [ MessageFile( @@ -902,9 +902,9 @@ class AdvancedChatAppGenerateTaskPipeline: :param text: text :return: True if output moderation should direct output, otherwise False """ - if self._base_task_pipeline._output_moderation_handler: - if self._base_task_pipeline._output_moderation_handler.should_direct_output(): - self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() + if self._base_task_pipeline.output_moderation_handler: + if self._base_task_pipeline.output_moderation_handler.should_direct_output(): + self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output() self._base_task_pipeline.queue_manager.publish( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) @@ -914,7 +914,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) return True else: - self._base_task_pipeline._output_moderation_handler.append_new_token(text) + self._base_task_pipeline.output_moderation_handler.append_new_token(text) return False diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 349b583833..9ce841f432 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Mapping -from typing import Any, Optional +from typing import Any, cast from core.agent.entities import AgentEntity from core.app.app_config.base_app_config_manager import BaseAppConfigManager @@ -30,7 +30,7 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): Agent Chatbot App Config Entity. """ - agent: Optional[AgentEntity] = None + agent: AgentEntity | None = None class AgentChatAppConfigManager(BaseAppConfigManager): @@ -39,8 +39,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager): cls, app_model: App, app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None, + conversation: Conversation | None = None, + override_config_dict: dict | None = None, ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config @@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): return filtered_config @classmethod - def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + def validate_agent_mode_and_set_defaults( + cls, tenant_id: str, config: dict[str, Any] + ) -> tuple[dict[str, Any], list[str]]: """ Validate agent_mode and set defaults for agent feature @@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager): if not config.get("agent_mode"): config["agent_mode"] = {"enabled": False, "tools": []} - if not isinstance(config["agent_mode"], dict): + agent_mode = config["agent_mode"] + if not isinstance(agent_mode, dict): raise ValueError("agent_mode must be of object type") - if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: - config["agent_mode"]["enabled"] = False + # FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing + agent_mode = cast(dict[str, Any], agent_mode) - if not isinstance(config["agent_mode"]["enabled"], bool): + if "enabled" not in agent_mode or not agent_mode["enabled"]: + agent_mode["enabled"] = False + + if not isinstance(agent_mode["enabled"], bool): raise ValueError("enabled in agent_mode must be of boolean type") - if not config["agent_mode"].get("strategy"): - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + if not agent_mode.get("strategy"): + agent_mode["strategy"] = PlanningStrategy.ROUTER.value - if config["agent_mode"]["strategy"] not in [ - member.value for member in list(PlanningStrategy.__members__.values()) - ]: + if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: raise ValueError("strategy in agent_mode must be in the specified strategy list") - if not config["agent_mode"].get("tools"): - config["agent_mode"]["tools"] = [] + if not agent_mode.get("tools"): + agent_mode["tools"] = [] - if not isinstance(config["agent_mode"]["tools"], list): + if not isinstance(agent_mode["tools"], list): raise ValueError("tools in agent_mode must be a list of objects") - for tool in config["agent_mode"]["tools"]: + for tool in agent_mode["tools"]: key = list(tool.keys())[0] if key in OLD_TOOLS: # old style, use tool name as key diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 89a5b8e3b5..e35e9d9408 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 19cfde3d5c..90a8040d81 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, Union, final +from typing import TYPE_CHECKING, Any, Union, final from sqlalchemy.orm import Session @@ -25,7 +25,7 @@ class BaseAppGenerator: def _prepare_user_inputs( self, *, - user_inputs: Optional[Mapping[str, Any]], + user_inputs: Mapping[str, Any] | None, variables: Sequence["VariableEntity"], tenant_id: str, strict_type_validation: bool = False, diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index cdaf68ce65..fdba952eeb 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -2,7 +2,7 @@ import queue import time from abc import abstractmethod from enum import IntEnum, auto -from typing import Any, Optional +from typing import Any from sqlalchemy.orm import DeclarativeMeta @@ -32,6 +32,7 @@ class AppQueueManager: self._task_id = task_id self._user_id = user_id self._invoke_from = invoke_from + self.invoke_from = invoke_from # Public accessor for invoke_from user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" redis_client.setex( @@ -115,7 +116,7 @@ class AppQueueManager: Set task stop flag :return: """ - result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id)) + result: Any | None = redis_client.get(cls._generate_task_belong_cache_key(task_id)) if result is None: return diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index dafdcdd429..e7db3bc41b 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,7 +1,7 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -82,11 +82,11 @@ class AppRunner: prompt_template_entity: PromptTemplateEntity, inputs: Mapping[str, str], files: Sequence["File"], - query: Optional[str] = None, - context: Optional[str] = None, - memory: Optional[TokenBufferMemory] = None, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + query: str | None = None, + context: str | None = None, + memory: TokenBufferMemory | None = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: """ Organize prompt messages :param context: @@ -161,7 +161,7 @@ class AppRunner: prompt_messages: list, text: str, stream: bool, - usage: Optional[LLMUsage] = None, + usage: LLMUsage | None = None, ): """ Direct output @@ -375,7 +375,7 @@ class AppRunner: def query_app_annotations_to_reply( self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom - ) -> Optional[MessageAnnotation]: + ) -> MessageAnnotation | None: """ Query app annotations to reply :param app_record: app record diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index 96a3db8502..4b6720a3c3 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -32,8 +30,8 @@ class ChatAppConfigManager(BaseAppConfigManager): cls, app_model: App, app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None, + conversation: Conversation | None = None, + override_config_dict: dict | None = None, ) -> ChatAppConfig: """ Convert app model config to chat app config diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 816d6d79a9..3aa1161fd8 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 1ece0d3d63..b1c7ae0df5 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -1,7 +1,7 @@ import time from collections.abc import Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Optional, Union +from typing import Any, Union from sqlalchemy.orm import Session @@ -135,7 +135,7 @@ class WorkflowResponseConverter: event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[NodeStartStreamResponse]: + ) -> NodeStartStreamResponse | None: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_execution_id: @@ -190,7 +190,7 @@ class WorkflowResponseConverter: event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[NodeFinishStreamResponse]: + ) -> NodeFinishStreamResponse | None: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_execution_id: @@ -235,7 +235,7 @@ class WorkflowResponseConverter: event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: + ) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_execution_id: diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 3a1f29689d..eb1902f12e 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -24,7 +22,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig): class CompletionAppConfigManager(BaseAppConfigManager): @classmethod def get_app_config( - cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None ) -> CompletionAppConfig: """ Convert app model config to completion app config diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 8485ce7519..843328f904 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator): raise MoreLikeThisDisabledError() app_model_config = message.app_model_config + if not app_model_config: + raise ValueError("Message app_model_config is None") override_model_config_dict = app_model_config.to_dict() model_dict = override_model_config_dict["model"] completion_params = model_dict.get("completion_params") diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 4d45c61145..d7e9ebdf24 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) + if not isinstance(metadata, dict): + metadata = {} sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 92f3b6507c..170c6a274b 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from typing import Optional, Union, cast +from typing import Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -84,7 +84,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): logger.exception("Failed to handle response, conversation_id: %s", conversation.id) raise e - def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: + def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig: if conversation: stmt = select(AppModelConfig).where( AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id @@ -112,7 +112,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity, ], - conversation: Optional[Conversation] = None, + conversation: Conversation | None = None, ) -> tuple[Conversation, Message]: """ Initialize generate records diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index ab738c48eb..45d047434b 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -425,6 +425,14 @@ class WorkflowAppGenerator(BaseAppGenerator): context: contextvars.Context, variable_loader: VariableLoader, ) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param workflow_thread_pool_id: workflow thread pool id + :return: + """ with preserve_flask_contexts(flask_app, context_vars=context): with Session(db.engine, expire_on_commit=False) as session: workflow = session.scalar( diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 210f6110b1..01ecf0298f 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): :param blocking_response: blocking response :return: """ - return dict(blocking_response.to_dict()) + return blocking_response.model_dump() @classmethod def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] @@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, object] = { "event": sub_stream_response.event.value, "workflow_run_id": chunk.workflow_run_id, } @@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, object] = { "event": sub_stream_response.event.value, "workflow_run_id": chunk.workflow_run_id, } @@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index ae7dd0dc28..56b0d91141 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Callable, Generator from contextlib import contextmanager -from typing import Optional, Union +from typing import Union from sqlalchemy.orm import Session @@ -133,7 +133,7 @@ class WorkflowAppGenerateTaskPipeline: self._application_generate_entity = application_generate_entity self._workflow_features_dict = workflow.features_dict self._workflow_run_id = "" - self._invoke_from = queue_manager._invoke_from + self._invoke_from = queue_manager.invoke_from self._draft_var_saver_factory = draft_var_saver_factory def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -142,7 +142,7 @@ class WorkflowAppGenerateTaskPipeline: :return: """ generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._base_task_pipeline._stream: + if self._base_task_pipeline.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -202,7 +202,7 @@ class WorkflowAppGenerateTaskPipeline: return None def _wrapper_process_stream_response( - self, trace_manager: Optional[TraceQueueManager] = None + self, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id @@ -264,7 +264,7 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState: + def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState: """Fluent validation for graph runtime state.""" if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") @@ -272,12 +272,12 @@ class WorkflowAppGenerateTaskPipeline: def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: """Handle ping events.""" - yield self._base_task_pipeline._ping_stream_response() + yield self._base_task_pipeline.ping_stream_response() def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: """Handle error events.""" - err = self._base_task_pipeline._handle_error(event=event) - yield self._base_task_pipeline._error_to_stream_response(err) + err = self._base_task_pipeline.handle_error(event=event) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_workflow_started_event( self, event: QueueWorkflowStartedEvent, **kwargs @@ -442,8 +442,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueWorkflowSucceededEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow succeeded events.""" @@ -476,8 +476,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueWorkflowPartialSuccessEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow partial success events.""" @@ -511,8 +511,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: Union[QueueWorkflowFailedEvent, QueueStopEvent], *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow failed and stop events.""" @@ -549,8 +549,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueTextChunkEvent, *, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle text chunk events.""" @@ -601,10 +601,10 @@ class WorkflowAppGenerateTaskPipeline: self, event: AppQueueEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + graph_runtime_state: GraphRuntimeState | None = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, ) -> Generator[StreamResponse, None, None]: """Dispatch events using elegant pattern matching.""" handlers = self._get_event_handlers() @@ -654,8 +654,8 @@ class WorkflowAppGenerateTaskPipeline: def _process_stream_response( self, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response using elegant Fluent Python patterns. @@ -722,7 +722,7 @@ class WorkflowAppGenerateTaskPipeline: session.commit() def _text_chunk_to_stream_response( - self, text: str, from_variable_selector: Optional[list[str]] = None + self, text: str, from_variable_selector: list[str] | None = None ) -> TextChunkStreamResponse: """ Handle completed event. diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 4f57ee1ff0..6ed596bfb8 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -99,8 +99,8 @@ class AppGenerateEntity(BaseModel): task_id: str # app config - app_config: Any - file_upload_config: Optional[FileUploadConfig] = None + app_config: Any = None + file_upload_config: FileUploadConfig | None = None inputs: Mapping[str, Any] files: Sequence[File] @@ -126,10 +126,10 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ # app config - app_config: EasyUIBasedAppConfig + app_config: EasyUIBasedAppConfig = None # type: ignore model_conf: ModelConfigWithCredentialsEntity - query: Optional[str] = None + query: str | None = None # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -140,8 +140,8 @@ class ConversationAppGenerateEntity(AppGenerateEntity): Base entity for conversation-based app generation. """ - conversation_id: Optional[str] = None - parent_message_id: Optional[str] = Field( + conversation_id: str | None = None + parent_message_id: str | None = Field( default=None, description=( "Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API." @@ -189,9 +189,9 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): """ # app config - app_config: WorkflowUIBasedAppConfig + app_config: WorkflowUIBasedAppConfig = None # type: ignore - workflow_run_id: Optional[str] = None + workflow_run_id: str | None = None query: str class SingleIterationRunEntity(BaseModel): @@ -202,7 +202,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): node_id: str inputs: Mapping - single_iteration_run: Optional[SingleIterationRunEntity] = None + single_iteration_run: SingleIterationRunEntity | None = None class SingleLoopRunEntity(BaseModel): """ @@ -212,7 +212,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): node_id: str inputs: Mapping - single_loop_run: Optional[SingleLoopRunEntity] = None + single_loop_run: SingleLoopRunEntity | None = None class WorkflowAppGenerateEntity(AppGenerateEntity): @@ -221,7 +221,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): """ # app config - app_config: WorkflowUIBasedAppConfig + app_config: WorkflowUIBasedAppConfig = None # type: ignore workflow_execution_id: str class SingleIterationRunEntity(BaseModel): @@ -232,7 +232,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): node_id: str inputs: dict - single_iteration_run: Optional[SingleIterationRunEntity] = None + single_iteration_run: SingleIterationRunEntity | None = None class SingleLoopRunEntity(BaseModel): """ diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 17d8df9bb9..376f52b3b5 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from datetime import datetime -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import Any, Optional from pydantic import BaseModel @@ -79,9 +79,9 @@ class QueueIterationStartEvent(AppQueueEvent): start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None + metadata: Mapping[str, Any] | None = None class QueueIterationNextEvent(AppQueueEvent): @@ -114,12 +114,12 @@ class QueueIterationCompletedEvent(AppQueueEvent): start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 - error: Optional[str] = None + error: str | None = None class QueueLoopStartEvent(AppQueueEvent): @@ -132,20 +132,20 @@ class QueueLoopStartEvent(AppQueueEvent): node_id: str node_type: NodeType node_title: str - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None + metadata: Mapping[str, Any] | None = None class QueueLoopNextEvent(AppQueueEvent): @@ -160,15 +160,15 @@ class QueueLoopNextEvent(AppQueueEvent): node_id: str node_type: NodeType node_title: str - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration run in parallel mode run id""" node_run_index: int output: Optional[Any] = None # output for the current loop @@ -187,21 +187,21 @@ class QueueLoopCompletedEvent(AppQueueEvent): node_title: str parallel_id: Optional[str] = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 - error: Optional[str] = None + error: str | None = None class QueueTextChunkEvent(AppQueueEvent): @@ -211,11 +211,11 @@ class QueueTextChunkEvent(AppQueueEvent): event: QueueEvent = QueueEvent.TEXT_CHUNK text: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None """from variable selector""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -252,9 +252,9 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES retriever_resources: Sequence[RetrievalSourceMetadata] - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -273,7 +273,7 @@ class QueueMessageEndEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.MESSAGE_END - llm_result: Optional[LLMResult] = None + llm_result: LLMResult | None = None class QueueAdvancedChatMessageEndEvent(AppQueueEvent): @@ -299,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None class QueueWorkflowFailedEvent(AppQueueEvent): @@ -319,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent): event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED exceptions_count: int - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None class QueueNodeStartedEvent(AppQueueEvent): @@ -362,22 +362,22 @@ class QueueNodeSucceededEvent(AppQueueEvent): node_type: NodeType parallel_id: Optional[str] = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: Optional[str] = None @@ -391,11 +391,11 @@ class QueueAgentLogEvent(AppQueueEvent): id: str label: str node_execution_id: str - parent_id: str | None - error: str | None + parent_id: str | None = None + error: str | None = None status: str data: Mapping[str, Any] - metadata: Optional[Mapping[str, Any]] = None + metadata: Mapping[str, Any] | None = None node_id: str @@ -404,10 +404,10 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent): event: QueueEvent = QueueEvent.RETRY - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str retry_index: int # retry index @@ -425,22 +425,22 @@ class QueueNodeExceptionEvent(AppQueueEvent): node_type: NodeType parallel_id: Optional[str] = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -458,14 +458,14 @@ class QueueNodeFailedEvent(AppQueueEvent): parallel_id: Optional[str] = None in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -494,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.ERROR - error: Optional[Any] = None + error: Any | None = None class QueuePingEvent(AppQueueEvent): @@ -510,15 +510,15 @@ class QueueStopEvent(AppQueueEvent): QueueStopEvent entity """ - class StopBy(Enum): + class StopBy(StrEnum): """ Stop by enum """ - USER_MANUAL = "user-manual" - ANNOTATION_REPLY = "annotation-reply" - OUTPUT_MODERATION = "output-moderation" - INPUT_MODERATION = "input-moderation" + USER_MANUAL = auto() + ANNOTATION_REPLY = auto() + OUTPUT_MODERATION = auto() + INPUT_MODERATION = auto() event: QueueEvent = QueueEvent.STOP stopped_by: StopBy diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 09e2db603c..c940a9ab6f 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,11 +1,10 @@ from collections.abc import Mapping, Sequence -from enum import Enum +from enum import StrEnum from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -51,7 +50,7 @@ class WorkflowTaskState(TaskState): answer: str = "" -class StreamEvent(Enum): +class StreamEvent(StrEnum): """ Stream event """ @@ -90,9 +89,6 @@ class StreamResponse(BaseModel): event: StreamEvent task_id: str - def to_dict(self): - return jsonable_encoder(self) - class ErrorStreamResponse(StreamResponse): """ @@ -112,7 +108,7 @@ class MessageStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE id: str answer: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None class MessageAudioStreamResponse(StreamResponse): @@ -141,7 +137,7 @@ class MessageEndStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE_END id: str metadata: dict = Field(default_factory=dict) - files: Optional[Sequence[Mapping[str, Any]]] = None + files: Sequence[Mapping[str, Any]] | None = None class MessageFileStreamResponse(StreamResponse): @@ -174,12 +170,12 @@ class AgentThoughtStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.AGENT_THOUGHT id: str position: int - thought: Optional[str] = None - observation: Optional[str] = None - tool: Optional[str] = None - tool_labels: Optional[dict] = None - tool_input: Optional[str] = None - message_files: Optional[list[str]] = None + thought: str | None = None + observation: str | None = None + tool: str | None = None + tool_labels: dict | None = None + tool_input: str | None = None + message_files: list[str] | None = None class AgentMessageStreamResponse(StreamResponse): @@ -225,16 +221,16 @@ class WorkflowFinishStreamResponse(StreamResponse): id: str workflow_id: str status: str - outputs: Optional[Mapping[str, Any]] = None - error: Optional[str] = None + outputs: Mapping[str, Any] | None = None + error: str | None = None elapsed_time: float total_tokens: int total_steps: int - created_by: Optional[dict] = None + created_by: dict | None = None created_at: int finished_at: int - exceptions_count: Optional[int] = 0 - files: Optional[Sequence[Mapping[str, Any]]] = [] + exceptions_count: int | None = 0 + files: Sequence[Mapping[str, Any]] | None = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED workflow_run_id: str @@ -261,14 +257,14 @@ class NodeStartStreamResponse(StreamResponse): inputs_truncated: bool = False created_at: int extras: dict = Field(default_factory=dict) - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None - parallel_run_id: Optional[str] = None - agent_strategy: Optional[AgentNodeStrategyInit] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None + parallel_run_id: str | None = None + agent_strategy: AgentNodeStrategyInit | None = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -322,18 +318,18 @@ class NodeFinishStreamResponse(StreamResponse): outputs: Optional[Mapping[str, Any]] = None outputs_truncated: bool = True status: str - error: Optional[str] = None + error: str | None = None elapsed_time: float - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None created_at: int finished_at: int - files: Optional[Sequence[Mapping[str, Any]]] = [] - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + files: Sequence[Mapping[str, Any]] | None = [] + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None event: StreamEvent = StreamEvent.NODE_FINISHED workflow_run_id: str @@ -394,18 +390,18 @@ class NodeRetryStreamResponse(StreamResponse): outputs: Optional[Mapping[str, Any]] = None outputs_truncated: bool = False status: str - error: Optional[str] = None + error: str | None = None elapsed_time: float - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None created_at: int finished_at: int - files: Optional[Sequence[Mapping[str, Any]]] = [] - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + files: Sequence[Mapping[str, Any]] | None = [] + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None retry_index: int = 0 event: StreamEvent = StreamEvent.NODE_RETRY @@ -514,10 +510,10 @@ class IterationNodeCompletedStreamResponse(StreamResponse): inputs: Optional[Mapping] = None inputs_truncated: bool = False status: WorkflowNodeExecutionStatus - error: Optional[str] = None + error: str | None = None elapsed_time: float total_tokens: int - execution_metadata: Optional[Mapping] = None + execution_metadata: Mapping | None = None finished_at: int steps: int @@ -569,7 +565,7 @@ class LoopNodeNextStreamResponse(StreamResponse): title: str index: int created_at: int - pre_loop_output: Optional[Any] = None + pre_loop_output: Any | None = None extras: dict = Field(default_factory=dict) parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None @@ -601,14 +597,14 @@ class LoopNodeCompletedStreamResponse(StreamResponse): inputs: Optional[Mapping] = None inputs_truncated: bool = False status: WorkflowNodeExecutionStatus - error: Optional[str] = None + error: str | None = None elapsed_time: float total_tokens: int - execution_metadata: Optional[Mapping] = None + execution_metadata: Mapping | None = None finished_at: int steps: int - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.LOOP_COMPLETED workflow_run_id: str @@ -626,7 +622,7 @@ class TextChunkStreamResponse(StreamResponse): """ text: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None event: StreamEvent = StreamEvent.TEXT_CHUNK data: Data @@ -688,7 +684,7 @@ class WorkflowAppStreamResponse(AppStreamResponse): WorkflowAppStreamResponse entity """ - workflow_run_id: Optional[str] = None + workflow_run_id: str | None = None class AppBlockingResponse(BaseModel): @@ -698,9 +694,6 @@ class AppBlockingResponse(BaseModel): task_id: str - def to_dict(self): - return jsonable_encoder(self) - class ChatbotAppBlockingResponse(AppBlockingResponse): """ @@ -756,8 +749,8 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): id: str workflow_id: str status: str - outputs: Optional[Mapping[str, Any]] = None - error: Optional[str] = None + outputs: Mapping[str, Any] | None = None + error: str | None = None elapsed_time: float total_tokens: int total_steps: int @@ -781,11 +774,11 @@ class AgentLogStreamResponse(StreamResponse): node_execution_id: str id: str label: str - parent_id: str | None - error: str | None + parent_id: str | None = None + error: str | None = None status: str data: Mapping[str, Any] - metadata: Optional[Mapping[str, Any]] = None + metadata: Mapping[str, Any] | None = None node_id: str event: StreamEvent = StreamEvent.AGENT_LOG diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index be183e2086..79fbafe39e 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from sqlalchemy import select @@ -17,7 +16,7 @@ logger = logging.getLogger(__name__) class AnnotationReplyFeature: def query( self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom - ) -> Optional[MessageAnnotation]: + ) -> MessageAnnotation | None: """ Query app annotations to reply :param app_record: app record @@ -35,6 +34,9 @@ class AnnotationReplyFeature: collection_binding_detail = annotation_setting.collection_binding_detail + if not collection_binding_detail: + return None + try: score_threshold = annotation_setting.score_threshold or 1 embedding_provider_name = collection_binding_detail.provider_name diff --git a/api/core/app/features/rate_limiting/__init__.py b/api/core/app/features/rate_limiting/__init__.py index 6624f6ad9d..4ad33acd0f 100644 --- a/api/core/app/features/rate_limiting/__init__.py +++ b/api/core/app/features/rate_limiting/__init__.py @@ -1 +1,3 @@ from .rate_limit import RateLimit + +__all__ = ["RateLimit"] diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index f526d2a16a..ffa10cd43c 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -3,7 +3,7 @@ import time import uuid from collections.abc import Generator, Mapping from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Union from core.errors.error import AppInvokeQuotaExceededError from extensions.ext_redis import redis_client @@ -19,7 +19,7 @@ class RateLimit: _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict: dict[str, "RateLimit"] = {} - def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): + def __new__(cls, client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: instance = super().__new__(cls) cls._instance_dict[client_id] = instance @@ -63,7 +63,7 @@ class RateLimit: if timeout_requests: redis_client.hdel(self.active_requests_key, *timeout_requests) - def enter(self, request_id: Optional[str] = None) -> str: + def enter(self, request_id: str | None = None) -> str: if self.disabled(): return RateLimit._UNLIMITED_REQUEST_ID if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL: diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 7d98cceb1a..45e3c0006b 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,5 @@ import logging import time -from typing import Optional from sqlalchemy import select from sqlalchemy.orm import Session @@ -38,11 +37,11 @@ class BasedGenerateTaskPipeline: ): self._application_generate_entity = application_generate_entity self.queue_manager = queue_manager - self._start_at = time.perf_counter() - self._output_moderation_handler = self._init_output_moderation() - self._stream = stream + self.start_at = time.perf_counter() + self.output_moderation_handler = self._init_output_moderation() + self.stream = stream - def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): + def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): logger.debug("error: %s", event.error) e = event.error err: Exception @@ -86,7 +85,7 @@ class BasedGenerateTaskPipeline: return message - def _error_to_stream_response(self, e: Exception): + def error_to_stream_response(self, e: Exception): """ Error to stream response. :param e: exception @@ -94,14 +93,14 @@ class BasedGenerateTaskPipeline: """ return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) - def _ping_stream_response(self) -> PingStreamResponse: + def ping_stream_response(self) -> PingStreamResponse: """ Ping stream response. :return: """ return PingStreamResponse(task_id=self._application_generate_entity.task_id) - def _init_output_moderation(self) -> Optional[OutputModeration]: + def _init_output_moderation(self) -> OutputModeration | None: """ Init output moderation. :return: @@ -118,21 +117,21 @@ class BasedGenerateTaskPipeline: ) return None - def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: + def handle_output_moderation_when_task_finished(self, completion: str) -> str | None: """ Handle output moderation when task finished. :param completion: completion :return: """ # response moderation - if self._output_moderation_handler: - self._output_moderation_handler.stop_thread() + if self.output_moderation_handler: + self.output_moderation_handler.stop_thread() - completion, flagged = self._output_moderation_handler.moderation_completion( + completion, flagged = self.output_moderation_handler.moderation_completion( completion=completion, public_event=False ) - self._output_moderation_handler = None + self.output_moderation_handler = None if flagged: return completion diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 0dad0a5a9d..67abb569e3 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Generator from threading import Thread -from typing import Optional, Union, cast +from typing import Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -109,7 +109,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): task_state=self._task_state, ) - self._conversation_name_generate_thread: Optional[Thread] = None + self._conversation_name_generate_thread: Thread | None = None def process( self, @@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._stream: + if self.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if self._task_state.metadata: extras["metadata"] = self._task_state.metadata.model_dump() response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] - if self._conversation_mode == AppMode.COMPLETION.value: + if self._conversation_mode == AppMode.COMPLETION: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=CompletionAppBlockingResponse.Data( @@ -209,7 +209,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): return None def _wrapper_process_stream_response( - self, trace_manager: Optional[TraceQueueManager] = None + self, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id @@ -252,7 +252,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None + self, publisher: AppGeneratorTTSPublisher | None, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if isinstance(event, QueueErrorEvent): with Session(db.engine) as session: - err = self._handle_error(event=event, session=session, message_id=self._message_id) + err = self.handle_error(event=event, session=session, message_id=self._message_id) session.commit() - yield self._error_to_stream_response(err) + yield self.error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): @@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._handle_stop(event) # handle output moderation - output_moderation_answer = self._handle_output_moderation_when_task_finished( + output_moderation_answer = self.handle_output_moderation_when_task_finished( cast(str, self._task_state.llm_result.message.content) ) if output_moderation_answer: @@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): elif isinstance(event, QueueMessageReplaceEvent): yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self.ping_stream_response() else: continue if publisher: @@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None): + def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None): """ Save message. :return: @@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): message.answer_tokens = usage.completion_tokens message.answer_unit_price = usage.completion_unit_price message.answer_price_unit = usage.completion_price_unit - message.provider_response_latency = time.perf_counter() - self._start_at + message.provider_response_latency = time.perf_counter() - self.start_at message.total_price = usage.total_price message.currency = usage.currency self._task_state.llm_result.usage.latency = message.provider_response_latency @@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): # transform usage model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) - self._task_state.llm_result.usage = model_type_instance._calc_response_usage( + self._task_state.llm_result.usage = model_type_instance.calc_response_usage( model, credentials, prompt_tokens, completion_tokens ) @@ -466,14 +466,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): task_id=self._application_generate_entity.task_id, id=message_id, answer=answer ) - def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]: + def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> AgentThoughtStreamResponse | None: """ Agent thought to stream response. :param event: agent thought event :return: """ with Session(db.engine, expire_on_commit=False) as session: - agent_thought: Optional[MessageAgentThought] = ( + agent_thought: MessageAgentThought | None = ( session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() ) @@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): :param text: text :return: True if output moderation should direct output, otherwise False """ - if self._output_moderation_handler: - if self._output_moderation_handler.should_direct_output(): + if self.output_moderation_handler: + if self.output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output - self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() + self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output() self.queue_manager.publish( QueueLLMChunkEvent( chunk=LLMResultChunk( @@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) return True else: - self._output_moderation_handler.append_new_token(text) + self.output_moderation_handler.append_new_token(text) return False diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index e865ba9d60..90ffdcf1f6 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -1,6 +1,6 @@ import logging from threading import Thread -from typing import Optional, Union +from typing import Union from flask import Flask, current_app from sqlalchemy import select @@ -52,7 +52,7 @@ class MessageCycleManager: self._application_generate_entity = application_generate_entity self._task_state = task_state - def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: + def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None: """ Generate conversation name. :param conversation_id: conversation id @@ -92,7 +92,7 @@ class MessageCycleManager: if not conversation: return - if conversation.mode != AppMode.COMPLETION.value: + if conversation.mode != AppMode.COMPLETION: app_model = conversation.app if not app_model: return @@ -111,7 +111,7 @@ class MessageCycleManager: db.session.commit() db.session.close() - def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: + def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> MessageAnnotation | None: """ Handle annotation reply. :param event: event @@ -141,7 +141,7 @@ class MessageCycleManager: if self._application_generate_entity.app_config.additional_features.show_retrieve_source: self._task_state.metadata.retriever_resources = event.retriever_resources - def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: + def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None: """ Message file to stream response. :param event: event @@ -180,7 +180,7 @@ class MessageCycleManager: return None def message_to_stream_response( - self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None + self, answer: str, message_id: str, from_variable_selector: list[str] | None = None ) -> MessageStreamResponse: """ Message to stream response. diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 4e6422e2df..1e0fba6215 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -5,7 +5,6 @@ import queue import re import threading from collections.abc import Iterable -from typing import Optional from core.app.entities.queue_entities import ( MessageQueueMessage, @@ -56,7 +55,7 @@ def _process_future( class AppGeneratorTTSPublisher: - def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None): + def __init__(self, tenant_id: str, voice: str, language: str | None = None): self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id self.msg_text = "" @@ -72,8 +71,8 @@ class AppGeneratorTTSPublisher: self.voice = voice if not voice or voice not in values: self.voice = self.voices[0].get("value") - self.MAX_SENTENCE = 2 - self._last_audio_event: Optional[AudioTrunk] = None + self.max_sentence = 2 + self._last_audio_event: AudioTrunk | None = None # FIXME better way to handle this threading.start threading.Thread(target=self._runtime).start() self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) @@ -113,8 +112,8 @@ class AppGeneratorTTSPublisher: self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) - if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): - self.MAX_SENTENCE += 1 + if len(sentence_arr) >= min(self.max_sentence, 7): + self.max_sentence += 1 text_content = "".join(sentence_arr) futures_result = self.executor.submit( _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 9036d561d6..6591b08a7e 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -1,5 +1,5 @@ from collections.abc import Iterable, Mapping -from typing import Any, Optional, TextIO, Union +from typing import Any, TextIO, Union from pydantic import BaseModel @@ -23,7 +23,7 @@ def get_colored_text(text: str, color: str) -> str: return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" -def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None): +def print_text(text: str, color: str | None = None, end: str = "", file: TextIO | None = None): """Print text with highlighting and no end characters.""" text_to_print = get_colored_text(text, color) if color else text print(text_to_print, end=end, file=file) @@ -34,10 +34,10 @@ def print_text(text: str, color: Optional[str] = None, end: str = "", file: Opti class DifyAgentCallbackHandler(BaseModel): """Callback Handler that prints to std out.""" - color: Optional[str] = "" + color: str | None = "" current_loop: int = 1 - def __init__(self, color: Optional[str] = None): + def __init__(self, color: str | None = None): super().__init__() """Initialize callback handler.""" # use a specific color is not specified @@ -58,9 +58,9 @@ class DifyAgentCallbackHandler(BaseModel): tool_name: str, tool_inputs: Mapping[str, Any], tool_outputs: Iterable[ToolInvokeMessage] | str, - message_id: Optional[str] = None, - timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None, + message_id: str | None = None, + timer: Any | None = None, + trace_manager: TraceQueueManager | None = None, ): """If not the final action, print out observation.""" if dify_config.DEBUG: @@ -98,7 +98,7 @@ class DifyAgentCallbackHandler(BaseModel): else: print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) - def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any): + def on_agent_finish(self, color: str | None = None, **kwargs: Any): """Run on agent end.""" if dify_config.DEBUG: print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) diff --git a/api/core/callback_handler/workflow_tool_callback_handler.py b/api/core/callback_handler/workflow_tool_callback_handler.py index 350b18772b..23aabd9970 100644 --- a/api/core/callback_handler/workflow_tool_callback_handler.py +++ b/api/core/callback_handler/workflow_tool_callback_handler.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Iterable, Mapping -from typing import Any, Optional +from typing import Any from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text from core.ops.ops_trace_manager import TraceQueueManager @@ -14,9 +14,9 @@ class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): tool_name: str, tool_inputs: Mapping[str, Any], tool_outputs: Iterable[ToolInvokeMessage], - message_id: Optional[str] = None, - timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None, + message_id: str | None = None, + timer: Any | None = None, + trace_manager: TraceQueueManager | None = None, ) -> Generator[ToolInvokeMessage, None, None]: for tool_output in tool_outputs: print_text("\n[on_tool_execution]\n", color=self.color) diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py index 656bf4aa72..cf958b91d2 100644 --- a/api/core/entities/agent_entities.py +++ b/api/core/entities/agent_entities.py @@ -1,8 +1,8 @@ -from enum import Enum +from enum import StrEnum, auto -class PlanningStrategy(Enum): - ROUTER = "router" - REACT_ROUTER = "react_router" - REACT = "react" - FUNCTION_CALL = "function_call" +class PlanningStrategy(StrEnum): + ROUTER = auto() + REACT_ROUTER = auto() + REACT = auto() + FUNCTION_CALL = auto() diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py index 9b4934646b..89b48fd2ef 100644 --- a/api/core/entities/embedding_type.py +++ b/api/core/entities/embedding_type.py @@ -1,10 +1,10 @@ -from enum import Enum +from enum import StrEnum, auto -class EmbeddingInputType(Enum): +class EmbeddingInputType(StrEnum): """ Enum for embedding input type. """ - DOCUMENT = "document" - QUERY = "query" + DOCUMENT = auto() + QUERY = auto() diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index 63fce06005..33e1f64579 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel class PreviewDetail(BaseModel): content: str - child_chunks: Optional[list[str]] = None + child_chunks: list[str] | None = None class QAPreviewDetail(BaseModel): diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 0fd49b059c..663a8164c6 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,6 +1,5 @@ from collections.abc import Sequence -from enum import Enum -from typing import Optional +from enum import StrEnum, auto from pydantic import BaseModel, ConfigDict @@ -9,16 +8,16 @@ from core.model_runtime.entities.model_entities import ModelType, ProviderModel from core.model_runtime.entities.provider_entities import ProviderEntity -class ModelStatus(Enum): +class ModelStatus(StrEnum): """ Enum class for model status. """ - ACTIVE = "active" + ACTIVE = auto() NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" NO_PERMISSION = "no-permission" - DISABLED = "disabled" + DISABLED = auto() CREDENTIAL_REMOVED = "credential-removed" @@ -29,8 +28,8 @@ class SimpleModelProviderEntity(BaseModel): provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None supported_model_types: list[ModelType] def __init__(self, provider_entity: ProviderEntity): @@ -92,8 +91,8 @@ class DefaultModelProviderEntity(BaseModel): provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] = [] diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index fbd62437e6..0afb51edce 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -1,20 +1,20 @@ -from enum import StrEnum +from enum import StrEnum, auto class CommonParameterType(StrEnum): SECRET_INPUT = "secret-input" TEXT_INPUT = "text-input" - SELECT = "select" - STRING = "string" - NUMBER = "number" - FILE = "file" - FILES = "files" + SELECT = auto() + STRING = auto() + NUMBER = auto() + FILE = auto() + FILES = auto() SYSTEM_FILES = "system-files" - BOOLEAN = "boolean" + BOOLEAN = auto() APP_SELECTOR = "app-selector" MODEL_SELECTOR = "model-selector" TOOLS_SELECTOR = "array[tools]" - ANY = "any" + ANY = auto() # Dynamic select parameter # Once you are not sure about the available options until authorization is done @@ -23,29 +23,29 @@ class CommonParameterType(StrEnum): # TOOL_SELECTOR = "tool-selector" # MCP object and array type parameters - ARRAY = "array" - OBJECT = "object" + ARRAY = auto() + OBJECT = auto() class AppSelectorScope(StrEnum): - ALL = "all" - CHAT = "chat" - WORKFLOW = "workflow" - COMPLETION = "completion" + ALL = auto() + CHAT = auto() + WORKFLOW = auto() + COMPLETION = auto() class ModelSelectorScope(StrEnum): - LLM = "llm" + LLM = auto() TEXT_EMBEDDING = "text-embedding" - RERANK = "rerank" - TTS = "tts" - SPEECH2TEXT = "speech2text" - MODERATION = "moderation" - VISION = "vision" + RERANK = auto() + TTS = auto() + SPEECH2TEXT = auto() + MODERATION = auto() + VISION = auto() class ToolSelectorScope(StrEnum): - ALL = "all" - CUSTOM = "custom" - BUILTIN = "builtin" - WORKFLOW = "workflow" + ALL = auto() + CUSTOM = auto() + BUILTIN = auto() + WORKFLOW = auto() diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 47732216dd..de3b0964ff 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -4,7 +4,6 @@ import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from typing import Optional from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import func, select @@ -92,7 +91,7 @@ class ProviderConfiguration(BaseModel): ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) - def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: + def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ Get current credentials. @@ -165,7 +164,7 @@ class ProviderConfiguration(BaseModel): return credentials - def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]: + def get_system_configuration_status(self) -> SystemConfigurationStatus | None: """ Get system configuration status. :return: @@ -794,9 +793,7 @@ class ProviderConfiguration(BaseModel): stmt = stmt.where(ProviderModelCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None - def get_custom_model_credential( - self, model_type: ModelType, model: str, credential_id: str | None - ) -> Optional[dict]: + def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None: """ Get custom model credentials. @@ -1274,7 +1271,7 @@ class ProviderConfiguration(BaseModel): return model_setting - def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]: + def get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None: """ Get provider model setting. :param model_type: model type @@ -1451,7 +1448,7 @@ class ProviderConfiguration(BaseModel): def get_provider_model( self, model_type: ModelType, model: str, only_active: bool = False - ) -> Optional[ModelWithProviderEntity]: + ) -> ModelWithProviderEntity | None: """ Get provider model. :param model_type: model type @@ -1468,7 +1465,7 @@ class ProviderConfiguration(BaseModel): return None def get_provider_models( - self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None + self, model_type: ModelType | None = None, only_active: bool = False, model: str | None = None ) -> list[ModelWithProviderEntity]: """ Get provider models. @@ -1652,7 +1649,7 @@ class ProviderConfiguration(BaseModel): model_types: Sequence[ModelType], provider_schema: ProviderEntity, model_setting_map: dict[ModelType, dict[str, ModelSettings]], - model: Optional[str] = None, + model: str | None = None, ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -1786,7 +1783,7 @@ class ProviderConfigurations(BaseModel): super().__init__(tenant_id=tenant_id) def get_models( - self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False + self, provider: str | None = None, model_type: ModelType | None = None, only_active: bool = False ) -> list[ModelWithProviderEntity]: """ Get available models. @@ -1843,8 +1840,14 @@ class ProviderConfigurations(BaseModel): def __setitem__(self, key, value): self.configurations[key] = value + def __contains__(self, key): + if "/" not in key: + key = str(ModelProviderID(key)) + return key in self.configurations + def __iter__(self): - return iter(self.configurations) + # Return an iterator of (key, value) tuples to match BaseModel's __iter__ + yield from self.configurations.items() def values(self) -> Iterator[ProviderConfiguration]: return iter(self.configurations.values()) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 9b8baf1973..0496959ce2 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,5 +1,5 @@ -from enum import Enum -from typing import Optional, Union +from enum import StrEnum, auto +from typing import Union from pydantic import BaseModel, ConfigDict, Field @@ -13,14 +13,14 @@ from core.model_runtime.entities.model_entities import ModelType from core.tools.entities.common_entities import I18nObject -class ProviderQuotaType(Enum): - PAID = "paid" +class ProviderQuotaType(StrEnum): + PAID = auto() """hosted paid quota""" - FREE = "free" + FREE = auto() """third-party free quota""" - TRIAL = "trial" + TRIAL = auto() """hosted trial quota""" @staticmethod @@ -31,25 +31,25 @@ class ProviderQuotaType(Enum): raise ValueError(f"No matching enum found for value '{value}'") -class QuotaUnit(Enum): - TIMES = "times" - TOKENS = "tokens" - CREDITS = "credits" +class QuotaUnit(StrEnum): + TIMES = auto() + TOKENS = auto() + CREDITS = auto() -class SystemConfigurationStatus(Enum): +class SystemConfigurationStatus(StrEnum): """ Enum class for system configuration status. """ - ACTIVE = "active" + ACTIVE = auto() QUOTA_EXCEEDED = "quota-exceeded" - UNSUPPORTED = "unsupported" + UNSUPPORTED = auto() class RestrictModel(BaseModel): model: str - base_model_name: Optional[str] = None + base_model_name: str | None = None model_type: ModelType # pydantic configs @@ -84,9 +84,9 @@ class SystemConfiguration(BaseModel): """ enabled: bool - current_quota_type: Optional[ProviderQuotaType] = None + current_quota_type: ProviderQuotaType | None = None quota_configurations: list[QuotaConfiguration] = [] - credentials: Optional[dict] = None + credentials: dict | None = None class CustomProviderConfiguration(BaseModel): @@ -95,8 +95,8 @@ class CustomProviderConfiguration(BaseModel): """ credentials: dict - current_credential_id: Optional[str] = None - current_credential_name: Optional[str] = None + current_credential_id: str | None = None + current_credential_name: str | None = None available_credentials: list[CredentialConfiguration] = [] @@ -107,11 +107,11 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType - credentials: dict | None - current_credential_id: Optional[str] = None - current_credential_name: Optional[str] = None + credentials: dict | None = None + current_credential_id: str | None = None + current_credential_name: str | None = None available_model_credentials: list[CredentialConfiguration] = [] - unadded_to_model_list: Optional[bool] = False + unadded_to_model_list: bool | None = False # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -131,7 +131,7 @@ class CustomConfiguration(BaseModel): Model class for provider custom configuration. """ - provider: Optional[CustomProviderConfiguration] = None + provider: CustomProviderConfiguration | None = None models: list[CustomModelConfiguration] = [] can_added_models: list[UnaddedModelConfiguration] = [] @@ -168,14 +168,14 @@ class BasicProviderConfig(BaseModel): Base model class for common provider settings like credentials """ - class Type(Enum): - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - TEXT_INPUT = CommonParameterType.TEXT_INPUT.value - SELECT = CommonParameterType.SELECT.value - BOOLEAN = CommonParameterType.BOOLEAN.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + class Type(StrEnum): + SECRET_INPUT = CommonParameterType.SECRET_INPUT + TEXT_INPUT = CommonParameterType.TEXT_INPUT + SELECT = CommonParameterType.SELECT + BOOLEAN = CommonParameterType.BOOLEAN + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR @classmethod def value_of(cls, value: str) -> "ProviderConfig.Type": @@ -205,12 +205,12 @@ class ProviderConfig(BasicProviderConfig): scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None required: bool = False - default: Optional[Union[int, str, float, bool]] = None - options: Optional[list[Option]] = None - label: Optional[I18nObject] = None - help: Optional[I18nObject] = None - url: Optional[str] = None - placeholder: Optional[I18nObject] = None + default: Union[int, str, float, bool] | None = None + options: list[Option] | None = None + label: I18nObject | None = None + help: I18nObject | None = None + url: str | None = None + placeholder: I18nObject | None = None def to_basic_provider_config(self) -> BasicProviderConfig: return BasicProviderConfig(type=self.type, name=self.name) diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 642f24a411..8c1ba98ae1 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -1,12 +1,9 @@ -from typing import Optional - - class LLMError(ValueError): """Base class for all LLM exceptions.""" - description: Optional[str] = None + description: str | None = None - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index eee914a529..c2789a7a35 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -1,10 +1,10 @@ -import enum import importlib.util import json import logging import os +from enum import StrEnum, auto from pathlib import Path -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -13,18 +13,18 @@ from core.helper.position_helper import sort_to_dict_by_position_map logger = logging.getLogger(__name__) -class ExtensionModule(enum.Enum): - MODERATION = "moderation" - EXTERNAL_DATA_TOOL = "external_data_tool" +class ExtensionModule(StrEnum): + MODERATION = auto() + EXTERNAL_DATA_TOOL = auto() class ModuleExtension(BaseModel): - extension_class: Optional[Any] = None + extension_class: Any | None = None name: str - label: Optional[dict] = None - form_schema: Optional[list] = None + label: dict | None = None + form_schema: list | None = None builtin: bool = True - position: Optional[int] = None + position: int | None = None class Extensible: @@ -32,9 +32,9 @@ class Extensible: name: str tenant_id: str - config: Optional[dict] = None + config: dict | None = None - def __init__(self, tenant_id: str, config: Optional[dict] = None): + def __init__(self, tenant_id: str, config: dict | None = None): self.tenant_id = tenant_id self.config = config diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 45878e763f..564801f189 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -1,5 +1,3 @@ -from typing import Optional - from sqlalchemy import select from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor @@ -39,7 +37,7 @@ class ApiExternalDataTool(ExternalDataTool): if not api_based_extension: raise ValueError("api_based_extension_id is invalid") - def query(self, inputs: dict, query: Optional[str] = None) -> str: + def query(self, inputs: dict, query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index 81f1aaf174..cbec2e4e42 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional from core.extension.extensible import Extensible, ExtensionModule @@ -16,7 +15,7 @@ class ExternalDataTool(Extensible, ABC): variable: str """the tool variable name of app tool""" - def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None): + def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None): super().__init__(tenant_id, config) self.app_id = app_id self.variable = variable @@ -34,7 +33,7 @@ class ExternalDataTool(Extensible, ABC): raise NotImplementedError @abstractmethod - def query(self, inputs: dict, query: Optional[str] = None) -> str: + def query(self, inputs: dict, query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 6a9703a569..86bbb7060c 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -1,7 +1,7 @@ import logging from collections.abc import Mapping from concurrent.futures import Future, ThreadPoolExecutor, as_completed -from typing import Any, Optional +from typing import Any from flask import Flask, current_app @@ -63,7 +63,7 @@ class ExternalDataFetch: external_data_tool: ExternalDataVariableEntity, inputs: Mapping[str, Any], query: str, - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """ Query external data tool. :param flask_app: flask app diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 538bc3f525..6c542d681b 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from core.extension.extensible import ExtensionModule from extensions.ext_code_based_extension import code_based_extension @@ -26,7 +26,7 @@ class ExternalDataToolFactory: # FIXME mypy issue here, figure out how to fix it extension_class.validate_config(tenant_id, config) # type: ignore - def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str: + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/file/constants.py b/api/core/file/constants.py index ed1779fd13..0665ed7e0d 100644 --- a/api/core/file/constants.py +++ b/api/core/file/constants.py @@ -9,7 +9,3 @@ FILE_MODEL_IDENTITY = "__dify__file__" def maybe_file_object(o: Any) -> bool: return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY - - -# The default user ID for service API calls. -DEFAULT_SERVICE_API_USER_ID = "DEFAULT-USER" diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index cd3e94f798..ce2ece48e1 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -138,9 +138,9 @@ def _get_encoded_string(f: File, /): response.raise_for_status() data = response.content case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f._storage_key) + data = _download_file_content(f.storage_key) case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f._storage_key) + data = _download_file_content(f.storage_key) encoded_string = base64.b64encode(data).decode("utf-8") return encoded_string diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py index 7cb5d0f2da..fffda3d5fa 100644 --- a/api/core/file/helpers.py +++ b/api/core/file/helpers.py @@ -6,7 +6,6 @@ import time import urllib.parse from configs import dify_config -from core.file.constants import DEFAULT_SERVICE_API_USER_ID def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str: @@ -30,10 +29,6 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, # Plugin access should use internal URL for Docker network communication base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL url = f"{base_url}/files/upload/for-plugin" - - if user_id is None: - user_id = DEFAULT_SERVICE_API_USER_ID - timestamp = str(int(time.time())) nonce = os.urandom(16).hex() key = dify_config.SECRET_KEY.encode() @@ -45,11 +40,8 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, def verify_plugin_file_signature( - *, filename: str, mimetype: str, tenant_id: str, user_id: str | None, timestamp: str, nonce: str, sign: str + *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str ) -> bool: - if user_id is None: - user_id = DEFAULT_SERVICE_API_USER_ID - data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" secret_key = dify_config.SECRET_KEY.encode() recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() diff --git a/api/core/file/models.py b/api/core/file/models.py index 1c6d00614c..990d3fe91d 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field, model_validator @@ -26,7 +26,7 @@ class FileUploadConfig(BaseModel): File Upload Entity. """ - image_config: Optional[ImageConfig] = None + image_config: ImageConfig | None = None allowed_file_types: Sequence[FileType] = Field(default_factory=list) allowed_file_extensions: Sequence[str] = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) @@ -38,21 +38,21 @@ class File(BaseModel): # new and old data formats during serialization and deserialization. dify_model_identity: str = FILE_MODEL_IDENTITY - id: Optional[str] = None # message file id + id: str | None = None # message file id tenant_id: str type: FileType transfer_method: FileTransferMethod # If `transfer_method` is `FileTransferMethod.remote_url`, the # `remote_url` attribute must not be `None`. - remote_url: Optional[str] = None # remote url + remote_url: str | None = None # remote url # If `transfer_method` is `FileTransferMethod.local_file` or # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`. # # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. - related_id: Optional[str] = None - filename: Optional[str] = None - extension: Optional[str] = Field(default=None, description="File extension, should contain dot") - mime_type: Optional[str] = None + related_id: str | None = None + filename: str | None = None + extension: str | None = Field(default=None, description="File extension, should contain dot") + mime_type: str | None = None size: int = -1 # Those properties are private, should not be exposed to the outside. @@ -61,19 +61,19 @@ class File(BaseModel): def __init__( self, *, - id: Optional[str] = None, + id: str | None = None, tenant_id: str, type: FileType, transfer_method: FileTransferMethod, - remote_url: Optional[str] = None, - related_id: Optional[str] = None, - filename: Optional[str] = None, - extension: Optional[str] = None, - mime_type: Optional[str] = None, + remote_url: str | None = None, + related_id: str | None = None, + filename: str | None = None, + extension: str | None = None, + mime_type: str | None = None, size: int = -1, - storage_key: Optional[str] = None, - dify_model_identity: Optional[str] = FILE_MODEL_IDENTITY, - url: Optional[str] = None, + storage_key: str | None = None, + dify_model_identity: str | None = FILE_MODEL_IDENTITY, + url: str | None = None, ): super().__init__( id=id, @@ -108,7 +108,7 @@ class File(BaseModel): return text - def generate_url(self) -> Optional[str]: + def generate_url(self) -> str | None: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.remote_url elif self.transfer_method == FileTransferMethod.LOCAL_FILE: @@ -146,3 +146,11 @@ class File(BaseModel): if not self.related_id: raise ValueError("Missing file related_id") return self + + @property + def storage_key(self) -> str: + return self._storage_key + + @storage_key.setter + def storage_key(self, value: str): + self._storage_key = value diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 2b580cb373..c44a8e1840 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -2,7 +2,7 @@ import logging from collections.abc import Mapping from enum import StrEnum from threading import Lock -from typing import Any, Optional +from typing import Any from httpx import Timeout, post from pydantic import BaseModel @@ -24,8 +24,8 @@ class CodeExecutionError(Exception): class CodeExecutionResponse(BaseModel): class Data(BaseModel): - stdout: Optional[str] = None - error: Optional[str] = None + stdout: str | None = None + error: str | None = None code: int message: str diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 1c112007cb..00fcfe0b80 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -1,12 +1,11 @@ import json -from enum import Enum +from enum import StrEnum from json import JSONDecodeError -from typing import Optional from extensions.ext_redis import redis_client -class ProviderCredentialsCacheType(Enum): +class ProviderCredentialsCacheType(StrEnum): PROVIDER = "provider" MODEL = "provider_model" LOAD_BALANCING_MODEL = "load_balancing_provider_model" @@ -14,9 +13,9 @@ class ProviderCredentialsCacheType(Enum): class ProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): - self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """ Get cached model provider credentials. diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index 314f052832..2fc8fbf885 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -1,12 +1,14 @@ import os from collections import OrderedDict from collections.abc import Callable +from functools import lru_cache from typing import TypeVar from configs import dify_config -from core.tools.utils.yaml_utils import load_yaml_file +from core.tools.utils.yaml_utils import load_yaml_file_cached +@lru_cache(maxsize=128) def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]: """ Get the mapping from name to index from a YAML file @@ -14,12 +16,17 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> :param file_name: the YAML file name, default to '_position.yaml' :return: a dict with name as key and index as value """ + # FIXME(-LAN-): Cache position maps to prevent file descriptor exhaustion during high-load benchmarks position_file_path = os.path.join(folder_path, file_name) - yaml_content = load_yaml_file(file_path=position_file_path, default_value=[]) + try: + yaml_content = load_yaml_file_cached(file_path=position_file_path) + except Exception: + yaml_content = [] positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()] return {name: index for index, name in enumerate(positions)} +@lru_cache(maxsize=128) def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: """ Get the mapping for tools from name to index from a YAML file. @@ -35,20 +42,6 @@ def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") - ) -def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: - """ - Get the mapping for providers from name to index from a YAML file. - :param folder_path: - :param file_name: the YAML file name, default to '_position.yaml' - :return: a dict with name as key and index as value - """ - position_map = get_position_map(folder_path, file_name=file_name) - return pin_position_map( - position_map, - pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, - ) - - def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]: """ Pin the items in the pin list to the beginning of the position map. diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index 26e738fced..ffb5148386 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod from json import JSONDecodeError -from typing import Any, Optional +from typing import Any from extensions.ext_redis import redis_client @@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC): """Generate cache key based on subclass implementation""" pass - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """Get cached provider credentials""" cached_credentials = redis_client.get(self.cache_key) if cached_credentials: @@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache): class NoOpProviderCredentialCache: """No-op provider credential cache""" - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """Get cached provider credentials""" return None diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index efeba9e5ee..cbb78939d2 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -13,18 +13,18 @@ logger = logging.getLogger(__name__) SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES -HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True +http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True try: - HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY - http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower() + config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY + http_request_node_ssl_verify_lower = str(config_value).lower() if http_request_node_ssl_verify_lower == "true": - HTTP_REQUEST_NODE_SSL_VERIFY = True + http_request_node_ssl_verify = True elif http_request_node_ssl_verify_lower == "false": - HTTP_REQUEST_NODE_SSL_VERIFY = False + http_request_node_ssl_verify = False else: raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'") except NameError: - HTTP_REQUEST_NODE_SSL_VERIFY = True + http_request_node_ssl_verify = True BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] @@ -51,7 +51,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): ) if "ssl_verify" not in kwargs: - kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY + kwargs["ssl_verify"] = http_request_node_ssl_verify ssl_verify = kwargs.pop("ssl_verify") diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index 95a1086ca8..54674d4ff6 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -1,12 +1,11 @@ import json -from enum import Enum +from enum import StrEnum from json import JSONDecodeError -from typing import Optional from extensions.ext_redis import redis_client -class ToolParameterCacheType(Enum): +class ToolParameterCacheType(StrEnum): PARAMETER = "tool_parameter" @@ -15,11 +14,11 @@ class ToolParameterCache: self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str ): self.cache_key = ( - f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" + f"{cache_type}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" f":identity_id:{identity_id}" ) - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """ Get cached model provider credentials. diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index 35e6e292d1..820502e558 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -1,7 +1,7 @@ import contextlib import re from collections.abc import Mapping -from typing import Any, Optional +from typing import Any def is_valid_trace_id(trace_id: str) -> bool: @@ -13,7 +13,7 @@ def is_valid_trace_id(trace_id: str) -> bool: return bool(re.match(r"^[a-zA-Z0-9\-_]{1,128}$", trace_id)) -def get_external_trace_id(request: Any) -> Optional[str]: +def get_external_trace_id(request: Any) -> str | None: """ Retrieve the trace_id from the request. @@ -61,7 +61,7 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]): return {} -def get_trace_id_from_otel_context() -> Optional[str]: +def get_trace_id_from_otel_context() -> str | None: """ Retrieve the current trace ID from the active OpenTelemetry trace context. Returns None if: @@ -88,7 +88,7 @@ def get_trace_id_from_otel_context() -> Optional[str]: return None -def parse_traceparent_header(traceparent: str) -> Optional[str]: +def parse_traceparent_header(traceparent: str) -> str | None: """ Parse the `traceparent` header to extract the trace_id. diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index a5d7f7aac7..af860a1070 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,5 +1,3 @@ -from typing import Optional - from flask import Flask from pydantic import BaseModel @@ -30,8 +28,8 @@ class FreeHostingQuota(HostingQuota): class HostingProvider(BaseModel): enabled: bool = False - credentials: Optional[dict] = None - quota_unit: Optional[QuotaUnit] = None + credentials: dict | None = None + quota_unit: QuotaUnit | None = None quotas: list[HostingQuota] = [] @@ -42,7 +40,7 @@ class HostedModerationConfig(BaseModel): class HostingConfiguration: provider_map: dict[str, HostingProvider] - moderation_config: Optional[HostedModerationConfig] = None + moderation_config: HostedModerationConfig | None = None def __init__(self): self.provider_map = {} diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 9f7255a72b..ee37024260 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,7 +5,7 @@ import re import threading import time import uuid -from typing import Any, Optional +from typing import Any from flask import current_app from sqlalchemy import select @@ -230,9 +230,9 @@ class IndexingRunner: tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, - doc_form: Optional[str] = None, + doc_form: str | None = None, doc_language: str = "English", - dataset_id: Optional[str] = None, + dataset_id: str | None = None, indexing_technique: str = "economy", ) -> IndexingEstimate: """ @@ -422,7 +422,7 @@ class IndexingRunner: max_tokens: int, chunk_overlap: int, separator: str, - embedding_model_instance: Optional[ModelInstance], + embedding_model_instance: ModelInstance | None, ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. @@ -530,6 +530,7 @@ class IndexingRunner: # chunk nodes by chunk size indexing_start_at = time.perf_counter() tokens = 0 + create_keyword_thread = None if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": # create keyword index create_keyword_thread = threading.Thread( @@ -568,7 +569,11 @@ class IndexingRunner: for future in futures: tokens += future.result() - if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": + if ( + dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX + and dataset.indexing_technique == "economy" + and create_keyword_thread is not None + ): create_keyword_thread.join() indexing_end_at = time.perf_counter() @@ -651,7 +656,7 @@ class IndexingRunner: @staticmethod def _update_document_index_status( - document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None + document_id: str, after_indexing_status: str, extra_update_params: dict | None = None ): """ Update the document indexing status. diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 427bb64a3b..6eb91d515c 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -2,7 +2,7 @@ import json import logging import re from collections.abc import Sequence -from typing import Optional, cast +from typing import cast import json_repair @@ -20,7 +20,7 @@ from core.llm_generator.prompts import ( ) from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.entities.trace_entity import TraceTaskName @@ -39,7 +39,7 @@ logger = logging.getLogger(__name__) class LLMGenerator: @classmethod def generate_conversation_name( - cls, tenant_id: str, query, conversation_id: Optional[str] = None, app_id: Optional[str] = None + cls, tenant_id: str, query, conversation_id: str | None = None, app_id: str | None = None ): prompt = CONVERSATION_TITLE_PROMPT @@ -315,14 +315,20 @@ class LLMGenerator: model_type=ModelType.LLM, ) - prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] + prompt_messages: list[PromptMessage] = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] - response: LLMResult = model_instance.invoke_llm( + # Explicitly use the non-streaming overload + result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False, ) + # Runtime type check since pyright has issues with the overload + if not isinstance(result, LLMResult): + raise TypeError("Expected LLMResult when stream=False") + response = result + answer = cast(str, response.message.content) return answer.strip() diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 28833fe8e8..1e302b7668 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -2,7 +2,7 @@ import json from collections.abc import Generator, Mapping, Sequence from copy import deepcopy from enum import StrEnum -from typing import Any, Literal, Optional, cast, overload +from typing import Any, Literal, cast, overload import json_repair from pydantic import TypeAdapter, ValidationError @@ -45,64 +45,62 @@ class SpecialModelType(StrEnum): @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, - stream: Literal[True] = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + stop: list[str] | None = None, + stream: Literal[True], + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, - stream: Literal[False] = False, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + stop: list[str] | None = None, + stream: Literal[False], + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput: ... - - @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: """ Invoke large language model with structured output @@ -168,7 +166,7 @@ def invoke_llm_with_structured_output( def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]: result_text: str = "" prompt_messages: Sequence[PromptMessage] = [] - system_fingerprint: Optional[str] = None + system_fingerprint: str | None = None for event in llm_result: if isinstance(event, LLMResultChunk): prompt_messages = event.prompt_messages diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 5626849edf..7d938a8a7d 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -4,7 +4,6 @@ import json import os import secrets import urllib.parse -from typing import Optional from urllib.parse import urljoin, urlparse import httpx @@ -122,7 +121,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: return False, "" -def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]: +def discover_oauth_metadata(server_url: str, protocol_version: str | None = None) -> OAuthMetadata | None: """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" # First check if the server supports OAuth 2.0 Resource Discovery support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url) @@ -152,7 +151,7 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N def start_authorization( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_information: OAuthClientInformation, redirect_url: str, provider_id: str, @@ -207,7 +206,7 @@ def start_authorization( def exchange_authorization( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_information: OAuthClientInformation, authorization_code: str, code_verifier: str, @@ -242,7 +241,7 @@ def exchange_authorization( def refresh_authorization( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_information: OAuthClientInformation, refresh_token: str, ) -> OAuthTokens: @@ -273,7 +272,7 @@ def refresh_authorization( def register_client( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, ) -> OAuthClientInformationFull: """Performs OAuth 2.0 Dynamic Client Registration.""" @@ -297,8 +296,8 @@ def register_client( def auth( provider: OAuthClientProvider, server_url: str, - authorization_code: Optional[str] = None, - state_param: Optional[str] = None, + authorization_code: str | None = None, + state_param: str | None = None, for_list: bool = False, ) -> dict[str, str]: """Orchestrates the full auth flow with a server using secure Redis state storage.""" diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index bf1820f744..3a550eb1b6 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -1,5 +1,3 @@ -from typing import Optional - from configs import dify_config from core.mcp.types import ( OAuthClientInformation, @@ -37,7 +35,7 @@ class OAuthClientProvider: client_uri="https://github.com/langgenius/dify", ) - def client_information(self) -> Optional[OAuthClientInformation]: + def client_information(self) -> OAuthClientInformation | None: """Loads information about this OAuth client.""" client_information = self.mcp_provider.decrypted_credentials.get("client_information", {}) if not client_information: @@ -51,7 +49,7 @@ class OAuthClientProvider: {"client_information": client_information.model_dump()}, ) - def tokens(self) -> Optional[OAuthTokens]: + def tokens(self) -> OAuthTokens | None: """Loads any existing OAuth tokens for the current session.""" credentials = self.mcp_provider.decrypted_credentials if not credentials: diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index cc4263c0aa..6db22a09e0 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -23,13 +23,13 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3 @final class _StatusReady: def __init__(self, endpoint_url: str): - self._endpoint_url = endpoint_url + self.endpoint_url = endpoint_url @final class _StatusError: def __init__(self, exc: Exception): - self._exc = exc + self.exc = exc # Type aliases for better readability @@ -211,9 +211,9 @@ class SSETransport: raise ValueError("failed to get endpoint URL") if isinstance(status, _StatusReady): - return status._endpoint_url + return status.endpoint_url elif isinstance(status, _StatusError): - raise status._exc + raise status.exc else: raise ValueError("failed to get endpoint URL") diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index 1012dc2810..86ec2c4db9 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -2,7 +2,7 @@ import logging from collections.abc import Callable from contextlib import AbstractContextManager, ExitStack from types import TracebackType -from typing import Any, Optional +from typing import Any from urllib.parse import urlparse from core.mcp.client.sse_client import sse_client @@ -21,11 +21,11 @@ class MCPClient: provider_id: str, tenant_id: str, authed: bool = True, - authorization_code: Optional[str] = None, + authorization_code: str | None = None, for_list: bool = False, - headers: Optional[dict[str, str]] = None, - timeout: Optional[float] = None, - sse_read_timeout: Optional[float] = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + sse_read_timeout: float | None = None, ): # Initialize info self.provider_id = provider_id @@ -46,9 +46,9 @@ class MCPClient: self.token = self.provider.tokens() # Initialize session and client objects - self._session: Optional[ClientSession] = None - self._streams_context: Optional[AbstractContextManager[Any]] = None - self._session_context: Optional[ClientSession] = None + self._session: ClientSession | None = None + self._streams_context: AbstractContextManager[Any] | None = None + self._session_context: ClientSession | None = None self._exit_stack = ExitStack() # Whether the client has been initialized @@ -59,9 +59,7 @@ class MCPClient: self._initialized = True return self - def __exit__( - self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType] - ): + def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None): self.cleanup() def _initialize( diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 3d51ac2333..212c2eb073 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -38,6 +38,7 @@ def handle_mcp_request( """ request_type = type(request.root) + request_root = request.root def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse: """Create success response with business result data""" @@ -58,21 +59,20 @@ def handle_mcp_request( error=error_data, ) - # Request handler mapping using functional approach - request_handlers = { - mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description), - mcp_types.ListToolsRequest: lambda: handle_list_tools( - app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict - ), - mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user), - mcp_types.PingRequest: lambda: handle_ping(), - } - try: - # Dispatch request to appropriate handler - handler = request_handlers.get(request_type) - if handler: - return create_success_response(handler()) + # Dispatch request to appropriate handler based on instance type + if isinstance(request_root, mcp_types.InitializeRequest): + return create_success_response(handle_initialize(mcp_server.description)) + elif isinstance(request_root, mcp_types.ListToolsRequest): + return create_success_response( + handle_list_tools( + app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict + ) + ) + elif isinstance(request_root, mcp_types.CallToolRequest): + return create_success_response(handle_call_tool(app, request, user_input_form, end_user)) + elif isinstance(request_root, mcp_types.PingRequest): + return create_success_response(handle_ping()) else: return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}") @@ -142,7 +142,7 @@ def handle_call_tool( end_user, args, InvokeFrom.SERVICE_API, - streaming=app.mode == AppMode.AGENT_CHAT.value, + streaming=app.mode == AppMode.AGENT_CHAT, ) answer = extract_answer_from_response(app, response) @@ -157,7 +157,7 @@ def build_parameter_schema( """Build parameter schema for the tool""" parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict) - if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}: + if app_mode in {AppMode.COMPLETION, AppMode.WORKFLOW}: return { "type": "object", "properties": parameters, @@ -175,9 +175,9 @@ def build_parameter_schema( def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]: """Prepare arguments based on app mode""" - if app.mode == AppMode.WORKFLOW.value: + if app.mode == AppMode.WORKFLOW: return {"inputs": arguments} - elif app.mode == AppMode.COMPLETION.value: + elif app.mode == AppMode.COMPLETION: return {"query": "", "inputs": arguments} else: # Chat modes - create a copy to avoid modifying original dict @@ -218,13 +218,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str: def process_mapping_response(app: App, response: Mapping) -> str: """Process mapping response based on app mode""" if app.mode in { - AppMode.ADVANCED_CHAT.value, - AppMode.COMPLETION.value, - AppMode.CHAT.value, - AppMode.AGENT_CHAT.value, + AppMode.ADVANCED_CHAT, + AppMode.COMPLETION, + AppMode.CHAT, + AppMode.AGENT_CHAT, }: return response.get("answer", "") - elif app.mode == AppMode.WORKFLOW.value: + elif app.mode == AppMode.WORKFLOW: return json.dumps(response["data"]["outputs"], ensure_ascii=False) else: raise ValueError("Invalid app mode: " + str(app.mode)) diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 96c48034c7..653b3773c0 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -4,7 +4,7 @@ from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from datetime import timedelta from types import TracebackType -from typing import Any, Generic, Optional, Self, TypeVar +from typing import Any, Generic, Self, TypeVar from httpx import HTTPStatusError from pydantic import BaseModel @@ -81,7 +81,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): self.request_meta = request_meta self.request = request self._session = session - self._completed = False + self.completed = False self._on_complete = on_complete self._entered = False # Track if we're in a context manager @@ -98,7 +98,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): ): """Exit the context manager, performing cleanup and notifying completion.""" try: - if self._completed: + if self.completed: self._on_complete(self) finally: self._entered = False @@ -113,9 +113,9 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """ if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") - assert not self._completed, "Request already responded to" + assert not self.completed, "Request already responded to" - self._completed = True + self.completed = True self._session._send_response(request_id=self.request_id, response=response) @@ -124,7 +124,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") - self._completed = True # Mark as completed so it's removed from in_flight + self.completed = True # Mark as completed so it's removed from in_flight # Send an error response to indicate cancellation self._session._send_response( request_id=self.request_id, @@ -212,7 +212,7 @@ class BaseSession( request: SendRequestT, result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, - metadata: Optional[MessageMetadata] = None, + metadata: MessageMetadata | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -351,7 +351,7 @@ class BaseSession( self._in_flight[responder.request_id] = responder self._received_request(responder) - if not responder._completed: + if not responder.completed: self._handle_incoming(responder) elif isinstance(message.message.root, JSONRPCNotification): diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 49aa8e4498..7399e8a4b6 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -5,7 +5,6 @@ from typing import ( Any, Generic, Literal, - Optional, TypeAlias, TypeVar, ) @@ -809,7 +808,7 @@ class LoggingMessageNotificationParams(NotificationParams): """The severity of this log message.""" logger: str | None = None """An optional name of the logger issuing this message.""" - data: Any + data: Any = None """ The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. @@ -1173,45 +1172,45 @@ class SessionMessage: """A message with specific metadata for transport-specific features.""" message: JSONRPCMessage - metadata: Optional[MessageMetadata] = None + metadata: MessageMetadata | None = None class OAuthClientMetadata(BaseModel): client_name: str redirect_uris: list[str] - grant_types: Optional[list[str]] = None - response_types: Optional[list[str]] = None - token_endpoint_auth_method: Optional[str] = None - client_uri: Optional[str] = None - scope: Optional[str] = None + grant_types: list[str] | None = None + response_types: list[str] | None = None + token_endpoint_auth_method: str | None = None + client_uri: str | None = None + scope: str | None = None class OAuthClientInformation(BaseModel): client_id: str - client_secret: Optional[str] = None + client_secret: str | None = None class OAuthClientInformationFull(OAuthClientInformation): client_name: str | None = None redirect_uris: list[str] - scope: Optional[str] = None - grant_types: Optional[list[str]] = None - response_types: Optional[list[str]] = None - token_endpoint_auth_method: Optional[str] = None + scope: str | None = None + grant_types: list[str] | None = None + response_types: list[str] | None = None + token_endpoint_auth_method: str | None = None class OAuthTokens(BaseModel): access_token: str token_type: str - expires_in: Optional[int] = None - refresh_token: Optional[str] = None - scope: Optional[str] = None + expires_in: int | None = None + refresh_token: str | None = None + scope: str | None = None class OAuthMetadata(BaseModel): authorization_endpoint: str token_endpoint: str - registration_endpoint: Optional[str] = None + registration_endpoint: str | None = None response_types_supported: list[str] - grant_types_supported: Optional[list[str]] = None - code_challenge_methods_supported: Optional[list[str]] = None + grant_types_supported: list[str] | None = None + code_challenge_methods_supported: list[str] | None = None diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 7be695812a..35af742f2a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from typing import Optional from sqlalchemy import select @@ -32,11 +31,16 @@ class TokenBufferMemory: self.model_instance = model_instance def _build_prompt_message_with_files( - self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool + self, + message_files: Sequence[MessageFile], + text_content: str, + message: Message, + app_record, + is_user_message: bool, ) -> PromptMessage: """ Build prompt message with files. - :param message_files: list of MessageFile objects + :param message_files: Sequence of MessageFile objects :param text_content: text content of the message :param message: Message object :param app_record: app record @@ -91,7 +95,7 @@ class TokenBufferMemory: return AssistantPromptMessage(content=prompt_message_contents) def get_history_prompt_messages( - self, max_token_limit: int = 2000, message_limit: Optional[int] = None + self, max_token_limit: int = 2000, message_limit: int | None = None ) -> Sequence[PromptMessage]: """ Get history prompt messages. @@ -128,14 +132,12 @@ class TokenBufferMemory: prompt_messages: list[PromptMessage] = [] for message in messages: # Process user message with files - user_files = ( - db.session.query(MessageFile) - .where( + user_files = db.session.scalars( + select(MessageFile).where( MessageFile.message_id == message.id, (MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)), ) - .all() - ) + ).all() if user_files: user_prompt_message = self._build_prompt_message_with_files( @@ -150,11 +152,9 @@ class TokenBufferMemory: prompt_messages.append(UserPromptMessage(content=message.query)) # Process assistant message with files - assistant_files = ( - db.session.query(MessageFile) - .where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant") - .all() - ) + assistant_files = db.session.scalars( + select(MessageFile).where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant") + ).all() if assistant_files: assistant_prompt_message = self._build_prompt_message_with_files( @@ -186,7 +186,7 @@ class TokenBufferMemory: human_prefix: str = "Human", ai_prefix: str = "Assistant", max_token_limit: int = 2000, - message_limit: Optional[int] = None, + message_limit: int | None = None, ) -> str: """ Get history prompt text. diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 10df2ad79e..a63e94d59c 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -103,47 +103,47 @@ class ModelInstance: def invoke_llm( self, prompt_messages: Sequence[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: Literal[True] = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Generator: ... @overload def invoke_llm( self, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: Literal[False] = False, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResult: ... @overload def invoke_llm( self, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: ... def invoke_llm( self, prompt_messages: Sequence[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[Sequence[str]] = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -176,7 +176,7 @@ class ModelInstance: ) def get_llm_num_tokens( - self, prompt_messages: Sequence[PromptMessage], tools: Optional[Sequence[PromptMessageTool]] = None + self, prompt_messages: Sequence[PromptMessage], tools: Sequence[PromptMessageTool] | None = None ) -> int: """ Get number of tokens for llm @@ -199,7 +199,7 @@ class ModelInstance: ) def invoke_text_embedding( - self, texts: list[str], user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT + self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT ) -> TextEmbeddingResult: """ Invoke large language model @@ -246,9 +246,9 @@ class ModelInstance: self, query: str, docs: list[str], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -276,7 +276,7 @@ class ModelInstance: ), ) - def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: + def invoke_moderation(self, text: str, user: str | None = None) -> bool: """ Invoke moderation model @@ -297,7 +297,7 @@ class ModelInstance: ), ) - def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: + def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str: """ Invoke large language model @@ -318,7 +318,7 @@ class ModelInstance: ), ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]: """ Invoke large language tts model @@ -397,7 +397,7 @@ class ModelInstance: except Exception as e: raise e - def get_tts_voices(self, language: Optional[str] = None): + def get_tts_voices(self, language: str | None = None): """ Invoke large language tts model voices @@ -470,7 +470,7 @@ class LBModelManager: model_type: ModelType, model: str, load_balancing_configs: list[ModelLoadBalancingConfiguration], - managed_credentials: Optional[dict] = None, + managed_credentials: dict | None = None, ): """ Load balancing model manager @@ -495,7 +495,7 @@ class LBModelManager: else: load_balancing_config.credentials = managed_credentials - def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: + def fetch_next(self) -> ModelLoadBalancingConfiguration | None: """ Get next model load balancing config Strategy: Round Robin diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 5ce4c23dbb..a745a91510 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -31,10 +30,10 @@ class Callback(ABC): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Before invoke callback @@ -60,10 +59,10 @@ class Callback(ABC): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ On new chunk callback @@ -90,10 +89,10 @@ class Callback(ABC): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ After invoke callback @@ -120,10 +119,10 @@ class Callback(ABC): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Invoke error callback @@ -141,7 +140,7 @@ class Callback(ABC): """ raise NotImplementedError() - def print_text(self, text: str, color: Optional[str] = None, end: str = ""): + def print_text(self, text: str, color: str | None = None, end: str = ""): """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(text_to_print, end=end) diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 8411afca92..b366fcc57b 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -2,7 +2,7 @@ import json import logging import sys from collections.abc import Sequence -from typing import Optional, cast +from typing import cast from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -20,10 +20,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Before invoke callback @@ -76,10 +76,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ On new chunk callback @@ -106,10 +106,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ After invoke callback @@ -147,10 +147,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Invoke error callback diff --git a/api/core/model_runtime/entities/common_entities.py b/api/core/model_runtime/entities/common_entities.py index 659ad59bd6..c7353de5af 100644 --- a/api/core/model_runtime/entities/common_entities.py +++ b/api/core/model_runtime/entities/common_entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel @@ -8,7 +6,7 @@ class I18nObject(BaseModel): Model class for i18n object. """ - zh_Hans: Optional[str] = None + zh_Hans: str | None = None en_US: str def __init__(self, **data): diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index d5caddb7a3..17f6000d93 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from decimal import Decimal from enum import StrEnum -from typing import Any, Optional, TypedDict, Union +from typing import Any, TypedDict, Union from pydantic import BaseModel, Field @@ -150,13 +150,13 @@ class LLMResult(BaseModel): Model class for llm result. """ - id: Optional[str] = None + id: str | None = None model: str prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) message: AssistantPromptMessage usage: LLMUsage - system_fingerprint: Optional[str] = None - reasoning_content: Optional[str] = None + system_fingerprint: str | None = None + reasoning_content: str | None = None class LLMStructuredOutput(BaseModel): @@ -164,7 +164,7 @@ class LLMStructuredOutput(BaseModel): Model class for llm structured output. """ - structured_output: Optional[Mapping[str, Any]] = None + structured_output: Mapping[str, Any] | None = None class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput): @@ -180,8 +180,8 @@ class LLMResultChunkDelta(BaseModel): index: int message: AssistantPromptMessage - usage: Optional[LLMUsage] = None - finish_reason: Optional[str] = None + usage: LLMUsage | None = None + finish_reason: str | None = None class LLMResultChunk(BaseModel): @@ -191,7 +191,7 @@ class LLMResultChunk(BaseModel): model: str prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - system_fingerprint: Optional[str] = None + system_fingerprint: str | None = None delta: LLMResultChunkDelta diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 7cd2e6a3d1..9235c881e0 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,20 +1,20 @@ from abc import ABC from collections.abc import Mapping, Sequence -from enum import Enum, StrEnum -from typing import Annotated, Any, Literal, Optional, Union +from enum import StrEnum, auto +from typing import Annotated, Any, Literal, Union from pydantic import BaseModel, Field, field_serializer, field_validator -class PromptMessageRole(Enum): +class PromptMessageRole(StrEnum): """ Enum class for prompt message. """ - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - TOOL = "tool" + SYSTEM = auto() + USER = auto() + ASSISTANT = auto() + TOOL = auto() @classmethod def value_of(cls, value: str) -> "PromptMessageRole": @@ -54,11 +54,11 @@ class PromptMessageContentType(StrEnum): Enum class for prompt message content type. """ - TEXT = "text" - IMAGE = "image" - AUDIO = "audio" - VIDEO = "video" - DOCUMENT = "document" + TEXT = auto() + IMAGE = auto() + AUDIO = auto() + VIDEO = auto() + DOCUMENT = auto() class PromptMessageContent(ABC, BaseModel): @@ -108,8 +108,8 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent): """ class DETAIL(StrEnum): - LOW = "low" - HIGH = "high" + LOW = auto() + HIGH = auto() type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW @@ -146,8 +146,8 @@ class PromptMessage(ABC, BaseModel): """ role: PromptMessageRole - content: Optional[str | list[PromptMessageContentUnionTypes]] = None - name: Optional[str] = None + content: str | list[PromptMessageContentUnionTypes] | None = None + name: str | None = None def is_empty(self) -> bool: """ @@ -193,8 +193,8 @@ class PromptMessage(ABC, BaseModel): @field_serializer("content") def serialize_content( - self, content: Optional[Union[str, Sequence[PromptMessageContent]]] - ) -> Optional[str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent]]: + self, content: Union[str, Sequence[PromptMessageContent]] | None + ) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None: if content is None or isinstance(content, str): return content if isinstance(content, list): diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 568149cc37..aee6ce1108 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,23 +1,23 @@ from decimal import Decimal -from enum import Enum, StrEnum -from typing import Any, Optional +from enum import StrEnum, auto +from typing import Any from pydantic import BaseModel, ConfigDict, model_validator from core.model_runtime.entities.common_entities import I18nObject -class ModelType(Enum): +class ModelType(StrEnum): """ Enum class for model type. """ - LLM = "llm" + LLM = auto() TEXT_EMBEDDING = "text-embedding" - RERANK = "rerank" - SPEECH2TEXT = "speech2text" - MODERATION = "moderation" - TTS = "tts" + RERANK = auto() + SPEECH2TEXT = auto() + MODERATION = auto() + TTS = auto() @classmethod def value_of(cls, origin_model_type: str) -> "ModelType": @@ -26,17 +26,17 @@ class ModelType(Enum): :return: model type """ - if origin_model_type in {"text-generation", cls.LLM.value}: + if origin_model_type in {"text-generation", cls.LLM}: return cls.LLM - elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}: + elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}: return cls.TEXT_EMBEDDING - elif origin_model_type in {"reranking", cls.RERANK.value}: + elif origin_model_type in {"reranking", cls.RERANK}: return cls.RERANK - elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}: + elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}: return cls.SPEECH2TEXT - elif origin_model_type in {"tts", cls.TTS.value}: + elif origin_model_type in {"tts", cls.TTS}: return cls.TTS - elif origin_model_type == cls.MODERATION.value: + elif origin_model_type == cls.MODERATION: return cls.MODERATION else: raise ValueError(f"invalid origin model type {origin_model_type}") @@ -63,7 +63,7 @@ class ModelType(Enum): raise ValueError(f"invalid model type {self}") -class FetchFrom(Enum): +class FetchFrom(StrEnum): """ Enum class for fetch from. """ @@ -72,7 +72,7 @@ class FetchFrom(Enum): CUSTOMIZABLE_MODEL = "customizable-model" -class ModelFeature(Enum): +class ModelFeature(StrEnum): """ Enum class for llm feature. """ @@ -80,11 +80,11 @@ class ModelFeature(Enum): TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" - VISION = "vision" + VISION = auto() STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = "document" - VIDEO = "video" - AUDIO = "audio" + DOCUMENT = auto() + VIDEO = auto() + AUDIO = auto() STRUCTURED_OUTPUT = "structured-output" @@ -93,14 +93,14 @@ class DefaultParameterName(StrEnum): Enum class for parameter template variable. """ - TEMPERATURE = "temperature" - TOP_P = "top_p" - TOP_K = "top_k" - PRESENCE_PENALTY = "presence_penalty" - FREQUENCY_PENALTY = "frequency_penalty" - MAX_TOKENS = "max_tokens" - RESPONSE_FORMAT = "response_format" - JSON_SCHEMA = "json_schema" + TEMPERATURE = auto() + TOP_P = auto() + TOP_K = auto() + PRESENCE_PENALTY = auto() + FREQUENCY_PENALTY = auto() + MAX_TOKENS = auto() + RESPONSE_FORMAT = auto() + JSON_SCHEMA = auto() @classmethod def value_of(cls, value: Any) -> "DefaultParameterName": @@ -116,34 +116,34 @@ class DefaultParameterName(StrEnum): raise ValueError(f"invalid parameter name {value}") -class ParameterType(Enum): +class ParameterType(StrEnum): """ Enum class for parameter type. """ - FLOAT = "float" - INT = "int" - STRING = "string" - BOOLEAN = "boolean" - TEXT = "text" + FLOAT = auto() + INT = auto() + STRING = auto() + BOOLEAN = auto() + TEXT = auto() -class ModelPropertyKey(Enum): +class ModelPropertyKey(StrEnum): """ Enum class for model property key. """ - MODE = "mode" - CONTEXT_SIZE = "context_size" - MAX_CHUNKS = "max_chunks" - FILE_UPLOAD_LIMIT = "file_upload_limit" - SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions" - MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk" - DEFAULT_VOICE = "default_voice" - VOICES = "voices" - WORD_LIMIT = "word_limit" - AUDIO_TYPE = "audio_type" - MAX_WORKERS = "max_workers" + MODE = auto() + CONTEXT_SIZE = auto() + MAX_CHUNKS = auto() + FILE_UPLOAD_LIMIT = auto() + SUPPORTED_FILE_EXTENSIONS = auto() + MAX_CHARACTERS_PER_CHUNK = auto() + DEFAULT_VOICE = auto() + VOICES = auto() + WORD_LIMIT = auto() + AUDIO_TYPE = auto() + MAX_WORKERS = auto() class ProviderModel(BaseModel): @@ -154,7 +154,7 @@ class ProviderModel(BaseModel): model: str label: I18nObject model_type: ModelType - features: Optional[list[ModelFeature]] = None + features: list[ModelFeature] | None = None fetch_from: FetchFrom model_properties: dict[ModelPropertyKey, Any] deprecated: bool = False @@ -171,15 +171,15 @@ class ParameterRule(BaseModel): """ name: str - use_template: Optional[str] = None + use_template: str | None = None label: I18nObject type: ParameterType - help: Optional[I18nObject] = None + help: I18nObject | None = None required: bool = False - default: Optional[Any] = None - min: Optional[float] = None - max: Optional[float] = None - precision: Optional[int] = None + default: Any | None = None + min: float | None = None + max: float | None = None + precision: int | None = None options: list[str] = [] @@ -189,7 +189,7 @@ class PriceConfig(BaseModel): """ input: Decimal - output: Optional[Decimal] = None + output: Decimal | None = None unit: Decimal currency: str @@ -200,7 +200,7 @@ class AIModelEntity(ProviderModel): """ parameter_rules: list[ParameterRule] = [] - pricing: Optional[PriceConfig] = None + pricing: PriceConfig | None = None @model_validator(mode="after") def validate_model(self): @@ -220,13 +220,13 @@ class ModelUsage(BaseModel): pass -class PriceType(Enum): +class PriceType(StrEnum): """ Enum class for price type. """ - INPUT = "input" - OUTPUT = "output" + INPUT = auto() + OUTPUT = auto() class PriceInfo(BaseModel): diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index c9aa8d1474..2ccc9e0eae 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,6 +1,5 @@ from collections.abc import Sequence -from enum import Enum -from typing import Optional +from enum import Enum, StrEnum, auto from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -17,16 +16,16 @@ class ConfigurateMethod(Enum): CUSTOMIZABLE_MODEL = "customizable-model" -class FormType(Enum): +class FormType(StrEnum): """ Enum class for form type. """ TEXT_INPUT = "text-input" SECRET_INPUT = "secret-input" - SELECT = "select" - RADIO = "radio" - SWITCH = "switch" + SELECT = auto() + RADIO = auto() + SWITCH = auto() class FormShowOnObject(BaseModel): @@ -62,9 +61,9 @@ class CredentialFormSchema(BaseModel): label: I18nObject type: FormType required: bool = True - default: Optional[str] = None - options: Optional[list[FormOption]] = None - placeholder: Optional[I18nObject] = None + default: str | None = None + options: list[FormOption] | None = None + placeholder: I18nObject | None = None max_length: int = 0 show_on: list[FormShowOnObject] = [] @@ -79,7 +78,7 @@ class ProviderCredentialSchema(BaseModel): class FieldModelSchema(BaseModel): label: I18nObject - placeholder: Optional[I18nObject] = None + placeholder: I18nObject | None = None class ModelCredentialSchema(BaseModel): @@ -98,8 +97,8 @@ class SimpleProviderEntity(BaseModel): provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] models: list[AIModelEntity] = [] @@ -120,24 +119,24 @@ class ProviderEntity(BaseModel): provider: str label: I18nObject - description: Optional[I18nObject] = None - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None - icon_small_dark: Optional[I18nObject] = None - icon_large_dark: Optional[I18nObject] = None - background: Optional[str] = None - help: Optional[ProviderHelpEntity] = None + description: I18nObject | None = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None + icon_small_dark: I18nObject | None = None + icon_large_dark: I18nObject | None = None + background: str | None = None + help: ProviderHelpEntity | None = None supported_model_types: Sequence[ModelType] configurate_methods: list[ConfigurateMethod] models: list[AIModelEntity] = Field(default_factory=list) - provider_credential_schema: Optional[ProviderCredentialSchema] = None - model_credential_schema: Optional[ModelCredentialSchema] = None + provider_credential_schema: ProviderCredentialSchema | None = None + model_credential_schema: ModelCredentialSchema | None = None # pydantic configs model_config = ConfigDict(protected_namespaces=()) # position from plugin _position.yaml - position: Optional[dict[str, list[str]]] = {} + position: dict[str, list[str]] | None = {} @field_validator("models", mode="before") @classmethod diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index 6bcb707684..80cf01fb6c 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -1,12 +1,9 @@ -from typing import Optional - - class InvokeError(ValueError): """Base class for all LLM exceptions.""" - description: Optional[str] = None + description: str | None = None - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description def __str__(self): diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index a10256c8d8..45f0335c2e 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,7 +1,6 @@ import decimal import hashlib from threading import Lock -from typing import Optional from pydantic import BaseModel, ConfigDict, Field @@ -100,7 +99,7 @@ class AIModel(BaseModel): model_schema = self.get_model_schema(model, credentials) # get price info from predefined model schema - price_config: Optional[PriceConfig] = None + price_config: PriceConfig | None = None if model_schema and model_schema.pricing: price_config = model_schema.pricing @@ -133,7 +132,7 @@ class AIModel(BaseModel): currency=price_config.currency, ) - def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: + def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None: """ Get model schema by model name and credentials @@ -174,7 +173,7 @@ class AIModel(BaseModel): return schema - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema from credentials @@ -232,7 +231,7 @@ class AIModel(BaseModel): return schema - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index e363d70cfc..c0f4c504d9 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Sequence -from typing import Optional, Union +from typing import Union from pydantic import ConfigDict @@ -93,12 +93,12 @@ class LargeLanguageModel(AIModel): model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: """ Invoke large language model @@ -244,11 +244,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunk, None, None]: """ Invoke result generator @@ -329,7 +329,7 @@ class LargeLanguageModel(AIModel): model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> int: """ Get number of tokens for given prompt messages @@ -357,7 +357,7 @@ class LargeLanguageModel(AIModel): ) return 0 - def _calc_response_usage( + def calc_response_usage( self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int ) -> LLMUsage: """ @@ -406,11 +406,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger before invoke callbacks @@ -454,11 +454,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger new chunk callbacks @@ -501,11 +501,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger after invoke callbacks @@ -551,11 +551,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger invoke error callbacks diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py index d17fea6321..7aff0184f4 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/core/model_runtime/model_providers/__base/moderation_model.py @@ -1,5 +1,4 @@ import time -from typing import Optional from pydantic import ConfigDict @@ -17,7 +16,7 @@ class ModerationModel(AIModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: + def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool: """ Invoke moderation model diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index c1422033f3..36067118b0 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -18,9 +16,9 @@ class RerankModel(AIModel): credentials: dict, query: str, docs: list[str], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> RerankResult: """ Invoke rerank model diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py index d20b80365a..9d3bf13e79 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py @@ -1,4 +1,4 @@ -from typing import IO, Optional +from typing import IO from pydantic import ConfigDict @@ -16,7 +16,7 @@ class Speech2TextModel(AIModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str: """ Invoke speech to text model diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 05c96a3e93..bd68ffe903 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import ConfigDict from core.entities.embedding_type import EmbeddingInputType @@ -23,7 +21,7 @@ class TextEmbeddingModel(AIModel): model: str, credentials: dict, texts: list[str], - user: Optional[str] = None, + user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> TextEmbeddingResult: """ @@ -48,7 +46,7 @@ class TextEmbeddingModel(AIModel): model=model, credentials=credentials, texts=texts, - input_type=input_type.value, + input_type=input_type, ) except Exception as e: raise self._transform_invoke_error(e) diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py index 8f8a638af6..23d36c03af 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py @@ -1,10 +1,10 @@ import logging from threading import Lock -from typing import Any, Optional +from typing import Any logger = logging.getLogger(__name__) -_tokenizer: Optional[Any] = None +_tokenizer: Any | None = None _lock = Lock() diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 8529463bc7..a83c8be37c 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -1,6 +1,5 @@ import logging from collections.abc import Iterable -from typing import Optional from pydantic import ConfigDict @@ -27,7 +26,7 @@ class TTSModel(AIModel): credentials: dict, content_text: str, voice: str, - user: Optional[str] = None, + user: str | None = None, ) -> Iterable[bytes]: """ Invoke large language model @@ -57,7 +56,7 @@ class TTSModel(AIModel): except Exception as e: raise self._transform_invoke_error(e) - def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None): + def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None): """ Retrieves the list of voices supported by a given text-to-speech (TTS) model. diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 8bea9bd121..250b158f99 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,14 +1,9 @@ import hashlib import logging -import os from collections.abc import Sequence from threading import Lock -from typing import Optional - -from pydantic import BaseModel import contexts -from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -26,15 +21,10 @@ from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) -class ModelProviderExtension(BaseModel): - plugin_model_provider_entity: PluginModelProviderEntity - position: Optional[int] = None - - class ModelProviderFactory: provider_position_map: dict[str, int] - def __init__(self, tenant_id: str) -> None: + def __init__(self, tenant_id: str): from core.plugin.impl.model import PluginModelClient self.provider_position_map = {} @@ -42,34 +32,15 @@ class ModelProviderFactory: self.tenant_id = tenant_id self.plugin_model_manager = PluginModelClient() - if not self.provider_position_map: - # get the path of current classes - current_path = os.path.abspath(__file__) - model_providers_path = os.path.dirname(current_path) - - # get _position.yaml file path - self.provider_position_map = get_provider_position_map(model_providers_path) - def get_providers(self) -> Sequence[ProviderEntity]: """ Get all providers :return: list of providers """ - # Fetch plugin model providers + # FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server + # The plugin server should return providers in the desired order plugin_providers = self.get_plugin_model_providers() - - # Convert PluginModelProviderEntity to ModelProviderExtension - model_provider_extensions = [] - for provider in plugin_providers: - model_provider_extensions.append(ModelProviderExtension(plugin_model_provider_entity=provider)) - - sorted_extensions = sort_to_dict_by_position_map( - position_map=self.provider_position_map, - data=model_provider_extensions, - name_func=lambda x: x.plugin_model_provider_entity.declaration.provider, - ) - - return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()] + return [provider.declaration for provider in plugin_providers] def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]: """ @@ -238,9 +209,9 @@ class ModelProviderFactory: def get_models( self, *, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - provider_configs: Optional[list[ProviderConfig]] = None, + provider: str | None = None, + model_type: ModelType | None = None, + provider_configs: list[ProviderConfig] | None = None, ) -> list[SimpleProviderEntity]: """ Get all models for given model type diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index 962e417671..c758eaf49f 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -8,7 +8,7 @@ from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6 from pathlib import Path, PurePath from re import Pattern from types import GeneratorType -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Union from uuid import UUID from pydantic import BaseModel @@ -18,7 +18,7 @@ from pydantic_core import Url from pydantic_extra_types.color import Color -def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any): +def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: return model.model_dump(mode=mode, **kwargs) @@ -98,9 +98,9 @@ def jsonable_encoder( exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, + custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, sqlalchemy_safe: bool = True, -): +) -> Any: custom_encoder = custom_encoder or {} if custom_encoder: if type(obj) in custom_encoder: diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index ce7bd21110..573f4ec2a7 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field from sqlalchemy import select @@ -87,7 +85,7 @@ class ApiModeration(Moderation): return result @staticmethod - def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: + def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension | None: stmt = select(APIBasedExtension).where( APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id ) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 752617b654..d76b4689be 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,15 +1,14 @@ from abc import ABC, abstractmethod -from enum import Enum -from typing import Optional +from enum import StrEnum, auto from pydantic import BaseModel, Field from core.extension.extensible import Extensible, ExtensionModule -class ModerationAction(Enum): - DIRECT_OUTPUT = "direct_output" - OVERRIDDEN = "overridden" +class ModerationAction(StrEnum): + DIRECT_OUTPUT = auto() + OVERRIDDEN = auto() class ModerationInputsResult(BaseModel): @@ -34,7 +33,7 @@ class Moderation(Extensible, ABC): module: ExtensionModule = ExtensionModule.MODERATION - def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None): + def __init__(self, app_id: str, tenant_id: str, config: dict | None = None): super().__init__(tenant_id, config) self.app_id = app_id diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 3ac33966cb..21dc58f16f 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationError @@ -21,7 +21,7 @@ class InputModeration: inputs: Mapping[str, Any], query: str, message_id: str, - trace_manager: Optional[TraceQueueManager] = None, + trace_manager: TraceQueueManager | None = None, ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 6993ec8b0b..a97e3d4253 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -1,7 +1,7 @@ import logging import threading import time -from typing import Any, Optional +from typing import Any from flask import Flask, current_app from pydantic import BaseModel, ConfigDict @@ -27,11 +27,11 @@ class OutputModeration(BaseModel): rule: ModerationRule queue_manager: AppQueueManager - thread: Optional[threading.Thread] = None + thread: threading.Thread | None = None thread_running: bool = True buffer: str = "" is_final_chunk: bool = False - final_output: Optional[str] = None + final_output: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) def should_direct_output(self) -> bool: @@ -127,7 +127,7 @@ class OutputModeration(BaseModel): if result.action == ModerationAction.DIRECT_OUTPUT: break - def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: + def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> ModerationOutputsResult | None: try: moderation_factory = ModerationFactory( name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 33b84d8ca5..7e817a6bff 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,7 +1,6 @@ import json import logging from collections.abc import Sequence -from typing import Optional from urllib.parse import urljoin from opentelemetry.trace import Link, Status, StatusCode @@ -120,7 +119,7 @@ class AliyunDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -353,8 +352,8 @@ class AliyunDataTrace(BaseTraceInstance): GEN_AI_FRAMEWORK: "dify", TOOL_NAME: node_execution.title, TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False), - TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False), - INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False), + TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False), + INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False), OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False), }, status=self.get_workflow_node_status(node_execution), diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index 881ec2141c..09cb6e3fc1 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -7,7 +7,6 @@ import uuid from collections import deque from collections.abc import Sequence from datetime import datetime -from typing import Optional import requests from opentelemetry import trace as trace_api @@ -184,7 +183,7 @@ def generate_span_id() -> int: return span_id -def convert_to_trace_id(uuid_v4: Optional[str]) -> int: +def convert_to_trace_id(uuid_v4: str | None) -> int: try: uuid_obj = uuid.UUID(uuid_v4) return uuid_obj.int @@ -192,7 +191,7 @@ def convert_to_trace_id(uuid_v4: Optional[str]) -> int: raise ValueError(f"Invalid UUID input: {e}") -def convert_string_to_id(string: Optional[str]) -> int: +def convert_string_to_id(string: str | None) -> int: if not string: return generate_span_id() hash_bytes = hashlib.sha256(string.encode("utf-8")).digest() @@ -200,7 +199,7 @@ def convert_string_to_id(string: Optional[str]) -> int: return id -def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int: +def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int: try: uuid_obj = uuid.UUID(uuid_v4) except Exception as e: @@ -209,7 +208,7 @@ def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int: return convert_string_to_id(combined_key) -def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]: +def convert_datetime_to_nanoseconds(start_time_a: datetime | None) -> int | None: if start_time_a is None: return None timestamp_in_seconds = start_time_a.timestamp() diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py index 1caa822cd0..f3dcbc5b8f 100644 --- a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py +++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from typing import Optional from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import Event, Status, StatusCode @@ -10,12 +9,12 @@ class SpanData(BaseModel): model_config = {"arbitrary_types_allowed": True} trace_id: int = Field(..., description="The unique identifier for the trace.") - parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.") + parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.") span_id: int = Field(..., description="The unique identifier for this span.") name: str = Field(..., description="The name of the span.") attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.") events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.") links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.") status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.") - start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.") - end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.") + start_time: int | None = Field(..., description="The start time of the span in nanoseconds.") + end_time: int | None = Field(..., description="The end time of the span in nanoseconds.") diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/core/ops/aliyun_trace/entities/semconv.py index 5d70264320..c9427c776a 100644 --- a/api/core/ops/aliyun_trace/entities/semconv.py +++ b/api/core/ops/aliyun_trace/entities/semconv.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum # public GEN_AI_SESSION_ID = "gen_ai.session.id" @@ -53,7 +53,7 @@ TOOL_DESCRIPTION = "tool.description" TOOL_PARAMETERS = "tool.parameters" -class GenAISpanKind(Enum): +class GenAISpanKind(StrEnum): CHAIN = "CHAIN" RETRIEVER = "RETRIEVER" RERANKER = "RERANKER" diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index e7c90c1229..1497bc1863 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -3,7 +3,7 @@ import json import logging import os from datetime import datetime, timedelta -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from urllib.parse import urlparse from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes @@ -15,6 +15,7 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.id_generator import RandomIdGenerator from opentelemetry.trace import SpanContext, TraceFlags, TraceState +from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig @@ -91,14 +92,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra raise -def datetime_to_nanos(dt: Optional[datetime]) -> int: +def datetime_to_nanos(dt: datetime | None) -> int: """Convert datetime to nanoseconds since epoch. If None, use current time.""" if dt is None: dt = datetime.now() return int(dt.timestamp() * 1_000_000_000) -def string_to_trace_id128(string: Optional[str]) -> int: +def string_to_trace_id128(string: str | None) -> int: """ Convert any input string into a stable 128-bit integer trace ID. @@ -283,7 +284,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): return file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data if message_file_data is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -307,7 +308,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): # Add end user data if available if trace_info.message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -699,8 +700,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def _get_workflow_nodes(self, workflow_run_id: str): """Helper method to get workflow nodes""" - workflow_nodes = ( - db.session.query( + workflow_nodes = db.session.scalars( + select( WorkflowNodeExecutionModel.id, WorkflowNodeExecutionModel.tenant_id, WorkflowNodeExecutionModel.app_id, @@ -713,10 +714,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): WorkflowNodeExecutionModel.elapsed_time, WorkflowNodeExecutionModel.process_data, WorkflowNodeExecutionModel.execution_metadata, - ) - .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) - .all() - ) + ).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) + ).all() return workflow_nodes def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 71c173d1f1..b8a25c5d7d 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,20 +1,20 @@ from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, ConfigDict, field_serializer, field_validator class BaseTraceInfo(BaseModel): - message_id: Optional[str] = None - message_data: Optional[Any] = None - inputs: Optional[Union[str, dict[str, Any], list]] = None - outputs: Optional[Union[str, dict[str, Any], list]] = None - start_time: Optional[datetime] = None - end_time: Optional[datetime] = None + message_id: str | None = None + message_data: Any | None = None + inputs: Union[str, dict[str, Any], list] | None = None + outputs: Union[str, dict[str, Any], list] | None = None + start_time: datetime | None = None + end_time: datetime | None = None metadata: dict[str, Any] - trace_id: Optional[str] = None + trace_id: str | None = None @field_validator("inputs", "outputs") @classmethod @@ -35,9 +35,9 @@ class BaseTraceInfo(BaseModel): class WorkflowTraceInfo(BaseTraceInfo): - workflow_data: Any - conversation_id: Optional[str] = None - workflow_app_log_id: Optional[str] = None + workflow_data: Any = None + conversation_id: str | None = None + workflow_app_log_id: str | None = None workflow_id: str tenant_id: str workflow_run_id: str @@ -46,7 +46,7 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_inputs: Mapping[str, Any] workflow_run_outputs: Mapping[str, Any] workflow_run_version: str - error: Optional[str] = None + error: str | None = None total_tokens: int file_list: list[str] query: str @@ -58,9 +58,9 @@ class MessageTraceInfo(BaseTraceInfo): message_tokens: int answer_tokens: int total_tokens: int - error: Optional[str] = None - file_list: Optional[Union[str, dict[str, Any], list]] = None - message_file_data: Optional[Any] = None + error: str | None = None + file_list: Union[str, dict[str, Any], list] | None = None + message_file_data: Any | None = None conversation_mode: str @@ -73,23 +73,23 @@ class ModerationTraceInfo(BaseTraceInfo): class SuggestedQuestionTraceInfo(BaseTraceInfo): total_tokens: int - status: Optional[str] = None - error: Optional[str] = None - from_account_id: Optional[str] = None - agent_based: Optional[bool] = None - from_source: Optional[str] = None - model_provider: Optional[str] = None - model_id: Optional[str] = None + status: str | None = None + error: str | None = None + from_account_id: str | None = None + agent_based: bool | None = None + from_source: str | None = None + model_provider: str | None = None + model_id: str | None = None suggested_question: list[str] level: str - status_message: Optional[str] = None - workflow_run_id: Optional[str] = None + status_message: str | None = None + workflow_run_id: str | None = None model_config = ConfigDict(protected_namespaces=()) class DatasetRetrievalTraceInfo(BaseTraceInfo): - documents: Any + documents: Any = None class ToolTraceInfo(BaseTraceInfo): @@ -97,23 +97,23 @@ class ToolTraceInfo(BaseTraceInfo): tool_inputs: dict[str, Any] tool_outputs: str metadata: dict[str, Any] - message_file_data: Any - error: Optional[str] = None + message_file_data: Any = None + error: str | None = None tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] - file_url: Union[str, None, list] + file_url: Union[str, None, list] = None class GenerateNameTraceInfo(BaseTraceInfo): - conversation_id: Optional[str] = None + conversation_id: str | None = None tenant_id: str class TaskData(BaseModel): app_id: str trace_info_type: str - trace_info: Any + trace_info: Any = None trace_info_info_map = { diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index 46ba1c45b9..312c7d3676 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic_core.core_schema import ValidationInfo @@ -52,50 +52,50 @@ class LangfuseTrace(BaseModel): Langfuse trace model """ - id: Optional[str] = Field( + id: str | None = Field( default=None, description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems " "or when creating a distributed trace. Traces are upserted on id.", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Identifier of the trace. Useful for sorting/filtering in the UI.", ) - input: Optional[Union[str, dict[str, Any], list, None]] = Field( + input: Union[str, dict[str, Any], list, None] | None = Field( default=None, description="The input of the trace. Can be any JSON object." ) - output: Optional[Union[str, dict[str, Any], list, None]] = Field( + output: Union[str, dict[str, Any], list, None] | None = Field( default=None, description="The output of the trace. Can be any JSON object." ) - metadata: Optional[dict[str, Any]] = Field( + metadata: dict[str, Any] | None = Field( default=None, description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated " "via the API.", ) - user_id: Optional[str] = Field( + user_id: str | None = Field( default=None, description="The id of the user that triggered the execution. Used to provide user-level analytics.", ) - session_id: Optional[str] = Field( + session_id: str | None = Field( default=None, description="Used to group multiple traces into a session in Langfuse. Use your own session/thread identifier.", ) - version: Optional[str] = Field( + version: str | None = Field( default=None, description="The version of the trace type. Used to understand how changes to the trace type affect metrics. " "Useful in debugging.", ) - release: Optional[str] = Field( + release: str | None = Field( default=None, description="The release identifier of the current deployment. Used to understand how changes of different " "deployments affect metrics. Useful in debugging.", ) - tags: Optional[list[str]] = Field( + tags: list[str] | None = Field( default=None, description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET " "API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.", ) - public: Optional[bool] = Field( + public: bool | None = Field( default=None, description="You can make a trace public to share it via a public link. This allows others to view the trace " "without needing to log in or be members of your Langfuse project.", @@ -113,61 +113,61 @@ class LangfuseSpan(BaseModel): Langfuse span model """ - id: Optional[str] = Field( + id: str | None = Field( default=None, description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.", ) - session_id: Optional[str] = Field( + session_id: str | None = Field( default=None, description="Used to group multiple spans into a session in Langfuse. Use your own session/thread identifier.", ) - trace_id: Optional[str] = Field( + trace_id: str | None = Field( default=None, description="The id of the trace the span belongs to. Used to link spans to traces.", ) - user_id: Optional[str] = Field( + user_id: str | None = Field( default=None, description="The id of the user that triggered the execution. Used to provide user-level analytics.", ) - start_time: Optional[datetime | str] = Field( + start_time: datetime | str | None = Field( default_factory=datetime.now, description="The time at which the span started, defaults to the current time.", ) - end_time: Optional[datetime | str] = Field( + end_time: datetime | str | None = Field( default=None, description="The time at which the span ended. Automatically set by span.end().", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Identifier of the span. Useful for sorting/filtering in the UI.", ) - metadata: Optional[dict[str, Any]] = Field( + metadata: dict[str, Any] | None = Field( default=None, description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated " "via the API.", ) - level: Optional[str] = Field( + level: str | None = Field( default=None, description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of " "traces with elevated error levels and for highlighting in the UI.", ) - status_message: Optional[str] = Field( + status_message: str | None = Field( default=None, description="The status message of the span. Additional field for context of the event. E.g. the error " "message of an error event.", ) - input: Optional[Union[str, Mapping[str, Any], list, None]] = Field( + input: Union[str, Mapping[str, Any], list, None] | None = Field( default=None, description="The input of the span. Can be any JSON object." ) - output: Optional[Union[str, Mapping[str, Any], list, None]] = Field( + output: Union[str, Mapping[str, Any], list, None] | None = Field( default=None, description="The output of the span. Can be any JSON object." ) - version: Optional[str] = Field( + version: str | None = Field( default=None, description="The version of the span type. Used to understand how changes to the span type affect metrics. " "Useful in debugging.", ) - parent_observation_id: Optional[str] = Field( + parent_observation_id: str | None = Field( default=None, description="The id of the observation the span belongs to. Used to link spans to observations.", ) @@ -188,15 +188,15 @@ class UnitEnum(StrEnum): class GenerationUsage(BaseModel): - promptTokens: Optional[int] = None - completionTokens: Optional[int] = None - total: Optional[int] = None - input: Optional[int] = None - output: Optional[int] = None - unit: Optional[UnitEnum] = None - inputCost: Optional[float] = None - outputCost: Optional[float] = None - totalCost: Optional[float] = None + promptTokens: int | None = None + completionTokens: int | None = None + total: int | None = None + input: int | None = None + output: int | None = None + unit: UnitEnum | None = None + inputCost: float | None = None + outputCost: float | None = None + totalCost: float | None = None @field_validator("input", "output") @classmethod @@ -206,69 +206,69 @@ class GenerationUsage(BaseModel): class LangfuseGeneration(BaseModel): - id: Optional[str] = Field( + id: str | None = Field( default=None, description="The id of the generation can be set, defaults to random id.", ) - trace_id: Optional[str] = Field( + trace_id: str | None = Field( default=None, description="The id of the trace the generation belongs to. Used to link generations to traces.", ) - parent_observation_id: Optional[str] = Field( + parent_observation_id: str | None = Field( default=None, description="The id of the observation the generation belongs to. Used to link generations to observations.", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Identifier of the generation. Useful for sorting/filtering in the UI.", ) - start_time: Optional[datetime | str] = Field( + start_time: datetime | str | None = Field( default_factory=datetime.now, description="The time at which the generation started, defaults to the current time.", ) - completion_start_time: Optional[datetime | str] = Field( + completion_start_time: datetime | str | None = Field( default=None, description="The time at which the completion started (streaming). Set it to get latency analytics broken " "down into time until completion started and completion duration.", ) - end_time: Optional[datetime | str] = Field( + end_time: datetime | str | None = Field( default=None, description="The time at which the generation ended. Automatically set by generation.end().", ) - model: Optional[str] = Field(default=None, description="The name of the model used for the generation.") - model_parameters: Optional[dict[str, Any]] = Field( + model: str | None = Field(default=None, description="The name of the model used for the generation.") + model_parameters: dict[str, Any] | None = Field( default=None, description="The parameters of the model used for the generation; can be any key-value pairs.", ) - input: Optional[Any] = Field( + input: Any | None = Field( default=None, description="The prompt used for the generation. Can be any string or JSON object.", ) - output: Optional[Any] = Field( + output: Any | None = Field( default=None, description="The completion generated by the model. Can be any string or JSON object.", ) - usage: Optional[GenerationUsage] = Field( + usage: GenerationUsage | None = Field( default=None, description="The usage object supports the OpenAi structure with tokens and a more generic version with " "detailed costs and units.", ) - metadata: Optional[dict[str, Any]] = Field( + metadata: dict[str, Any] | None = Field( default=None, description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being " "updated via the API.", ) - level: Optional[LevelEnum] = Field( + level: LevelEnum | None = Field( default=None, description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering " "of traces with elevated error levels and for highlighting in the UI.", ) - status_message: Optional[str] = Field( + status_message: str | None = Field( default=None, description="The status message of the generation. Additional field for context of the event. E.g. the error " "message of an error event.", ) - version: Optional[str] = Field( + version: str | None = Field( default=None, description="The version of the generation type. Used to understand how changes to the span type affect " "metrics. Useful in debugging.", diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 61b6a9c3e6..931bed78d4 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -1,7 +1,6 @@ import logging import os from datetime import datetime, timedelta -from typing import Optional from langfuse import Langfuse # type: ignore from sqlalchemy.orm import sessionmaker @@ -145,13 +144,13 @@ class LangFuseDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { @@ -164,7 +163,7 @@ class LangFuseDataTrace(BaseTraceInstance): "status": status, } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} model_provider = process_data.get("model_provider", None) model_name = process_data.get("model_name", None) if model_provider is not None and model_name is not None: @@ -242,7 +241,7 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -399,7 +398,7 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.add_span(langfuse_span_data=name_generation_span_data) - def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None): + def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None): format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {} try: self.langfuse_client.trace(**format_trace_data) @@ -407,7 +406,7 @@ class LangFuseDataTrace(BaseTraceInstance): except Exception as e: raise ValueError(f"LangFuse Failed to create trace: {str(e)}") - def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None): + def add_span(self, langfuse_span_data: LangfuseSpan | None = None): format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} try: self.langfuse_client.span(**format_span_data) @@ -415,12 +414,12 @@ class LangFuseDataTrace(BaseTraceInstance): except Exception as e: raise ValueError(f"LangFuse Failed to create span: {str(e)}") - def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None): + def update_span(self, span, langfuse_span_data: LangfuseSpan | None = None): format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} span.end(**format_span_data) - def add_generation(self, langfuse_generation_data: Optional[LangfuseGeneration] = None): + def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None): format_generation_data = ( filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} ) @@ -430,7 +429,7 @@ class LangFuseDataTrace(BaseTraceInstance): except Exception as e: raise ValueError(f"LangFuse Failed to create generation: {str(e)}") - def update_generation(self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None): + def update_generation(self, generation, langfuse_generation_data: LangfuseGeneration | None = None): format_generation_data = ( filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} ) diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 4fd01136ba..f73ba01c8b 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, Field, field_validator from pydantic_core.core_schema import ValidationInfo @@ -20,36 +20,36 @@ class LangSmithRunType(StrEnum): class LangSmithTokenUsage(BaseModel): - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - total_tokens: Optional[int] = None + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None class LangSmithMultiModel(BaseModel): - file_list: Optional[list[str]] = Field(None, description="List of files") + file_list: list[str] | None = Field(None, description="List of files") class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): - name: Optional[str] = Field(..., description="Name of the run") - inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run") - outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run") + name: str | None = Field(..., description="Name of the run") + inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the run") + outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the run") run_type: LangSmithRunType = Field(..., description="Type of the run") - start_time: Optional[datetime | str] = Field(None, description="Start time of the run") - end_time: Optional[datetime | str] = Field(None, description="End time of the run") - extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") - error: Optional[str] = Field(None, description="Error message of the run") - serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run") - parent_run_id: Optional[str] = Field(None, description="Parent run ID") - events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") - tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") - dotted_order: Optional[str] = Field(None, description="Dotted order of the run") - id: Optional[str] = Field(None, description="ID of the run") - session_id: Optional[str] = Field(None, description="Session ID associated with the run") - session_name: Optional[str] = Field(None, description="Session name associated with the run") - reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") - input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") - output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") + start_time: datetime | str | None = Field(None, description="Start time of the run") + end_time: datetime | str | None = Field(None, description="End time of the run") + extra: dict[str, Any] | None = Field(None, description="Extra information of the run") + error: str | None = Field(None, description="Error message of the run") + serialized: dict[str, Any] | None = Field(None, description="Serialized data of the run") + parent_run_id: str | None = Field(None, description="Parent run ID") + events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run") + tags: list[str] | None = Field(None, description="Tags associated with the run") + trace_id: str | None = Field(None, description="Trace ID associated with the run") + dotted_order: str | None = Field(None, description="Dotted order of the run") + id: str | None = Field(None, description="ID of the run") + session_id: str | None = Field(None, description="Session ID associated with the run") + session_name: str | None = Field(None, description="Session name associated with the run") + reference_example_id: str | None = Field(None, description="Reference example ID associated with the run") + input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run") + output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run") @field_validator("inputs", "outputs") @classmethod @@ -128,15 +128,15 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): class LangSmithRunUpdateModel(BaseModel): run_id: str = Field(..., description="ID of the run") - trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") - dotted_order: Optional[str] = Field(None, description="Dotted order of the run") - parent_run_id: Optional[str] = Field(None, description="Parent run ID") - end_time: Optional[datetime | str] = Field(None, description="End time of the run") - error: Optional[str] = Field(None, description="Error message of the run") - inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run") - outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run") - events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") - tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") - input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") - output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") + trace_id: str | None = Field(None, description="Trace ID associated with the run") + dotted_order: str | None = Field(None, description="Dotted order of the run") + parent_run_id: str | None = Field(None, description="Parent run ID") + end_time: datetime | str | None = Field(None, description="End time of the run") + error: str | None = Field(None, description="Error message of the run") + inputs: dict[str, Any] | None = Field(None, description="Inputs of the run") + outputs: dict[str, Any] | None = Field(None, description="Outputs of the run") + events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run") + tags: list[str] | None = Field(None, description="Tags associated with the run") + extra: dict[str, Any] | None = Field(None, description="Extra information of the run") + input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run") + output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 1d2155e584..24a43e1cd8 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -2,7 +2,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import Optional, cast +from typing import cast from langsmith import Client from langsmith.schemas import RunBase @@ -166,13 +166,13 @@ class LangSmithDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 metadata = {str(key): value for key, value in execution_metadata.items()} metadata.update( @@ -187,7 +187,7 @@ class LangSmithDataTrace(BaseTraceInstance): } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} if process_data and process_data.get("model_mode") == "chat": run_type = LangSmithRunType.llm @@ -246,7 +246,7 @@ class LangSmithDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): # get message file data file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) metadata = trace_info.metadata @@ -259,7 +259,7 @@ class LangSmithDataTrace(BaseTraceInstance): metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index dfb7a1f2e4..8fa92f9fcd 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -2,7 +2,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import Optional, cast +from typing import cast from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 @@ -46,7 +46,7 @@ def wrap_metadata(metadata, **kwargs): return metadata -def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]): +def prepare_opik_uuid(user_datetime: datetime | None, user_uuid: str | None): """Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most messages and objects. The type-hints of BaseTraceInfo indicates that objects start_time and message_id could be null which means we cannot map @@ -181,13 +181,13 @@ class OpikDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { @@ -201,7 +201,7 @@ class OpikDataTrace(BaseTraceInstance): } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} provider = None model = None @@ -263,7 +263,7 @@ class OpikDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): # get message file data file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data if message_file_data is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -281,7 +281,7 @@ class OpikDataTrace(BaseTraceInstance): metadata["file_list"] = file_list if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 4805faa5ab..66af061da3 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -1,3 +1,4 @@ +import collections import json import logging import os @@ -42,7 +43,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): +class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): def __getitem__(self, provider: str) -> dict[str, Any]: match provider: case TracingProviderEnum.LANGFUSE: @@ -123,7 +124,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): raise KeyError(f"Unsupported tracing provider: {provider}") -provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap() +provider_config_map = OpsTraceProviderConfigMap() class OpsTraceManager: @@ -220,7 +221,7 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: Optional[TraceAppConfig] = ( + trace_config_data: TraceAppConfig | None = ( db.session.query(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -244,7 +245,7 @@ class OpsTraceManager: @classmethod def get_ops_trace_instance( cls, - app_id: Optional[Union[UUID, str]] = None, + app_id: Union[UUID, str] | None = None, ): """ Get ops trace through model config @@ -257,7 +258,7 @@ class OpsTraceManager: if app_id is None: return None - app: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.query(App).where(App.id == app_id).first() if app is None: return None @@ -331,7 +332,7 @@ class OpsTraceManager: except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app_config: App | None = db.session.query(App).where(App.id == app_id).first() if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -349,7 +350,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.query(App).where(App.id == app_id).first() if not app: raise ValueError("App not found") if not app.tracing: @@ -825,7 +826,7 @@ class TraceTask: return generate_name_trace_info -trace_manager_timer: Optional[threading.Timer] = None +trace_manager_timer: threading.Timer | None = None trace_manager_queue: queue.Queue = queue.Queue() trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 2c0afb1600..5e8651d6f9 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from datetime import datetime -from typing import Optional, Union +from typing import Union from urllib.parse import urlparse from sqlalchemy import select @@ -49,9 +49,7 @@ def replace_text_with_content(data): return data -def generate_dotted_order( - run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None -) -> str: +def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_dotted_order: str | None = None) -> str: """ generate dotted_order for langsmith """ diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index 7f489f37ac..ef1a3be45b 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, Field, field_validator from pydantic_core.core_schema import ValidationInfo @@ -8,24 +8,24 @@ from core.ops.utils import replace_text_with_content class WeaveTokenUsage(BaseModel): - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - total_tokens: Optional[int] = None + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None class WeaveMultiModel(BaseModel): - file_list: Optional[list[str]] = Field(None, description="List of files") + file_list: list[str] | None = Field(None, description="List of files") class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): id: str = Field(..., description="ID of the trace") op: str = Field(..., description="Name of the operation") - inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace") - outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace") - attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( + inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the trace") + outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the trace") + attributes: Union[str, dict[str, Any], list, None] | None = Field( None, description="Metadata and attributes associated with trace" ) - exception: Optional[str] = Field(None, description="Exception message of the trace") + exception: str | None = Field(None, description="Exception message of the trace") @field_validator("inputs", "outputs") @classmethod diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index b103574f72..c6e69191de 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -2,7 +2,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import Any, Optional, cast +from typing import Any, cast import wandb import weave @@ -168,13 +168,13 @@ class WeaveDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 attributes = {str(k): v for k, v in execution_metadata.items()} attributes.update( @@ -189,7 +189,7 @@ class WeaveDataTrace(BaseTraceInstance): } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} if process_data and process_data.get("model_mode") == "chat": attributes.update( { @@ -222,7 +222,7 @@ class WeaveDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): # get message file data file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) attributes = trace_info.metadata @@ -235,7 +235,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -423,7 +423,7 @@ class WeaveDataTrace(BaseTraceInstance): logger.debug("Weave API check failed: %s", str(e)) raise ValueError(f"Weave API check failed: {str(e)}") - def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): + def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None): call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) self.calls[run_data.id] = call if parent_run_id: diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 9d9fb8f72e..8b08b09eb9 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping -from typing import Optional, Union +from typing import Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -27,7 +27,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): app = cls._get_app(app_id, tenant_id) """Retrieve app parameters.""" - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app.workflow if workflow is None: raise ValueError("unexpected app type") @@ -53,8 +53,8 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): app_id: str, user_id: str, tenant_id: str, - conversation_id: Optional[str], - query: Optional[str], + conversation_id: str | None, + query: str | None, stream: bool, inputs: Mapping, files: list[dict], @@ -70,7 +70,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): conversation_id = conversation_id or "" - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}: if not query: raise ValueError("missing query") @@ -96,7 +96,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ invoke chat app """ - if app.mode == AppMode.ADVANCED_CHAT.value: + if app.mode == AppMode.ADVANCED_CHAT: workflow = app.workflow if not workflow: raise ValueError("unexpected app type") @@ -114,7 +114,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): invoke_from=InvokeFrom.SERVICE_API, streaming=stream, ) - elif app.mode == AppMode.AGENT_CHAT.value: + elif app.mode == AppMode.AGENT_CHAT: return AgentChatAppGenerator().generate( app_model=app, user=user, @@ -127,7 +127,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): invoke_from=InvokeFrom.SERVICE_API, streaming=stream, ) - elif app.mode == AppMode.CHAT.value: + elif app.mode == AppMode.CHAT: return ChatAppGenerator().generate( app_model=app, user=user, diff --git a/api/core/plugin/backwards_invocation/base.py b/api/core/plugin/backwards_invocation/base.py index 2a5f857576..a89b0f95be 100644 --- a/api/core/plugin/backwards_invocation/base.py +++ b/api/core/plugin/backwards_invocation/base.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar from pydantic import BaseModel @@ -23,5 +23,5 @@ T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel) class BaseBackwardsInvocationResponse(BaseModel, Generic[T]): - data: Optional[T] = None + data: T | None = None error: str = "" diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index 06773504d9..c2d1574e67 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.backwards_invocation.base import BaseBackwardsInvocation @@ -23,7 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): provider: str, tool_name: str, tool_parameters: dict[str, Any], - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tool diff --git a/api/core/plugin/entities/endpoint.py b/api/core/plugin/entities/endpoint.py index d7ba75bb4f..e5bca140f8 100644 --- a/api/core/plugin/entities/endpoint.py +++ b/api/core/plugin/entities/endpoint.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional from pydantic import BaseModel, Field, model_validator @@ -24,7 +23,7 @@ class EndpointProviderDeclaration(BaseModel): """ settings: list[ProviderConfig] = Field(default_factory=list) - endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list[EndpointDeclaration]) + endpoints: list[EndpointDeclaration] | None = Field(default_factory=list[EndpointDeclaration]) class EndpointEntity(BasePluginEntity): diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 1c13a621d4..e0762619e6 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field, model_validator from core.model_runtime.entities.provider_entities import ProviderEntity @@ -19,11 +17,11 @@ class MarketplacePluginDeclaration(BaseModel): resource: PluginResourceRequirements = Field( ..., description="Specification of computational resources needed to run the plugin" ) - endpoint: Optional[EndpointProviderDeclaration] = Field( + endpoint: EndpointProviderDeclaration | None = Field( None, description="Configuration for the plugin's API endpoint, if applicable" ) - model: Optional[ProviderEntity] = Field(None, description="Details of the AI model used by the plugin, if any") - tool: Optional[ToolProviderEntity] = Field( + model: ProviderEntity | None = Field(None, description="Details of the AI model used by the plugin, if any") + tool: ToolProviderEntity | None = Field( None, description="Information about the tool functionality provided by the plugin, if any" ) latest_version: str = Field( diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index b46d973e36..68b5c1084a 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -1,6 +1,6 @@ -import enum import json -from typing import Any, Optional, Union +from enum import StrEnum, auto +from typing import Any, Union from pydantic import BaseModel, Field, field_validator @@ -11,9 +11,7 @@ from core.tools.entities.common_entities import I18nObject class PluginParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") - icon: Optional[str] = Field( - default=None, description="The icon of the option, can be a url or a base64 encoded image" - ) + icon: str | None = Field(default=None, description="The icon of the option, can be a url or a base64 encoded image") @field_validator("value", mode="before") @classmethod @@ -24,44 +22,44 @@ class PluginParameterOption(BaseModel): return value -class PluginParameterType(enum.StrEnum): +class PluginParameterType(StrEnum): """ all available parameter types """ - STRING = CommonParameterType.STRING.value - NUMBER = CommonParameterType.NUMBER.value - BOOLEAN = CommonParameterType.BOOLEAN.value - SELECT = CommonParameterType.SELECT.value - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - FILE = CommonParameterType.FILE.value - FILES = CommonParameterType.FILES.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value - ANY = CommonParameterType.ANY.value - DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value + STRING = CommonParameterType.STRING + NUMBER = CommonParameterType.NUMBER + BOOLEAN = CommonParameterType.BOOLEAN + SELECT = CommonParameterType.SELECT + SECRET_INPUT = CommonParameterType.SECRET_INPUT + FILE = CommonParameterType.FILE + FILES = CommonParameterType.FILES + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR + ANY = CommonParameterType.ANY + DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT # deprecated, should not use. - SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + SYSTEM_FILES = CommonParameterType.SYSTEM_FILES # MCP object and array type parameters - ARRAY = CommonParameterType.ARRAY.value - OBJECT = CommonParameterType.OBJECT.value + ARRAY = CommonParameterType.ARRAY + OBJECT = CommonParameterType.OBJECT -class MCPServerParameterType(enum.StrEnum): +class MCPServerParameterType(StrEnum): """ MCP server got complex parameter types """ - ARRAY = "array" - OBJECT = "object" + ARRAY = auto() + OBJECT = auto() class PluginParameterAutoGenerate(BaseModel): - class Type(enum.StrEnum): - PROMPT_INSTRUCTION = "prompt_instruction" + class Type(StrEnum): + PROMPT_INSTRUCTION = auto() type: Type @@ -73,15 +71,15 @@ class PluginParameterTemplate(BaseModel): class PluginParameter(BaseModel): name: str = Field(..., description="The name of the parameter") label: I18nObject = Field(..., description="The label presented to the user") - placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user") + placeholder: I18nObject | None = Field(default=None, description="The placeholder presented to the user") scope: str | None = None - auto_generate: Optional[PluginParameterAutoGenerate] = None - template: Optional[PluginParameterTemplate] = None + auto_generate: PluginParameterAutoGenerate | None = None + template: PluginParameterTemplate | None = None required: bool = False - default: Optional[Union[float, int, str]] = None - min: Optional[Union[float, int]] = None - max: Optional[Union[float, int]] = None - precision: Optional[int] = None + default: Union[float, int, str] | None = None + min: Union[float, int] | None = None + max: Union[float, int] | None = None + precision: int | None = None options: list[PluginParameterOption] = Field(default_factory=list) @field_validator("options", mode="before") @@ -92,7 +90,7 @@ class PluginParameter(BaseModel): return v -def as_normal_type(typ: enum.StrEnum): +def as_normal_type(typ: StrEnum): if typ.value in { PluginParameterType.SECRET_INPUT, PluginParameterType.SELECT, @@ -101,7 +99,7 @@ def as_normal_type(typ: enum.StrEnum): return typ.value -def cast_parameter_value(typ: enum.StrEnum, value: Any, /): +def cast_parameter_value(typ: StrEnum, value: Any, /): try: match typ.value: case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT: @@ -189,7 +187,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.") -def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: Any): +def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any): """ init frontend parameter by rule """ diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 7857ec7376..3063cd39ae 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -1,6 +1,6 @@ import datetime -import enum from collections.abc import Mapping +from enum import StrEnum, auto from typing import Any, Optional from packaging.version import InvalidVersion, Version @@ -15,11 +15,11 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity -class PluginInstallationSource(enum.StrEnum): - Github = "github" - Marketplace = "marketplace" - Package = "package" - Remote = "remote" +class PluginInstallationSource(StrEnum): + Github = auto() + Marketplace = auto() + Package = auto() + Remote = auto() class PluginResourceRequirements(BaseModel): @@ -27,40 +27,40 @@ class PluginResourceRequirements(BaseModel): class Permission(BaseModel): class Tool(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) class Model(BaseModel): - enabled: Optional[bool] = Field(default=False) - llm: Optional[bool] = Field(default=False) - text_embedding: Optional[bool] = Field(default=False) - rerank: Optional[bool] = Field(default=False) - tts: Optional[bool] = Field(default=False) - speech2text: Optional[bool] = Field(default=False) - moderation: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) + llm: bool | None = Field(default=False) + text_embedding: bool | None = Field(default=False) + rerank: bool | None = Field(default=False) + tts: bool | None = Field(default=False) + speech2text: bool | None = Field(default=False) + moderation: bool | None = Field(default=False) class Node(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) class Endpoint(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) class Storage(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) size: int = Field(ge=1024, le=1073741824, default=1048576) - tool: Optional[Tool] = Field(default=None) - model: Optional[Model] = Field(default=None) - node: Optional[Node] = Field(default=None) - endpoint: Optional[Endpoint] = Field(default=None) - storage: Optional[Storage] = Field(default=None) + tool: Tool | None = Field(default=None) + model: Model | None = Field(default=None) + node: Node | None = Field(default=None) + endpoint: Endpoint | None = Field(default=None) + storage: Storage | None = Field(default=None) - permission: Optional[Permission] = Field(default=None) + permission: Permission | None = Field(default=None) -class PluginCategory(enum.StrEnum): - Tool = "tool" - Model = "model" - Extension = "extension" +class PluginCategory(StrEnum): + Tool = auto() + Model = auto() + Extension = auto() AgentStrategy = "agent-strategy" Datasource = "datasource" @@ -73,12 +73,12 @@ class PluginDeclaration(BaseModel): datasources: Optional[list[str]] = Field(default_factory=list[str]) class Meta(BaseModel): - minimum_dify_version: Optional[str] = Field(default=None) - version: Optional[str] = Field(default=None) + minimum_dify_version: str | None = Field(default=None) + version: str | None = Field(default=None) @field_validator("minimum_dify_version") @classmethod - def validate_minimum_dify_version(cls, v: Optional[str]) -> Optional[str]: + def validate_minimum_dify_version(cls, v: str | None) -> str | None: if v is None: return v try: @@ -88,18 +88,18 @@ class PluginDeclaration(BaseModel): raise ValueError(f"Invalid version format: {v}") from e version: str = Field(...) - author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$") + author: str | None = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$") name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$") description: I18nObject icon: str - icon_dark: Optional[str] = Field(default=None) + icon_dark: str | None = Field(default=None) label: I18nObject category: PluginCategory created_at: datetime.datetime resource: PluginResourceRequirements plugins: Plugins tags: list[str] = Field(default_factory=list) - repo: Optional[str] = Field(default=None) + repo: str | None = Field(default=None) verified: bool = Field(default=False) tool: Optional[ToolProviderEntity] = None model: Optional[ProviderEntity] = None @@ -161,10 +161,10 @@ class PluginEntity(PluginInstallation): class PluginDependency(BaseModel): - class Type(enum.StrEnum): - Github = PluginInstallationSource.Github.value - Marketplace = PluginInstallationSource.Marketplace.value - Package = PluginInstallationSource.Package.value + class Type(StrEnum): + Github = PluginInstallationSource.Github + Marketplace = PluginInstallationSource.Marketplace + Package = PluginInstallationSource.Package class Github(BaseModel): repo: str @@ -188,9 +188,9 @@ class PluginDependency(BaseModel): type: Type value: Github | Marketplace | Package - current_identifier: Optional[str] = None + current_identifier: str | None = None class MissingPluginDependency(BaseModel): plugin_unique_identifier: str - current_identifier: Optional[str] = None + current_identifier: str | None = None diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 2cb96ac7bb..f15acc16f9 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Generic, TypeVar from pydantic import BaseModel, ConfigDict, Field @@ -25,7 +25,7 @@ class PluginDaemonBasicResponse(BaseModel, Generic[T]): code: int message: str - data: Optional[T] + data: T | None = None class InstallPluginMessage(BaseModel): @@ -183,7 +183,7 @@ class PluginVerification(BaseModel): class PluginDecodeResponse(BaseModel): unique_identifier: str = Field(description="The unique identifier of the plugin.") manifest: PluginDeclaration - verification: Optional[PluginVerification] = Field(default=None, description="Basic verification information") + verification: PluginVerification | None = Field(default=None, description="Basic verification information") class PluginOAuthAuthorizationUrlResponse(BaseModel): diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 3a783dad3e..10f37f75f8 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -35,7 +35,7 @@ class InvokeCredentials(BaseModel): class PluginInvokeContext(BaseModel): - credentials: Optional[InvokeCredentials] = Field( + credentials: InvokeCredentials | None = Field( default_factory=InvokeCredentials, description="Credentials context for the plugin invocation or backward invocation.", ) @@ -50,7 +50,7 @@ class RequestInvokeTool(BaseModel): provider: str tool: str tool_parameters: dict - credential_id: Optional[str] = None + credential_id: str | None = None class BaseRequestInvokeModel(BaseModel): @@ -70,9 +70,9 @@ class RequestInvokeLLM(BaseRequestInvokeModel): mode: str completion_params: dict[str, Any] = Field(default_factory=dict) prompt_messages: list[PromptMessage] = Field(default_factory=list) - tools: Optional[list[PromptMessageTool]] = Field(default_factory=list[PromptMessageTool]) - stop: Optional[list[str]] = Field(default_factory=list[str]) - stream: Optional[bool] = False + tools: list[PromptMessageTool] | None = Field(default_factory=list[PromptMessageTool]) + stop: list[str] | None = Field(default_factory=list[str]) + stream: bool | None = False model_config = ConfigDict(protected_namespaces=()) @@ -194,10 +194,10 @@ class RequestInvokeApp(BaseModel): app_id: str inputs: dict[str, Any] - query: Optional[str] = None + query: str | None = None response_mode: Literal["blocking", "streaming"] - conversation_id: Optional[str] = None - user: Optional[str] = None + conversation_id: str | None = None + user: str | None = None files: list[dict] = Field(default_factory=list) diff --git a/api/core/plugin/impl/agent.py b/api/core/plugin/impl/agent.py index 544d582f03..7e428939bf 100644 --- a/api/core/plugin/impl/agent.py +++ b/api/core/plugin/impl/agent.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.agent.entities import AgentInvokeMessage from core.plugin.entities.plugin_daemon import ( @@ -82,10 +82,10 @@ class PluginAgentClient(BasePluginClient): agent_provider: str, agent_strategy: str, agent_params: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - context: Optional[PluginInvokeContext] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + context: PluginInvokeContext | None = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent with the given tenant, user, plugin, provider, name and parameters. diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 85a72d9f82..153da142f4 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -1,6 +1,6 @@ import binascii from collections.abc import Generator, Sequence -from typing import IO, Optional +from typing import IO from core.model_runtime.entities.llm_entities import LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -151,9 +151,9 @@ class PluginModelClient(BasePluginClient): model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, ) -> Generator[LLMResultChunk, None, None]: """ @@ -200,7 +200,7 @@ class PluginModelClient(BasePluginClient): model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> int: """ Get number of tokens for llm @@ -325,8 +325,8 @@ class PluginModelClient(BasePluginClient): credentials: dict, query: str, docs: list[str], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, + score_threshold: float | None = None, + top_n: int | None = None, ) -> RerankResult: """ Invoke rerank @@ -414,7 +414,7 @@ class PluginModelClient(BasePluginClient): provider: str, model: str, credentials: dict, - language: Optional[str] = None, + language: str | None = None, ): """ Get tts model voices diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 14b6e81700..bc4de38099 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -91,9 +91,9 @@ class PluginToolManager(BasePluginClient): credentials: dict[str, Any], credential_type: CredentialType, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Invoke the tool with the given tenant, user, plugin, provider, name, credentials and parameters. @@ -193,9 +193,9 @@ class PluginToolManager(BasePluginClient): provider: str, credentials: dict[str, Any], tool: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get the runtime parameters of the tool diff --git a/api/core/plugin/utils/chunk_merger.py b/api/core/plugin/utils/chunk_merger.py index ec66ba02ee..e30076f9d3 100644 --- a/api/core/plugin/utils/chunk_merger.py +++ b/api/core/plugin/utils/chunk_merger.py @@ -82,7 +82,9 @@ def merge_blob_chunks( message_class = type(resp) merged_message = message_class( type=ToolInvokeMessage.MessageType.BLOB, - message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]), + message=ToolInvokeMessage.BlobMessage( + blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written]) + ), meta=resp.meta, ) yield cast(MessageType, merged_message) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 11c6e5c23b..5f2ffefd94 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Optional, cast +from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import file_manager @@ -41,11 +41,11 @@ class AdvancedPromptTransform(PromptTransform): inputs: Mapping[str, str], query: str, files: Sequence[File], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], + context: str | None, + memory_config: MemoryConfig | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: prompt_messages = [] @@ -80,13 +80,13 @@ class AdvancedPromptTransform(PromptTransform): self, prompt_template: CompletionModelPromptTemplate, inputs: Mapping[str, str], - query: Optional[str], + query: str | None, files: Sequence[File], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], + context: str | None, + memory_config: MemoryConfig | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Get completion model prompt messages. @@ -141,13 +141,13 @@ class AdvancedPromptTransform(PromptTransform): self, prompt_template: list[ChatModelMessage], inputs: Mapping[str, str], - query: Optional[str], + query: str | None, files: Sequence[File], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], + context: str | None, + memory_config: MemoryConfig | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Get chat model prompt messages. diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 09f017a7db..a96b094e6d 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, cast +from typing import cast from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, @@ -23,7 +23,7 @@ class AgentHistoryPromptTransform(PromptTransform): model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage], history_messages: list[PromptMessage], - memory: Optional[TokenBufferMemory] = None, + memory: TokenBufferMemory | None = None, ): self.model_config = model_config self.prompt_messages = prompt_messages diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index c8e7b414df..7094633093 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -12,7 +12,7 @@ class ChatModelMessage(BaseModel): text: str role: PromptMessageRole - edition_type: Optional[Literal["basic", "jinja2"]] = None + edition_type: Literal["basic", "jinja2"] | None = None class CompletionModelPromptTemplate(BaseModel): @@ -21,7 +21,7 @@ class CompletionModelPromptTemplate(BaseModel): """ text: str - edition_type: Optional[Literal["basic", "jinja2"]] = None + edition_type: Literal["basic", "jinja2"] | None = None class MemoryConfig(BaseModel): @@ -43,8 +43,8 @@ class MemoryConfig(BaseModel): """ enabled: bool - size: Optional[int] = None + size: int | None = None - role_prefix: Optional[RolePrefix] = None + role_prefix: RolePrefix | None = None window: WindowConfig - query_prompt_template: Optional[str] = None + query_prompt_template: str | None = None diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 1f040599be..a6e873d587 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -55,8 +55,8 @@ class PromptTransform: memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None, + human_prefix: str | None = None, + ai_prefix: str | None = None, ) -> str: """Get memory messages.""" kwargs: dict[str, Any] = {"max_token_limit": max_token_limit} diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d75a230d73..d1d518a55d 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,8 +1,8 @@ -import enum import json import os from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from enum import StrEnum, auto +from typing import TYPE_CHECKING, Any, cast from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -25,9 +25,9 @@ if TYPE_CHECKING: from core.file.models import File -class ModelMode(enum.StrEnum): - COMPLETION = "completion" - CHAT = "chat" +class ModelMode(StrEnum): + COMPLETION = auto() + CHAT = auto() prompt_file_contents: dict[str, Any] = {} @@ -45,11 +45,11 @@ class SimplePromptTransform(PromptTransform): inputs: Mapping[str, str], query: str, files: Sequence["File"], - context: Optional[str], - memory: Optional[TokenBufferMemory], + context: str | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: inputs = {key: str(value) for key, value in inputs.items()} model_mode = ModelMode(model_config.mode) @@ -86,9 +86,9 @@ class SimplePromptTransform(PromptTransform): model_config: ModelConfigWithCredentialsEntity, pre_prompt: str, inputs: dict, - query: Optional[str] = None, - context: Optional[str] = None, - histories: Optional[str] = None, + query: str | None = None, + context: str | None = None, + histories: str | None = None, ) -> tuple[str, dict]: # get prompt template prompt_template_config = self.get_prompt_template( @@ -101,9 +101,22 @@ class SimplePromptTransform(PromptTransform): with_memory_prompt=histories is not None, ) - variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs} + custom_variable_keys_obj = prompt_template_config["custom_variable_keys"] + special_variable_keys_obj = prompt_template_config["special_variable_keys"] - for v in prompt_template_config["special_variable_keys"]: + # Type check for custom_variable_keys + if not isinstance(custom_variable_keys_obj, list): + raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}") + custom_variable_keys = cast(list[str], custom_variable_keys_obj) + + # Type check for special_variable_keys + if not isinstance(special_variable_keys_obj, list): + raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}") + special_variable_keys = cast(list[str], special_variable_keys_obj) + + variables = {k: inputs[k] for k in custom_variable_keys if k in inputs} + + for v in special_variable_keys: # support #context#, #query# and #histories# if v == "#context#": variables["#context#"] = context or "" @@ -113,9 +126,16 @@ class SimplePromptTransform(PromptTransform): variables["#histories#"] = histories or "" prompt_template = prompt_template_config["prompt_template"] + if not isinstance(prompt_template, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template)}") + prompt = prompt_template.format(variables) - return prompt, prompt_template_config["prompt_rules"] + prompt_rules = prompt_template_config["prompt_rules"] + if not isinstance(prompt_rules, dict): + raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}") + + return prompt, prompt_rules def get_prompt_template( self, @@ -126,11 +146,11 @@ class SimplePromptTransform(PromptTransform): has_context: bool, query_in_prompt: bool, with_memory_prompt: bool = False, - ): + ) -> dict[str, object]: prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) - custom_variable_keys = [] - special_variable_keys = [] + custom_variable_keys: list[str] = [] + special_variable_keys: list[str] = [] prompt = "" for order in prompt_rules["system_prompt_orders"]: @@ -162,12 +182,12 @@ class SimplePromptTransform(PromptTransform): pre_prompt: str, inputs: dict, query: str, - context: Optional[str], + context: str | None, files: Sequence["File"], - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: prompt_messages: list[PromptMessage] = [] # get prompt @@ -208,12 +228,12 @@ class SimplePromptTransform(PromptTransform): pre_prompt: str, inputs: dict, query: str, - context: Optional[str], + context: str | None, files: Sequence["File"], - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: # get prompt prompt, prompt_rules = self._get_prompt_str_and_rules( app_mode=app_mode, @@ -261,7 +281,7 @@ class SimplePromptTransform(PromptTransform): self, prompt: str, files: Sequence["File"], - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> UserPromptMessage: if files: prompt_message_contents: list[PromptMessageContentUnionTypes] = [] diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index f5a6274e0d..6f642ab5db 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,7 +3,7 @@ import json from collections import defaultdict from collections.abc import Sequence from json import JSONDecodeError -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -281,7 +281,7 @@ class ProviderManager: model_type_instance=model_type_instance, ) - def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]: + def get_default_model(self, tenant_id: str, model_type: ModelType) -> DefaultModelEntity | None: """ Get default model. @@ -1036,8 +1036,8 @@ class ProviderManager: def _to_model_settings( self, provider_entity: ProviderEntity, - provider_model_settings: Optional[list[ProviderModelSetting]] = None, - load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None, + provider_model_settings: list[ProviderModelSetting] | None = None, + load_balancing_model_configs: list[LoadBalancingModelConfig] | None = None, ) -> list[ModelSettings]: """ Convert to model settings. diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index d17d76333e..696e3e967f 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError @@ -18,8 +16,8 @@ class DataPostProcessor: self, tenant_id: str, reranking_mode: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, + reranking_model: dict | None = None, + weights: dict | None = None, reorder_enabled: bool = False, ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) @@ -29,9 +27,9 @@ class DataPostProcessor: self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: if self.rerank_runner: documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) @@ -45,9 +43,9 @@ class DataPostProcessor: self, reranking_mode: str, tenant_id: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, - ) -> Optional[BaseRerankRunner]: + reranking_model: dict | None = None, + weights: dict | None = None, + ) -> BaseRerankRunner | None: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: runner = RerankRunnerFactory.create_rerank_runner( runner_type=reranking_mode, @@ -74,12 +72,12 @@ class DataPostProcessor: return runner return None - def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: + def _get_reorder_runner(self, reorder_enabled) -> ReorderRunner | None: if reorder_enabled: return ReorderRunner() return None - def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None: + def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None: if reranking_model: try: model_manager = ModelManager() diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 2e5cbde6dd..3d69d86b65 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Optional +from typing import Any import orjson from pydantic import BaseModel @@ -30,7 +30,7 @@ class Jieba(BaseKeyword): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() keyword_number = ( - self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + self.dataset.keyword_number or self._config.max_keywords_per_chunk ) for text in texts: @@ -53,7 +53,7 @@ class Jieba(BaseKeyword): keyword_table = self._get_dataset_keyword_table() keywords_list = kwargs.get("keywords_list") keyword_number = ( - self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + self.dataset.keyword_number or self._config.max_keywords_per_chunk ) for i in range(len(texts)): text = texts[i] @@ -144,7 +144,7 @@ class Jieba(BaseKeyword): storage.delete(file_key) storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8")) - def _get_dataset_keyword_table(self) -> Optional[dict]: + def _get_dataset_keyword_table(self) -> dict | None: dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict @@ -240,7 +240,7 @@ class Jieba(BaseKeyword): ) else: keyword_number = ( - self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + self.dataset.keyword_number or self._config.max_keywords_per_chunk ) keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index a6214d955b..81619570f9 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,5 +1,5 @@ import re -from typing import Optional, cast +from typing import cast class JiebaKeywordTableHandler: @@ -10,7 +10,7 @@ class JiebaKeywordTableHandler: jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore - def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: + def extract_keywords(self, text: str, max_keywords_per_chunk: int | None = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" import jieba.analyse # type: ignore diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index fefd42f84d..429744c0de 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,6 +1,5 @@ import concurrent.futures from concurrent.futures import ThreadPoolExecutor -from typing import Optional from flask import Flask, current_app from sqlalchemy import select @@ -39,11 +38,11 @@ class RetrievalService: dataset_id: str, query: str, top_k: int, - score_threshold: Optional[float] = 0.0, - reranking_model: Optional[dict] = None, + score_threshold: float | None = 0.0, + reranking_model: dict | None = None, reranking_mode: str = "reranking_model", - weights: Optional[dict] = None, - document_ids_filter: Optional[list[str]] = None, + weights: dict | None = None, + document_ids_filter: list[str] | None = None, ): if not query: return [] @@ -125,8 +124,8 @@ class RetrievalService: cls, dataset_id: str, query: str, - external_retrieval_model: Optional[dict] = None, - metadata_filtering_conditions: Optional[dict] = None, + external_retrieval_model: dict | None = None, + metadata_filtering_conditions: dict | None = None, ): stmt = select(Dataset).where(Dataset.id == dataset_id) dataset = db.session.scalar(stmt) @@ -145,7 +144,7 @@ class RetrievalService: return all_documents @classmethod - def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: + def _get_dataset(cls, dataset_id: str) -> Dataset | None: with Session(db.engine) as session: return session.query(Dataset).where(Dataset.id == dataset_id).first() @@ -158,7 +157,7 @@ class RetrievalService: top_k: int, all_documents: list, exceptions: list, - document_ids_filter: Optional[list[str]] = None, + document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): try: @@ -182,12 +181,12 @@ class RetrievalService: dataset_id: str, query: str, top_k: int, - score_threshold: Optional[float], - reranking_model: Optional[dict], + score_threshold: float | None, + reranking_model: dict | None, all_documents: list, retrieval_method: str, exceptions: list, - document_ids_filter: Optional[list[str]] = None, + document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): try: @@ -235,12 +234,12 @@ class RetrievalService: dataset_id: str, query: str, top_k: int, - score_threshold: Optional[float], - reranking_model: Optional[dict], + score_threshold: float | None, + reranking_model: dict | None, all_documents: list, retrieval_method: str, exceptions: list, - document_ids_filter: Optional[list[str]] = None, + document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): try: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index c3a6127e4a..77a0fa6cf2 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, model_validator @@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): account: str account_password: str namespace: str = "dify" - namespace_password: Optional[str] = None + namespace_password: str | None = None metrics: str = "cosine" read_timeout: int = 60000 diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index e7128b183e..de1572410c 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional +from typing import Any import chromadb from chromadb import QueryResult, Settings @@ -20,8 +20,8 @@ class ChromaConfig(BaseModel): port: int tenant: str database: str - auth_provider: Optional[str] = None - auth_credentials: Optional[str] = None + auth_provider: str | None = None + auth_credentials: str | None = None def to_chroma_params(self): settings = Settings( diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index eb4cbd2324..e55e5f3101 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -84,7 +84,7 @@ class ClickzettaConnectionPool: self._pool_locks: dict[str, threading.Lock] = {} self._max_pool_size = 5 # Maximum connections per configuration self._connection_timeout = 300 # 5 minutes timeout - self._cleanup_thread: Optional[threading.Thread] = None + self._cleanup_thread: threading.Thread | None = None self._shutdown = False self._start_cleanup_thread() @@ -303,8 +303,8 @@ class ClickzettaVector(BaseVector): """ # Class-level write queue and lock for serializing writes - _write_queue: Optional[queue.Queue] = None - _write_thread: Optional[threading.Thread] = None + _write_queue: queue.Queue | None = None + _write_thread: threading.Thread | None = None _write_lock = threading.Lock() _shutdown = False @@ -328,7 +328,7 @@ class ClickzettaVector(BaseVector): def __init__(self, vector_instance: "ClickzettaVector"): self.vector = vector_instance - self.connection: Optional[Connection] = None + self.connection: Connection | None = None def __enter__(self) -> "Connection": self.connection = self.vector._get_connection() @@ -641,7 +641,7 @@ class ClickzettaVector(BaseVector): for doc, embedding in zip(batch_docs, batch_embeddings): # Optimized: minimal checks for common case, fallback for edge cases - metadata = doc.metadata if doc.metadata else {} + metadata = doc.metadata or {} if not isinstance(metadata, dict): metadata = {} diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py index 7118029d40..7b00928b7b 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any from flask import current_app @@ -22,8 +22,8 @@ class ElasticSearchJaVector(ElasticSearchVector): def create_collection( self, embeddings: list[list[float]], - metadatas: Optional[list[dict[Any, Any]]] = None, - index_params: Optional[dict] = None, + metadatas: list[dict[Any, Any]] | None = None, + index_params: dict | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index df1c747585..2c147fa7ca 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, Optional, cast +from typing import Any, cast from urllib.parse import urlparse import requests @@ -24,18 +24,18 @@ logger = logging.getLogger(__name__) class ElasticSearchConfig(BaseModel): # Regular Elasticsearch config - host: Optional[str] = None - port: Optional[int] = None - username: Optional[str] = None - password: Optional[str] = None + host: str | None = None + port: int | None = None + username: str | None = None + password: str | None = None # Elastic Cloud specific config - cloud_url: Optional[str] = None # Cloud URL for Elasticsearch Cloud - api_key: Optional[str] = None + cloud_url: str | None = None # Cloud URL for Elasticsearch Cloud + api_key: str | None = None # Common config use_cloud: bool = False - ca_certs: Optional[str] = None + ca_certs: str | None = None verify_certs: bool = False request_timeout: int = 100000 retry_on_timeout: bool = True @@ -256,8 +256,8 @@ class ElasticSearchVector(BaseVector): def create_collection( self, embeddings: list[list[float]], - metadatas: Optional[list[dict[Any, Any]]] = None, - index_params: Optional[dict] = None, + metadatas: list[dict[Any, Any]] | None = None, + index_params: dict | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py index 9887e21b7c..8fc94be360 100644 --- a/api/core/rag/datasource/vdb/field.py +++ b/api/core/rag/datasource/vdb/field.py @@ -1,13 +1,13 @@ -from enum import Enum +from enum import StrEnum, auto -class Field(Enum): +class Field(StrEnum): CONTENT_KEY = "page_content" METADATA_KEY = "metadata" GROUP_KEY = "group_id" - VECTOR = "vector" + VECTOR = auto() # Sparse Vector aims to support full text search - SPARSE_VECTOR = "sparse_vector" + SPARSE_VECTOR = auto() TEXT_KEY = "text" PRIMARY_KEY = "id" DOC_ID = "metadata.doc_id" diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index 107ea75e6a..cfee090768 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -1,7 +1,7 @@ import json import logging import ssl -from typing import Any, Optional +from typing import Any from elasticsearch import Elasticsearch from pydantic import BaseModel, model_validator @@ -28,8 +28,8 @@ def create_ssl_context() -> ssl.SSLContext: class HuaweiCloudVectorConfig(BaseModel): hosts: str - username: str | None - password: str | None + username: str | None = None + password: str | None = None @model_validator(mode="before") @classmethod @@ -157,8 +157,8 @@ class HuaweiCloudVector(BaseVector): def create_collection( self, embeddings: list[list[float]], - metadatas: Optional[list[dict[Any, Any]]] = None, - index_params: Optional[dict] = None, + metadatas: list[dict[Any, Any]] | None = None, + index_params: dict | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 5097412c2c..f3ec30d178 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -2,7 +2,7 @@ import copy import json import logging import time -from typing import Any, Optional +from typing import Any from opensearchpy import OpenSearch, helpers from opensearchpy.helpers import BulkIndexError @@ -29,10 +29,10 @@ UGC_INDEX_PREFIX = "ugc_index" class LindormVectorStoreConfig(BaseModel): hosts: str - username: Optional[str] = None - password: Optional[str] = None - using_ugc: Optional[bool] = False - request_timeout: Optional[float] = 1.0 # timeout units: s + username: str | None = None + password: str | None = None + using_ugc: bool | None = False + request_timeout: float | None = 1.0 # timeout units: s @model_validator(mode="before") @classmethod @@ -448,13 +448,13 @@ def default_text_search_query( query_text: str, k: int = 4, text_field: str = Field.CONTENT_KEY.value, - must: Optional[list[dict]] = None, - must_not: Optional[list[dict]] = None, - should: Optional[list[dict]] = None, + must: list[dict] | None = None, + must_not: list[dict] | None = None, + should: list[dict] | None = None, minimum_should_match: int = 0, - filters: Optional[list[dict]] = None, - routing: Optional[str] = None, - routing_field: Optional[str] = None, + filters: list[dict] | None = None, + routing: str | None = None, + routing_field: str | None = None, **kwargs, ): query_clause: dict[str, Any] = {} @@ -505,13 +505,13 @@ def default_vector_search_query( query_vector: list[float], k: int = 4, min_score: str = "0.0", - ef_search: Optional[str] = None, # only for hnsw - nprobe: Optional[str] = None, # "2000" - reorder_factor: Optional[str] = None, # "20" - client_refactor: Optional[str] = None, # "true" + ef_search: str | None = None, # only for hnsw + nprobe: str | None = None, # "2000" + reorder_factor: str | None = None, # "20" + client_refactor: str | None = None, # "true" vector_field: str = Field.VECTOR.value, - filters: Optional[list[dict]] = None, - filter_type: Optional[str] = None, + filters: list[dict] | None = None, + filter_type: str | None = None, **kwargs, ): if filters is not None: diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 7da830f643..6fe396dc1e 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -1,8 +1,9 @@ import json import logging import uuid +from collections.abc import Callable from functools import wraps -from typing import Any, Optional +from typing import Any, Concatenate, ParamSpec, TypeVar from mo_vector.client import MoVectorClient # type: ignore from pydantic import BaseModel, model_validator @@ -17,7 +18,6 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset logger = logging.getLogger(__name__) -from typing import ParamSpec, TypeVar P = ParamSpec("P") R = TypeVar("R") @@ -47,16 +47,6 @@ class MatrixoneConfig(BaseModel): return values -def ensure_client(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if self.client is None: - self.client = self._get_client(None, False) - return func(self, *args, **kwargs) - - return wrapper - - class MatrixoneVector(BaseVector): """ Matrixone vector storage implementation. @@ -84,7 +74,7 @@ class MatrixoneVector(BaseVector): self.client = self._get_client(len(embeddings[0]), True) return self.add_texts(texts, embeddings) - def _get_client(self, dimension: Optional[int] = None, create_table: bool = False) -> MoVectorClient: + def _get_client(self, dimension: int | None = None, create_table: bool = False) -> MoVectorClient: """ Create a new client for the collection. @@ -113,7 +103,7 @@ class MatrixoneVector(BaseVector): self.client = self._get_client(len(embeddings[0]), True) assert self.client is not None ids = [] - for _, doc in enumerate(documents): + for doc in documents: if doc.metadata is not None: doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) ids.append(doc_id) @@ -216,6 +206,19 @@ class MatrixoneVector(BaseVector): self.client.delete() +T = TypeVar("T", bound=MatrixoneVector) + + +def ensure_client(func: Callable[Concatenate[T, P], R]): + @wraps(func) + def wrapper(self: T, *args: P.args, **kwargs: P.kwargs): + if self.client is None: + self.client = self._get_client(None, False) + return func(self, *args, **kwargs) + + return wrapper + + class MatrixoneVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector: if dataset.index_struct_dict: diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 2ec48ae365..5f32feb709 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any from packaging import version from pydantic import BaseModel, model_validator @@ -26,13 +26,13 @@ class MilvusConfig(BaseModel): """ uri: str # Milvus server URI - token: Optional[str] = None # Optional token for authentication - user: Optional[str] = None # Username for authentication - password: Optional[str] = None # Password for authentication + token: str | None = None # Optional token for authentication + user: str | None = None # Username for authentication + password: str | None = None # Password for authentication batch_size: int = 100 # Batch size for operations database: str = "default" # Database name enable_hybrid_search: bool = False # Flag to enable hybrid search - analyzer_params: Optional[str] = None # Analyzer params + analyzer_params: str | None = None # Analyzer params @model_validator(mode="before") @classmethod @@ -79,7 +79,7 @@ class MilvusVector(BaseVector): self._load_collection_fields() self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported - def _load_collection_fields(self, fields: Optional[list[str]] = None): + def _load_collection_fields(self, fields: list[str] | None = None): if fields is None: # Load collection fields from remote server collection_info = self._client.describe_collection(self._collection_name) @@ -292,7 +292,7 @@ class MilvusVector(BaseVector): ) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None ): """ Create a new collection in Milvus with the specified schema and index parameters. diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b590a4dfe4..17aac25b87 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -1,7 +1,7 @@ import json import logging import uuid -from enum import Enum +from enum import StrEnum from typing import Any from clickhouse_connect import get_client @@ -27,7 +27,7 @@ class MyScaleConfig(BaseModel): fts_params: str -class SortOrder(Enum): +class SortOrder(StrEnum): ASC = "ASC" DESC = "DESC" diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 3f65a4a275..3eb1df027e 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Literal, Optional +from typing import Any, Literal from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers @@ -26,10 +26,10 @@ class OpenSearchConfig(BaseModel): secure: bool = False # use_ssl verify_certs: bool = True auth_method: Literal["basic", "aws_managed_iam"] = "basic" - user: Optional[str] = None - password: Optional[str] = None - aws_region: Optional[str] = None - aws_service: Optional[str] = None + user: str | None = None + password: str | None = None + aws_region: str | None = None + aws_service: str | None = None @model_validator(mode="before") @classmethod @@ -104,7 +104,7 @@ class OpenSearchVector(BaseVector): }, } # See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377 - if self._client_config.aws_service not in ["aoss"]: + if self._client_config.aws_service != "aoss": action["_id"] = uuid4().hex actions.append(action) @@ -236,7 +236,7 @@ class OpenSearchVector(BaseVector): return docs def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None ): lock_name = f"vector_indexing_lock_{self._collection_name.lower()}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 12d97c500f..d46f29bd64 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import qdrant_client from flask import current_app @@ -40,17 +40,30 @@ if TYPE_CHECKING: MetadataFilter = Union[DictFilter, common_types.Filter] +class PathQdrantParams(BaseModel): + path: str + + +class UrlQdrantParams(BaseModel): + url: str + api_key: str | None + timeout: float + verify: bool + grpc_port: int + prefer_grpc: bool + + class QdrantConfig(BaseModel): endpoint: str - api_key: Optional[str] = None + api_key: str | None = None timeout: float = 20 - root_path: Optional[str] = None + root_path: str | None = None grpc_port: int = 6334 prefer_grpc: bool = False replication_factor: int = 1 write_consistency_factor: int = 1 - def to_qdrant_params(self): + def to_qdrant_params(self) -> PathQdrantParams | UrlQdrantParams: if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): @@ -58,23 +71,23 @@ class QdrantConfig(BaseModel): raise ValueError("Root path is not set") path = os.path.join(self.root_path, path) - return {"path": path} + return PathQdrantParams(path=path) else: - return { - "url": self.endpoint, - "api_key": self.api_key, - "timeout": self.timeout, - "verify": self.endpoint.startswith("https"), - "grpc_port": self.grpc_port, - "prefer_grpc": self.prefer_grpc, - } + return UrlQdrantParams( + url=self.endpoint, + api_key=self.api_key, + timeout=self.timeout, + verify=self.endpoint.startswith("https"), + grpc_port=self.grpc_port, + prefer_grpc=self.prefer_grpc, + ) class QdrantVector(BaseVector): def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"): super().__init__(collection_name) self._client_config = config - self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) + self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params().model_dump()) self._distance_func = distance_func.upper() self._group_id = group_id @@ -176,10 +189,10 @@ class QdrantVector(BaseVector): self, texts: Iterable[str], embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, batch_size: int = 64, - group_id: Optional[str] = None, + group_id: str | None = None, ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest @@ -221,7 +234,7 @@ class QdrantVector(BaseVector): def _build_payloads( cls, texts: Iterable[str], - metadatas: Optional[list[dict]], + metadatas: list[dict] | None, content_payload_key: str, metadata_payload_key: str, group_id: str, diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 9d3dc7c622..99698fcdd0 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -1,6 +1,6 @@ import json import uuid -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, model_validator from sqlalchemy import Column, String, Table, create_engine, insert @@ -160,7 +160,7 @@ class RelytVector(BaseVector): else: return None - def delete_by_uuids(self, ids: Optional[list[str]] = None): + def delete_by_uuids(self, ids: list[str] | None = None): """Delete by vector IDs. Args: @@ -241,7 +241,7 @@ class RelytVector(BaseVector): self, embedding: list[float], k: int = 4, - filter: Optional[dict] = None, + filter: dict | None = None, ) -> list[tuple[Document, float]]: # Add the filter if provided diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index 27685b7ddf..e91d9bb0d6 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -2,7 +2,7 @@ import json import logging import math from collections.abc import Iterable -from typing import Any, Optional +from typing import Any import tablestore # type: ignore from pydantic import BaseModel, model_validator @@ -22,11 +22,11 @@ logger = logging.getLogger(__name__) class TableStoreConfig(BaseModel): - access_key_id: Optional[str] = None - access_key_secret: Optional[str] = None - instance_name: Optional[str] = None - endpoint: Optional[str] = None - normalize_full_text_bm25_score: Optional[bool] = False + access_key_id: str | None = None + access_key_secret: str | None = None + instance_name: str | None = None + endpoint: str | None = None + normalize_full_text_bm25_score: bool | None = False @model_validator(mode="before") @classmethod diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 4af34bbb2d..291d047c04 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, Optional +from typing import Any from pydantic import BaseModel from tcvdb_text.encoder import BM25Encoder # type: ignore @@ -24,10 +24,10 @@ logger = logging.getLogger(__name__) class TencentConfig(BaseModel): url: str - api_key: Optional[str] + api_key: str | None = None timeout: float = 30 - username: Optional[str] - database: Optional[str] + username: str | None = None + database: str | None = None index_type: str = "HNSW" metric_type: str = "IP" shard: int = 1 diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 7055581459..f90a311df4 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import qdrant_client import requests @@ -45,9 +45,9 @@ if TYPE_CHECKING: class TidbOnQdrantConfig(BaseModel): endpoint: str - api_key: Optional[str] = None + api_key: str | None = None timeout: float = 20 - root_path: Optional[str] = None + root_path: str | None = None grpc_port: int = 6334 prefer_grpc: bool = False replication_factor: int = 1 @@ -180,10 +180,10 @@ class TidbOnQdrantVector(BaseVector): self, texts: Iterable[str], embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, batch_size: int = 64, - group_id: Optional[str] = None, + group_id: str | None = None, ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest @@ -225,7 +225,7 @@ class TidbOnQdrantVector(BaseVector): def _build_payloads( cls, texts: Iterable[str], - metadatas: Optional[list[dict]], + metadatas: list[dict] | None, content_payload_key: str, metadata_payload_key: str, group_id: str, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 184b5f2142..e1d4422144 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -1,5 +1,6 @@ import time import uuid +from collections.abc import Sequence import requests from requests.auth import HTTPDigestAuth @@ -139,7 +140,7 @@ class TidbService: @staticmethod def batch_update_tidb_serverless_cluster_status( - tidb_serverless_list: list[TidbAuthBinding], + tidb_serverless_list: Sequence[TidbAuthBinding], project_id: str, api_url: str, iam_url: str, diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index b2cc51d034..dc4f026ff3 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,7 +1,7 @@ import logging import time from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from sqlalchemy import select @@ -32,7 +32,7 @@ class AbstractVectorFactory(ABC): class Vector: - def __init__(self, dataset: Dataset, attributes: Optional[list] = None): + def __init__(self, dataset: Dataset, attributes: list | None = None): if attributes is None: attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self._dataset = dataset @@ -180,7 +180,7 @@ class Vector: case _: raise ValueError(f"Vector store {vector_type} is not supported.") - def create(self, texts: Optional[list] = None, **kwargs): + def create(self, texts: list | None = None, **kwargs): if texts: start = time.time() logger.info("start embedding %s texts %s", len(texts), start) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 43dde37c7e..3ec08b93ed 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -1,6 +1,6 @@ import datetime import json -from typing import Any, Optional +from typing import Any import requests import weaviate # type: ignore @@ -19,7 +19,7 @@ from models.dataset import Dataset class WeaviateConfig(BaseModel): endpoint: str - api_key: Optional[str] = None + api_key: str | None = None batch_size: int = 100 @model_validator(mode="before") diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 63c6db8d06..74a2653e9d 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any, Optional +from typing import Any from sqlalchemy import func, select @@ -15,7 +15,7 @@ class DatasetDocumentStore: self, dataset: Dataset, user_id: str, - document_id: Optional[str] = None, + document_id: str | None = None, ): self._dataset = dataset self._user_id = user_id @@ -176,7 +176,7 @@ class DatasetDocumentStore: result = self.get_document_segment(doc_id) return result is not None - def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[Document]: + def get_document(self, doc_id: str, raise_error: bool = True) -> Document | None: document_segment = self.get_document_segment(doc_id) if document_segment is None: @@ -217,16 +217,16 @@ class DatasetDocumentStore: document_segment.index_node_hash = doc_hash db.session.commit() - def get_document_hash(self, doc_id: str) -> Optional[str]: + def get_document_hash(self, doc_id: str) -> str | None: """Get the stored hash for a document, if it exists.""" document_segment = self.get_document_segment(doc_id) if document_segment is None: return None - data: Optional[str] = document_segment.index_node_hash + data: str | None = document_segment.index_node_hash return data - def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: + def get_document_segment(self, doc_id: str) -> DocumentSegment | None: stmt = select(DocumentSegment).where( DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id ) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 43be9cde69..5f94129a0c 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -1,6 +1,6 @@ import base64 import logging -from typing import Any, Optional, cast +from typing import Any, cast import numpy as np from sqlalchemy.exc import IntegrityError @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) class CacheEmbedding(Embeddings): - def __init__(self, model_instance: ModelInstance, user: Optional[str] = None): + def __init__(self, model_instance: ModelInstance, user: str | None = None): self._model_instance = model_instance self._user = user diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index 800422d888..8e92191568 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel from models.dataset import DocumentSegment @@ -19,5 +17,5 @@ class RetrievalSegments(BaseModel): model_config = {"arbitrary_types_allowed": True} segment: DocumentSegment - child_chunks: Optional[list[RetrievalChildChunk]] = None - score: Optional[float] = None + child_chunks: list[RetrievalChildChunk] | None = None + score: float | None = None diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index 00120425c9..aca879df7d 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -1,23 +1,23 @@ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel class RetrievalSourceMetadata(BaseModel): - position: Optional[int] = None - dataset_id: Optional[str] = None - dataset_name: Optional[str] = None - document_id: Optional[str] = None - document_name: Optional[str] = None - data_source_type: Optional[str] = None - segment_id: Optional[str] = None - retriever_from: Optional[str] = None - score: Optional[float] = None - hit_count: Optional[int] = None - word_count: Optional[int] = None - segment_position: Optional[int] = None - index_node_hash: Optional[str] = None - content: Optional[str] = None - page: Optional[int] = None - doc_metadata: Optional[dict[str, Any]] = None - title: Optional[str] = None + position: int | None = None + dataset_id: str | None = None + dataset_name: str | None = None + document_id: str | None = None + document_name: str | None = None + data_source_type: str | None = None + segment_id: str | None = None + retriever_from: str | None = None + score: float | None = None + hit_count: int | None = None + word_count: int | None = None + segment_position: int | None = None + index_node_hash: str | None = None + content: str | None = None + page: int | None = None + doc_metadata: dict[str, Any] | None = None + title: str | None = None diff --git a/api/core/rag/entities/context_entities.py b/api/core/rag/entities/context_entities.py index cd18ad081f..a2b03d54ba 100644 --- a/api/core/rag/entities/context_entities.py +++ b/api/core/rag/entities/context_entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel @@ -9,4 +7,4 @@ class DocumentContext(BaseModel): """ content: str - score: Optional[float] = None + score: float | None = None diff --git a/api/core/rag/entities/metadata_entities.py b/api/core/rag/entities/metadata_entities.py index 1f054bccdb..b07d760cf4 100644 --- a/api/core/rag/entities/metadata_entities.py +++ b/api/core/rag/entities/metadata_entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field @@ -30,7 +30,7 @@ SupportedComparisonOperator = Literal[ class Condition(BaseModel): """ - Conditon detail + Condition detail """ name: str @@ -43,5 +43,5 @@ class MetadataCondition(BaseModel): Metadata Condition. """ - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py index 60dbc449f7..1f91a3ece1 100644 --- a/api/core/rag/extractor/blob/blob.py +++ b/api/core/rag/extractor/blob/blob.py @@ -12,7 +12,7 @@ import mimetypes from collections.abc import Generator, Mapping from io import BufferedReader, BytesIO from pathlib import Path, PurePath -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, ConfigDict, model_validator @@ -30,17 +30,17 @@ class Blob(BaseModel): """ data: Union[bytes, str, None] = None # Raw data - mimetype: Optional[str] = None # Not to be confused with a file extension + mimetype: str | None = None # Not to be confused with a file extension encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string # Location where the original content was found # Represent location on the local file system # Useful for situations where downstream code assumes it must work with file paths # rather than in-memory content. - path: Optional[PathLike] = None + path: PathLike | None = None model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) @property - def source(self) -> Optional[str]: + def source(self) -> str | None: """The source location of the blob as string if known otherwise none.""" return str(self.path) if self.path else None @@ -91,7 +91,7 @@ class Blob(BaseModel): path: PathLike, *, encoding: str = "utf-8", - mime_type: Optional[str] = None, + mime_type: str | None = None, guess_type: bool = True, ) -> Blob: """Load the blob from a path like object. @@ -120,8 +120,8 @@ class Blob(BaseModel): data: Union[str, bytes], *, encoding: str = "utf-8", - mime_type: Optional[str] = None, - path: Optional[str] = None, + mime_type: str | None = None, + path: str | None = None, ) -> Blob: """Initialize the blob from in-memory data. diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 5b67403902..3bfae9d6bd 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -1,7 +1,6 @@ """Abstract interface for document loader implementations.""" import csv -from typing import Optional import pandas as pd @@ -21,10 +20,10 @@ class CSVExtractor(BaseExtractor): def __init__( self, file_path: str, - encoding: Optional[str] = None, + encoding: str | None = None, autodetect_encoding: bool = False, - source_column: Optional[str] = None, - csv_args: Optional[dict] = None, + source_column: str | None = None, + csv_args: dict | None = None, ): """Initialize with file path.""" self._file_path = file_path diff --git a/api/core/rag/extractor/entity/datasource_type.py b/api/core/rag/extractor/entity/datasource_type.py index 19ad300d11..6568f60ea2 100644 --- a/api/core/rag/extractor/entity/datasource_type.py +++ b/api/core/rag/extractor/entity/datasource_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class DatasourceType(Enum): +class DatasourceType(StrEnum): FILE = "upload_file" NOTION = "notion_import" WEBSITE = "website_crawl" diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 07f0e90de0..0a57c792f1 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, ConfigDict from models.dataset import Document @@ -15,7 +13,7 @@ class NotionInfo(BaseModel): notion_workspace_id: str notion_obj_id: str notion_page_type: str - document: Optional[Document] = None + document: Document | None = None tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) @@ -44,10 +42,10 @@ class ExtractSetting(BaseModel): """ datasource_type: str - upload_file: Optional[UploadFile] = None - notion_info: Optional[NotionInfo] = None - website_info: Optional[WebsiteInfo] = None - document_model: Optional[str] = None + upload_file: UploadFile | None = None + notion_info: NotionInfo | None = None + website_info: WebsiteInfo | None = None + document_model: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) def __init__(self, **data): diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index baa3fdf2eb..ea9c6bd73a 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" import os -from typing import Optional, cast +from typing import cast import pandas as pd from openpyxl import load_workbook @@ -18,7 +18,7 @@ class ExcelExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): + def __init__(self, file_path: str, encoding: str | None = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 3a7ad8f2ce..3dc08e1832 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -1,7 +1,7 @@ import re import tempfile from pathlib import Path -from typing import Optional, Union +from typing import Union from urllib.parse import unquote from configs import dify_config @@ -90,7 +90,7 @@ class ExtractProcessor: @classmethod def extract( - cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: Optional[str] = None + cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: @@ -104,7 +104,7 @@ class ExtractProcessor: input_file = Path(file_path) file_extension = input_file.suffix.lower() etl_type = dify_config.ETL_TYPE - extractor: Optional[BaseExtractor] = None + extractor: BaseExtractor | None = None if etl_type == "Unstructured": unstructured_api_url = dify_config.UNSTRUCTURED_API_URL or "" unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY or "" diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 17f7d8661f..00004409d6 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -1,17 +1,17 @@ """Document loader helpers.""" import concurrent.futures -from typing import NamedTuple, Optional, cast +from typing import NamedTuple, cast class FileEncoding(NamedTuple): """A file encoding as the NamedTuple.""" - encoding: Optional[str] + encoding: str | None """The encoding of the file.""" confidence: float """The confidence of the encoding.""" - language: Optional[str] + language: str | None """The language of the file.""" diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index 3845392c8d..79d6ae2dac 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -2,7 +2,6 @@ import re from pathlib import Path -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.helpers import detect_file_encodings @@ -22,7 +21,7 @@ class MarkdownExtractor(BaseExtractor): file_path: str, remove_hyperlinks: bool = False, remove_images: bool = False, - encoding: Optional[str] = None, + encoding: str | None = None, autodetect_encoding: bool = True, ): """Initialize with file path.""" @@ -45,13 +44,13 @@ class MarkdownExtractor(BaseExtractor): return documents - def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]: + def markdown_to_tups(self, markdown_text: str) -> list[tuple[str | None, str]]: """Convert a markdown file to a dictionary. The keys are the headers and the values are the text under each header. """ - markdown_tups: list[tuple[Optional[str], str]] = [] + markdown_tups: list[tuple[str | None, str]] = [] lines = markdown_text.split("\n") current_header = None @@ -94,7 +93,7 @@ class MarkdownExtractor(BaseExtractor): content = re.sub(pattern, r"\1", content) return content - def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]: + def parse_tups(self, filepath: str) -> list[tuple[str | None, str]]: """Parse file into tuples.""" content = "" try: diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index e0c68128dd..c1563840f0 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -329,7 +329,7 @@ class NotionExtractor(BaseExtractor): result_lines = "\n".join(result_lines_arr) return result_lines - def update_last_edited_time(self, document_model: Optional[DocumentModel]): + def update_last_edited_time(self, document_model: DocumentModel | None): if not document_model: return diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 3c43f34104..80530d99a6 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -2,7 +2,6 @@ import contextlib from collections.abc import Iterator -from typing import Optional from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor @@ -18,7 +17,7 @@ class PdfExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, file_cache_key: Optional[str] = None): + def __init__(self, file_path: str, file_cache_key: str | None = None): """Initialize with file path.""" self._file_path = file_path self._file_cache_key = file_cache_key diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py index a00d328cb1..93f301ceff 100644 --- a/api/core/rag/extractor/text_extractor.py +++ b/api/core/rag/extractor/text_extractor.py @@ -1,7 +1,6 @@ """Abstract interface for document loader implementations.""" from pathlib import Path -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.helpers import detect_file_encodings @@ -16,7 +15,7 @@ class TextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): + def __init__(self, file_path: str, encoding: str | None = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 4ed8dfbbd8..5199208f70 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -23,7 +23,7 @@ class UnstructuredWordExtractor(BaseExtractor): unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) # check the file extension try: - import magic # noqa: F401 # pyright: ignore[reportUnusedImport] + import magic # noqa: F401 is_doc = detect_filetype(self._file_path) == FileType.DOC except ImportError: diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 2427de8292..ad04bd0bd1 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,7 +1,6 @@ import base64 import contextlib import logging -from typing import Optional from bs4 import BeautifulSoup @@ -17,7 +16,7 @@ class UnstructuredEmailExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index fa91f7dd03..fc14ee6275 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import pypandoc # type: ignore @@ -20,7 +19,7 @@ class UnstructuredEpubExtractor(BaseExtractor): def __init__( self, file_path: str, - api_url: Optional[str] = None, + api_url: str | None = None, api_key: str = "", ): """Initialize with file path.""" diff --git a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 0a0c8d3a1c..23030d7739 100644 --- a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -16,7 +15,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor): """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index d363449c29..f29e639d1b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredMsgExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index ecc272a2f0..c12a55ee4b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredPPTExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index e7bf6fd2e6..99e3eec501 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredPPTXExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index 916cdc3f2b..d75e166f1b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredXmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index c59a70ea57..fe983aa86a 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient @@ -9,7 +9,7 @@ class WaterCrawlProvider: def __init__(self, api_key, base_url: str | None = None): self.client = WaterCrawlAPIClient(api_key, base_url) - def crawl_url(self, url, options: Optional[dict | Any] = None): + def crawl_url(self, url, options: dict | Any | None = None): options = options or {} spider_options = { "max_depth": 1, diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py index 05fbf9003b..9ad69e7fe3 100644 --- a/api/core/rag/index_processor/constant/built_in_field.py +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -1,15 +1,15 @@ -from enum import Enum, StrEnum +from enum import StrEnum, auto class BuiltInField(StrEnum): - document_name = "document_name" - uploader = "uploader" - upload_date = "upload_date" - last_update_date = "last_update_date" - source = "source" + document_name = auto() + uploader = auto() + upload_date = auto() + last_update_date = auto() + source = auto() -class MetadataDataSource(Enum): +class MetadataDataSource(StrEnum): upload_file = "file_upload" website_crawl = "website" notion_import = "notion" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index dea9b49631..b3fc4ac221 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from configs import dify_config from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -35,7 +35,7 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): raise NotImplementedError @abstractmethod @@ -64,7 +64,7 @@ class BaseIndexProcessor(ABC): max_tokens: int, chunk_overlap: int, separator: str, - embedding_model_instance: Optional["ModelInstance"], + embedding_model_instance: ModelInstance | None, ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 89a1fba798..755aa88d08 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -2,7 +2,7 @@ import uuid from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword @@ -89,7 +89,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): else: keyword.add_texts(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index f5e30b73c2..e0ccd8b567 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -3,7 +3,7 @@ import json import uuid from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from configs import dify_config from core.model_manager import ModelInstance @@ -114,25 +114,37 @@ class ParentChildIndexProcessor(BaseIndexProcessor): ] vector.create(formatted_child_documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids if dataset.indexing_technique == "high_quality": delete_child_chunks = kwargs.get("delete_child_chunks") or False + precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") vector = Vector(dataset) + if node_ids: - child_node_ids = ( - db.session.query(ChildChunk.index_node_id) - .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) - .where( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.index_node_id.in_(node_ids), - ChildChunk.dataset_id == dataset.id, + # Use precomputed child_node_ids if available (to avoid race conditions) + if precomputed_child_node_ids is not None: + child_node_ids = precomputed_child_node_ids + else: + # Fallback to original query (may fail if segments are already deleted) + child_node_ids = ( + db.session.query(ChildChunk.index_node_id) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() ) - .all() - ) - child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] - vector.delete_by_ids(child_node_ids) - if delete_child_chunks: + child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]] + + # Delete from vector index + if child_node_ids: + vector.delete_by_ids(child_node_ids) + + # Delete from database + if delete_child_chunks and child_node_ids: db.session.query(ChildChunk).where( ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) ).delete(synchronize_session=False) @@ -180,7 +192,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): document_node: Document, rules: Rule, process_rule_mode: str, - embedding_model_instance: Optional[ModelInstance], + embedding_model_instance: ModelInstance | None, ) -> list[ChildDocument]: if not rules.subchunk_segmentation: raise ValueError("No subchunk segmentation found in rules.") diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index eedac60aa0..2054031643 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -5,7 +5,7 @@ import re import threading import uuid from collections.abc import Mapping -from typing import Any, Optional +from typing import Any import pandas as pd from flask import Flask, current_app @@ -132,7 +132,7 @@ class QAIndexProcessor(BaseIndexProcessor): vector = Vector(dataset) vector.create(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 5ecd2f796b..4bd7b1d62e 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -10,7 +10,7 @@ class ChildDocument(BaseModel): page_content: str - vector: Optional[list[float]] = None + vector: list[float] | None = None """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). @@ -23,16 +23,16 @@ class Document(BaseModel): page_content: str - vector: Optional[list[float]] = None + vector: list[float] | None = None """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ metadata: dict = Field(default_factory=dict) - provider: Optional[str] = "dify" + provider: str | None = "dify" - children: Optional[list[ChildDocument]] = None + children: list[ChildDocument] | None = None class GeneralStructureChunk(BaseModel): diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 818b04b2ff..3561def008 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional from core.rag.models.document import Document @@ -10,9 +9,9 @@ class BaseRerankRunner(ABC): self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 7a6ebd1f39..e855b0083f 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.model_manager import ModelInstance from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner @@ -13,9 +11,9 @@ class RerankModelRunner(BaseRerankRunner): self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index ab49e43b70..c455db6095 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -1,6 +1,5 @@ import math from collections import Counter -from typing import Optional import numpy as np @@ -22,9 +21,9 @@ class WeightRerankRunner(BaseRerankRunner): self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 93bad23f2b..b08f80da49 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -4,7 +4,7 @@ import re import threading from collections import Counter, defaultdict from collections.abc import Generator, Mapping -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from flask import Flask, current_app from sqlalchemy import Float, and_, or_, select, text @@ -85,9 +85,9 @@ class DatasetRetrieval: show_retrieve_source: bool, hit_callback: DatasetIndexToolCallbackHandler, message_id: str, - memory: Optional[TokenBufferMemory] = None, - inputs: Optional[Mapping[str, Any]] = None, - ) -> Optional[str]: + memory: TokenBufferMemory | None = None, + inputs: Mapping[str, Any] | None = None, + ) -> str | None: """ Retrieve dataset. :param app_id: app_id @@ -290,9 +290,9 @@ class DatasetRetrieval: model_instance: ModelInstance, model_config: ModelConfigWithCredentialsEntity, planning_strategy: PlanningStrategy, - message_id: Optional[str] = None, - metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, - metadata_condition: Optional[MetadataCondition] = None, + message_id: str | None = None, + metadata_filter_document_ids: dict[str, list[str]] | None = None, + metadata_condition: MetadataCondition | None = None, ): tools = [] for dataset in available_datasets: @@ -410,12 +410,12 @@ class DatasetRetrieval: top_k: int, score_threshold: float, reranking_mode: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict[str, Any]] = None, + reranking_model: dict | None = None, + weights: dict[str, Any] | None = None, reranking_enable: bool = True, - message_id: Optional[str] = None, - metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, - metadata_condition: Optional[MetadataCondition] = None, + message_id: str | None = None, + metadata_filter_document_ids: dict[str, list[str]] | None = None, + metadata_condition: MetadataCondition | None = None, ): if not available_datasets: return [] @@ -505,9 +505,7 @@ class DatasetRetrieval: return all_documents - def _on_retrieval_end( - self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None - ): + def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None): """Handle retrieval end.""" dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: @@ -588,8 +586,8 @@ class DatasetRetrieval: query: str, top_k: int, all_documents: list, - document_ids_filter: Optional[list[str]] = None, - metadata_condition: Optional[MetadataCondition] = None, + document_ids_filter: list[str] | None = None, + metadata_condition: MetadataCondition | None = None, ): with flask_app.app_context(): dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -664,7 +662,7 @@ class DatasetRetrieval: hit_callback: DatasetIndexToolCallbackHandler, user_id: str, inputs: dict, - ) -> Optional[list[DatasetRetrieverBaseTool]]: + ) -> list[DatasetRetrieverBaseTool] | None: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id @@ -853,9 +851,9 @@ class DatasetRetrieval: user_id: str, metadata_filtering_mode: str, metadata_model_config: ModelConfig, - metadata_filtering_conditions: Optional[MetadataFilteringCondition], + metadata_filtering_conditions: MetadataFilteringCondition | None, inputs: dict, - ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: + ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: document_query = db.session.query(DatasetDocument).where( DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.indexing_status == "completed", @@ -950,7 +948,7 @@ class DatasetRetrieval: def _automatic_metadata_filter_func( self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig - ) -> Optional[list[dict[str, Any]]]: + ) -> list[dict[str, Any]] | None: # get all metadata field metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) metadata_fields = db.session.scalars(metadata_stmt).all() @@ -1005,7 +1003,7 @@ class DatasetRetrieval: return automatic_metadata_filters def _process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list + self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list ): if value is None and condition not in ("empty", "not empty"): return diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index d654463be9..8356861242 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any from core.model_manager import ModelInstance from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer @@ -24,7 +24,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): @classmethod def from_encoder( cls: type[TS], - embedding_model_instance: Optional[ModelInstance], + embedding_model_instance: ModelInstance | None, allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037 disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 **kwargs: Any, @@ -48,7 +48,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): - def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any): + def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) self._fixed_separator = fixed_separator diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index c5b6ac4608..41e6d771e9 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -9,7 +9,6 @@ from dataclasses import dataclass from typing import ( Any, Literal, - Optional, TypeVar, Union, ) @@ -71,7 +70,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: list[dict] | None = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] @@ -94,7 +93,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): metadatas.append(doc.metadata or {}) return self.create_documents(texts, metadatas=metadatas) - def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: + def _join_docs(self, docs: list[str], separator: str) -> str | None: text = separator.join(docs) text = text.strip() if text == "": @@ -110,9 +109,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): docs = [] current_doc: list[str] = [] total = 0 - index = 0 - for d in splits: - _len = lengths[index] + for d, _len in zip(splits, lengths): if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: if total > self._chunk_size: logger.warning( @@ -134,7 +131,6 @@ class TextSplitter(BaseDocumentTransformer, ABC): current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) - index += 1 doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) @@ -197,7 +193,7 @@ class TokenTextSplitter(TextSplitter): def __init__( self, encoding_name: str = "gpt2", - model_name: Optional[str] = None, + model_name: str | None = None, allowed_special: Union[Literal["all"], Set[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, @@ -248,7 +244,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): def __init__( self, - separators: Optional[list[str]] = None, + separators: list[str] | None = None, keep_separator: bool = True, **kwargs: Any, ): diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index d6f40491b6..eda7b54d6a 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -6,7 +6,7 @@ providing improved performance by offloading database operations to background w """ import logging -from typing import Optional, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -39,8 +39,8 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): _session_factory: sessionmaker _tenant_id: str - _app_id: Optional[str] - _triggered_from: Optional[WorkflowRunTriggeredFrom] + _app_id: str | None + _triggered_from: WorkflowRunTriggeredFrom | None _creator_user_id: str _creator_user_role: CreatorUserRole @@ -48,8 +48,8 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: Optional[str], - triggered_from: Optional[WorkflowRunTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowRunTriggeredFrom | None, ): """ Initialize the repository with Celery task configuration and context information. diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index b36252dba2..21a0b7eefe 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -7,7 +7,7 @@ providing improved performance by offloading database operations to background w import logging from collections.abc import Sequence -from typing import Optional, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -44,8 +44,8 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): _session_factory: sessionmaker _tenant_id: str - _app_id: Optional[str] - _triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom] + _app_id: str | None + _triggered_from: WorkflowNodeExecutionTriggeredFrom | None _creator_user_id: str _creator_user_role: CreatorUserRole _execution_cache: dict[str, WorkflowNodeExecution] @@ -55,8 +55,8 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: Optional[str], - triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowNodeExecutionTriggeredFrom | None, ): """ Initialize the repository with Celery task configuration and context information. @@ -94,10 +94,10 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER # In-memory cache for workflow node executions - self._execution_cache: dict[str, WorkflowNodeExecution] = {} + self._execution_cache = {} # Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval - self._workflow_execution_mapping: dict[str, list[str]] = {} + self._workflow_execution_mapping = {} logger.info( "Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s", @@ -151,7 +151,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): def get_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache. diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 08423effd0..9091a3190b 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -4,7 +4,7 @@ SQLAlchemy implementation of the WorkflowExecutionRepository. import json import logging -from typing import Optional, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -41,8 +41,8 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: Optional[str], - triggered_from: Optional[WorkflowRunTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowRunTriggeredFrom | None, ): """ Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. @@ -156,7 +156,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): else None ) db_model.status = domain_model.status - db_model.error = domain_model.error_message if domain_model.error_message else None + db_model.error = domain_model.error_message or None db_model.total_tokens = domain_model.total_tokens db_model.total_steps = domain_model.total_steps db_model.exceptions_count = domain_model.exceptions_count diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 1e36799d3e..219aec5a03 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -64,8 +64,8 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: str, - triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowNodeExecutionTriggeredFrom | None, ): """ Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. @@ -470,7 +470,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) def get_db_models_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecutionModel]: """ diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 5a2b803932..6e0462c530 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Generator from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from models.model import File @@ -46,9 +46,9 @@ class Tool(ABC): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage]: if self.runtime and self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters) @@ -96,17 +96,17 @@ class Tool(ABC): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: pass def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get the runtime parameters @@ -119,9 +119,9 @@ class Tool(ABC): def get_merged_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get merged runtime parameters @@ -196,7 +196,7 @@ class Tool(ABC): message=ToolInvokeMessage.TextMessage(text=text), ) - def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage: + def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage: """ create a blob message diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index ddec7b1329..3de0014c61 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from openai import BaseModel from pydantic import Field @@ -13,9 +13,9 @@ class ToolRuntime(BaseModel): """ tenant_id: str - tool_id: Optional[str] = None - invoke_from: Optional[InvokeFrom] = None - tool_invoke_from: Optional[ToolInvokeFrom] = None + tool_id: str | None = None + invoke_from: InvokeFrom | None = None + tool_invoke_from: ToolInvokeFrom | None = None credentials: dict[str, Any] = Field(default_factory=dict) credential_type: CredentialType = Field(default=CredentialType.API_KEY) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 68bfe5b4a5..45fd16d684 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -18,7 +18,7 @@ from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolProviderNotFoundError, ) -from core.tools.utils.yaml_utils import load_yaml_file +from core.tools.utils.yaml_utils import load_yaml_file_cached class BuiltinToolProviderController(ToolProviderController): @@ -31,7 +31,7 @@ class BuiltinToolProviderController(ToolProviderController): provider = self.__class__.__module__.split(".")[-1] yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml") try: - provider_yaml = load_yaml_file(yaml_path, ignore_error=False) + provider_yaml = load_yaml_file_cached(yaml_path) except Exception as e: raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}") @@ -71,7 +71,7 @@ class BuiltinToolProviderController(ToolProviderController): for tool_file in tool_files: # get tool name tool_name = tool_file.split(".")[0] - tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False) + tool = load_yaml_file_cached(path.join(tool_path, tool_file)) # get tool class, import the module assistant_tool_class: type = load_single_subclass_from_source( diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 5c24920871..af9b5b31c2 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -1,6 +1,6 @@ import io from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.file.enums import FileType from core.file.file_manager import download @@ -18,9 +18,9 @@ class ASRTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: file = tool_parameters.get("audio_file") if file.type != FileType.AUDIO: # type: ignore @@ -56,9 +56,9 @@ class ASRTool(BuiltinTool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: parameters = [] diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index f191968812..8bc159bb85 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -1,6 +1,6 @@ import io from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -16,9 +16,9 @@ class TTSTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: provider, model = tool_parameters.get("model").split("#") # type: ignore voice = tool_parameters.get(f"voice#{provider}#{model}") @@ -72,9 +72,9 @@ class TTSTool(BuiltinTool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: parameters = [] diff --git a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py index b4e650e0ed..4383943199 100644 --- a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py +++ b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage from core.tools.builtin_tool.tool import BuiltinTool @@ -12,9 +12,9 @@ class SimpleCode(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke simple code diff --git a/api/core/tools/builtin_tool/providers/time/tools/current_time.py b/api/core/tools/builtin_tool/providers/time/tools/current_time.py index d054afac96..44f94c2723 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/current_time.py +++ b/api/core/tools/builtin_tool/providers/time/tools/current_time.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import UTC, datetime -from typing import Any, Optional +from typing import Any from pytz import timezone as pytz_timezone @@ -13,9 +13,9 @@ class CurrentTimeTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tools diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index a8fd6ec2cd..197b062e44 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any import pytz @@ -14,9 +14,9 @@ class LocaltimeToTimestampTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Convert localtime to timestamp diff --git a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py index 0ef6331530..462e4be5ce 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any import pytz @@ -14,9 +14,9 @@ class TimestampToLocaltimeTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Convert timestamp to localtime diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index 91316b859a..babfa9bcd9 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any import pytz @@ -14,9 +14,9 @@ class TimezoneConversionTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Convert time to equivalent time zone diff --git a/api/core/tools/builtin_tool/providers/time/tools/weekday.py b/api/core/tools/builtin_tool/providers/time/tools/weekday.py index 158ce701c0..e26b316bd5 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/weekday.py +++ b/api/core/tools/builtin_tool/providers/time/tools/weekday.py @@ -1,7 +1,7 @@ import calendar from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage @@ -12,9 +12,9 @@ class WeekdayTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Calculate the day of the week for a given date diff --git a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py index 3bee710879..9d668ac9eb 100644 --- a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py +++ b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage @@ -12,9 +12,9 @@ class WebscraperTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tools diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 5790aea2b0..0cc992155a 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -1,4 +1,5 @@ from pydantic import Field +from sqlalchemy import select from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool_provider import ToolProviderController @@ -176,11 +177,11 @@ class ApiToolProviderController(ToolProviderController): tools: list[ApiTool] = [] # get tenant api providers - db_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider) - .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) - .all() - ) + db_providers = db.session.scalars( + select(ApiToolProvider).where( + ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name + ) + ).all() if db_providers and len(db_providers) != 0: for db_provider in db_providers: diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 190af999b1..13dd2114d3 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -2,7 +2,7 @@ import json from collections.abc import Generator from dataclasses import dataclass from os import getenv -from typing import Any, Optional, Union +from typing import Any, Union from urllib.parse import urlencode import httpx @@ -376,9 +376,9 @@ class ApiTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke http request diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index ca3be26ff9..ee2b438f5b 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -14,9 +14,9 @@ class ToolApiEntity(BaseModel): name: str # identifier label: I18nObject # label description: I18nObject - parameters: Optional[list[ToolParameter]] = None + parameters: list[ToolParameter] | None = None labels: list[str] = Field(default_factory=list) - output_schema: Optional[dict] = None + output_schema: dict | None = None ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]] @@ -28,25 +28,25 @@ class ToolProviderApiEntity(BaseModel): name: str # identifier description: I18nObject icon: str | dict - icon_dark: Optional[str | dict] = Field(default=None, description="The dark icon of the tool") + icon_dark: str | dict | None = Field(default=None, description="The dark icon of the tool") label: I18nObject # label type: ToolProviderType - masked_credentials: Optional[dict] = None - original_credentials: Optional[dict] = None + masked_credentials: dict | None = None + original_credentials: dict | None = None is_team_authorization: bool = False allow_delete: bool = True - plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") - plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") + plugin_id: str | None = Field(default="", description="The plugin id of the tool") + plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool") tools: list[ToolApiEntity] = Field(default_factory=list) labels: list[str] = Field(default_factory=list) # MCP - server_url: Optional[str] = Field(default="", description="The server url of the tool") + server_url: str | None = Field(default="", description="The server url of the tool") updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) - server_identifier: Optional[str] = Field(default="", description="The server identifier of the MCP tool") - timeout: Optional[float] = Field(default=30.0, description="The timeout of the MCP tool") - sse_read_timeout: Optional[float] = Field(default=300.0, description="The SSE read timeout of the MCP tool") - masked_headers: Optional[dict[str, str]] = Field(default=None, description="The masked headers of the MCP tool") - original_headers: Optional[dict[str, str]] = Field(default=None, description="The original headers of the MCP tool") + server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool") + timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool") + sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool") + masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool") + original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool") @field_validator("tools", mode="before") @classmethod diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index aadbbeb843..2c6d9c1964 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field @@ -9,9 +7,9 @@ class I18nObject(BaseModel): """ en_US: str - zh_Hans: Optional[str] = Field(default=None) - pt_BR: Optional[str] = Field(default=None) - ja_JP: Optional[str] = Field(default=None) + zh_Hans: str | None = Field(default=None) + pt_BR: str | None = Field(default=None) + ja_JP: str | None = Field(default=None) def __init__(self, **data): super().__init__(**data) diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index ffeeabbc1c..eba20b07f0 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel from core.tools.entities.tool_entities import ToolParameter @@ -16,14 +14,14 @@ class ApiToolBundle(BaseModel): # method method: str # summary - summary: Optional[str] = None + summary: str | None = None # operation_id - operation_id: Optional[str] = None + operation_id: str | None = None # parameters - parameters: Optional[list[ToolParameter]] = None + parameters: list[ToolParameter] | None = None # author author: str # icon - icon: Optional[str] = None + icon: str | None = None # openapi operation openapi: dict diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 847ff80f45..1dc4b3a9ea 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,9 +1,8 @@ import base64 import contextlib -import enum from collections.abc import Mapping -from enum import Enum -from typing import Any, Optional, Union +from enum import StrEnum, auto +from typing import Any, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator @@ -22,7 +21,7 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY -class ToolLabelEnum(Enum): +class ToolLabelEnum(StrEnum): SEARCH = "search" IMAGE = "image" VIDEOS = "videos" @@ -42,18 +41,18 @@ class ToolLabelEnum(Enum): OTHER = "other" -class ToolProviderType(enum.StrEnum): +class ToolProviderType(StrEnum): """ Enum class for tool provider """ - PLUGIN = "plugin" + PLUGIN = auto() BUILT_IN = "builtin" - WORKFLOW = "workflow" - API = "api" - APP = "app" + WORKFLOW = auto() + API = auto() + APP = auto() DATASET_RETRIEVAL = "dataset-retrieval" - MCP = "mcp" + MCP = auto() @classmethod def value_of(cls, value: str) -> "ToolProviderType": @@ -69,15 +68,15 @@ class ToolProviderType(enum.StrEnum): raise ValueError(f"invalid mode value {value}") -class ApiProviderSchemaType(Enum): +class ApiProviderSchemaType(StrEnum): """ Enum class for api provider schema type. """ - OPENAPI = "openapi" - SWAGGER = "swagger" - OPENAI_PLUGIN = "openai_plugin" - OPENAI_ACTIONS = "openai_actions" + OPENAPI = auto() + SWAGGER = auto() + OPENAI_PLUGIN = auto() + OPENAI_ACTIONS = auto() @classmethod def value_of(cls, value: str) -> "ApiProviderSchemaType": @@ -93,14 +92,14 @@ class ApiProviderSchemaType(Enum): raise ValueError(f"invalid mode value {value}") -class ApiProviderAuthType(Enum): +class ApiProviderAuthType(StrEnum): """ Enum class for api provider auth type. """ - NONE = "none" - API_KEY_HEADER = "api_key_header" - API_KEY_QUERY = "api_key_query" + NONE = auto() + API_KEY_HEADER = auto() + API_KEY_QUERY = auto() @classmethod def value_of(cls, value: str) -> "ApiProviderAuthType": @@ -177,36 +176,36 @@ class ToolInvokeMessage(BaseModel): return value class LogMessage(BaseModel): - class LogStatus(Enum): - START = "start" - ERROR = "error" - SUCCESS = "success" + class LogStatus(StrEnum): + START = auto() + ERROR = auto() + SUCCESS = auto() id: str label: str = Field(..., description="The label of the log") - parent_id: Optional[str] = Field(default=None, description="Leave empty for root log") - error: Optional[str] = Field(default=None, description="The error message") + parent_id: str | None = Field(default=None, description="Leave empty for root log") + error: str | None = Field(default=None, description="The error message") status: LogStatus = Field(..., description="The status of the log") data: Mapping[str, Any] = Field(..., description="Detailed log data") - metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") + metadata: Mapping[str, Any] | None = Field(default=None, description="The metadata of the log") class RetrieverResourceMessage(BaseModel): retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") - class MessageType(Enum): - TEXT = "text" - IMAGE = "image" - LINK = "link" - BLOB = "blob" - JSON = "json" - IMAGE_LINK = "image_link" - BINARY_LINK = "binary_link" - VARIABLE = "variable" - FILE = "file" - LOG = "log" - BLOB_CHUNK = "blob_chunk" - RETRIEVER_RESOURCES = "retriever_resources" + class MessageType(StrEnum): + TEXT = auto() + IMAGE = auto() + LINK = auto() + BLOB = auto() + JSON = auto() + IMAGE_LINK = auto() + BINARY_LINK = auto() + VARIABLE = auto() + FILE = auto() + LOG = auto() + BLOB_CHUNK = auto() + RETRIEVER_RESOURCES = auto() type: MessageType = MessageType.TEXT """ @@ -243,7 +242,7 @@ class ToolInvokeMessage(BaseModel): class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") - file_var: Optional[dict[str, Any]] = None + file_var: dict[str, Any] | None = None class ToolParameter(PluginParameter): @@ -251,29 +250,29 @@ class ToolParameter(PluginParameter): Overrides type """ - class ToolParameterType(enum.StrEnum): + class ToolParameterType(StrEnum): """ removes TOOLS_SELECTOR from PluginParameterType """ - STRING = PluginParameterType.STRING.value - NUMBER = PluginParameterType.NUMBER.value - BOOLEAN = PluginParameterType.BOOLEAN.value - SELECT = PluginParameterType.SELECT.value - SECRET_INPUT = PluginParameterType.SECRET_INPUT.value - FILE = PluginParameterType.FILE.value - FILES = PluginParameterType.FILES.value - APP_SELECTOR = PluginParameterType.APP_SELECTOR.value - MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value - ANY = PluginParameterType.ANY.value - DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value + STRING = PluginParameterType.STRING + NUMBER = PluginParameterType.NUMBER + BOOLEAN = PluginParameterType.BOOLEAN + SELECT = PluginParameterType.SELECT + SECRET_INPUT = PluginParameterType.SECRET_INPUT + FILE = PluginParameterType.FILE + FILES = PluginParameterType.FILES + APP_SELECTOR = PluginParameterType.APP_SELECTOR + MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR + ANY = PluginParameterType.ANY + DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT # MCP object and array type parameters - ARRAY = MCPServerParameterType.ARRAY.value - OBJECT = MCPServerParameterType.OBJECT.value + ARRAY = MCPServerParameterType.ARRAY + OBJECT = MCPServerParameterType.OBJECT # deprecated, should not use. - SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value + SYSTEM_FILES = PluginParameterType.SYSTEM_FILES def as_normal_type(self): return as_normal_type(self) @@ -281,17 +280,17 @@ class ToolParameter(PluginParameter): def cast_value(self, value: Any): return cast_parameter_value(self, value) - class ToolParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool - FORM = "form" # should be set before invoking tool - LLM = "llm" # will be set by LLM + class ToolParameterForm(StrEnum): + SCHEMA = auto() # should be set while adding tool + FORM = auto() # should be set before invoking tool + LLM = auto() # will be set by LLM type: ToolParameterType = Field(..., description="The type of the parameter") - human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") + human_description: I18nObject | None = Field(default=None, description="The description presented to the user") form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") - llm_description: Optional[str] = None + llm_description: str | None = None # MCP object and array type parameters use this field to store the schema - input_schema: Optional[dict] = None + input_schema: dict | None = None @classmethod def get_simple_instance( @@ -300,7 +299,7 @@ class ToolParameter(PluginParameter): llm_description: str, typ: ToolParameterType, required: bool, - options: Optional[list[str]] = None, + options: list[str] | None = None, ) -> "ToolParameter": """ get a simple tool parameter @@ -341,9 +340,9 @@ class ToolProviderIdentity(BaseModel): name: str = Field(..., description="The name of the tool") description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") - icon_dark: Optional[str] = Field(default=None, description="The dark icon of the tool") + icon_dark: str | None = Field(default=None, description="The dark icon of the tool") label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field( + tags: list[ToolLabelEnum] | None = Field( default=[], description="The tags of the tool", ) @@ -354,7 +353,7 @@ class ToolIdentity(BaseModel): name: str = Field(..., description="The name of the tool") label: I18nObject = Field(..., description="The label of the tool") provider: str = Field(..., description="The provider of the tool") - icon: Optional[str] = None + icon: str | None = None class ToolDescription(BaseModel): @@ -365,8 +364,8 @@ class ToolDescription(BaseModel): class ToolEntity(BaseModel): identity: ToolIdentity parameters: list[ToolParameter] = Field(default_factory=list) - description: Optional[ToolDescription] = None - output_schema: Optional[dict] = None + description: ToolDescription | None = None + output_schema: dict | None = None has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") # pydantic configs @@ -387,9 +386,9 @@ class OAuthSchema(BaseModel): class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity - plugin_id: Optional[str] = None + plugin_id: str | None = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) - oauth_schema: Optional[OAuthSchema] = None + oauth_schema: OAuthSchema | None = None class ToolProviderEntityWithPlugin(ToolProviderEntity): @@ -412,8 +411,8 @@ class ToolInvokeMeta(BaseModel): """ time_cost: float = Field(..., description="The time cost of the tool invoke") - error: Optional[str] = None - tool_config: Optional[dict] = None + error: str | None = None + tool_config: dict | None = None @classmethod def empty(cls) -> "ToolInvokeMeta": @@ -447,14 +446,14 @@ class ToolLabel(BaseModel): icon: str = Field(..., description="The icon of the tool") -class ToolInvokeFrom(Enum): +class ToolInvokeFrom(StrEnum): """ Enum class for tool invoke """ - WORKFLOW = "workflow" - AGENT = "agent" - PLUGIN = "plugin" + WORKFLOW = auto() + AGENT = auto() + PLUGIN = auto() class ToolSelector(BaseModel): @@ -465,11 +464,11 @@ class ToolSelector(BaseModel): type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter") required: bool = Field(..., description="Whether the parameter is required") description: str = Field(..., description="The description of the parameter") - default: Optional[Union[int, float, str]] = None - options: Optional[list[PluginParameterOption]] = None + default: Union[int, float, str] | None = None + options: list[PluginParameterOption] | None = None provider_id: str = Field(..., description="The id of the provider") - credential_id: Optional[str] = Field(default=None, description="The id of the credential") + credential_id: str | None = Field(default=None, description="The id of the credential") tool_name: str = Field(..., description="The name of the tool") tool_description: str = Field(..., description="The description of the tool") tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") @@ -479,9 +478,9 @@ class ToolSelector(BaseModel): return self.model_dump() -class CredentialType(enum.StrEnum): +class CredentialType(StrEnum): API_KEY = "api-key" - OAUTH2 = "oauth2" + OAUTH2 = auto() def get_name(self): if self == CredentialType.API_KEY: diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 5f6eb045ab..60b393e1ea 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional, Self +from typing import Any, Self from core.mcp.types import Tool as RemoteMCPTool from core.tools.__base.tool_provider import ToolProviderController @@ -25,9 +25,9 @@ class MCPToolProviderController(ToolProviderController): provider_id: str, tenant_id: str, server_url: str, - headers: Optional[dict[str, str]] = None, - timeout: Optional[float] = None, - sse_read_timeout: Optional[float] = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + sse_read_timeout: float | None = None, ): super().__init__(entity) self.entity: ToolProviderEntityWithPlugin = entity diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 6810ac683d..976d4dc942 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -1,7 +1,7 @@ import base64 import json from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.mcp_client import MCPClient @@ -20,9 +20,9 @@ class MCPTool(Tool): icon: str, server_url: str, provider_id: str, - headers: Optional[dict[str, str]] = None, - timeout: Optional[float] = None, - sse_read_timeout: Optional[float] = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + sse_read_timeout: float | None = None, ): super().__init__(entity, runtime) self.tenant_id = tenant_id @@ -40,9 +40,9 @@ class MCPTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: from core.tools.errors import ToolInvokeError @@ -67,22 +67,42 @@ class MCPTool(Tool): for content in result.content: if isinstance(content, TextContent): - try: - content_json = json.loads(content.text) - if isinstance(content_json, dict): - yield self.create_json_message(content_json) - elif isinstance(content_json, list): - for item in content_json: - yield self.create_json_message(item) - else: - yield self.create_text_message(content.text) - except json.JSONDecodeError: - yield self.create_text_message(content.text) - + yield from self._process_text_content(content) elif isinstance(content, ImageContent): - yield self.create_blob_message( - blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType} - ) + yield self._process_image_content(content) + + def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]: + """Process text content and yield appropriate messages.""" + try: + content_json = json.loads(content.text) + yield from self._process_json_content(content_json) + except json.JSONDecodeError: + yield self.create_text_message(content.text) + + def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]: + """Process JSON content based on its type.""" + if isinstance(content_json, dict): + yield self.create_json_message(content_json) + elif isinstance(content_json, list): + yield from self._process_json_list(content_json) + else: + # For primitive types (str, int, bool, etc.), convert to string + yield self.create_text_message(str(content_json)) + + def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]: + """Process a list of JSON items.""" + if any(not isinstance(item, dict) for item in json_list): + # If the list contains any non-dict item, treat the entire list as a text message. + yield self.create_text_message(str(json_list)) + return + + # Otherwise, process each dictionary as a separate JSON message. + for item in json_list: + yield self.create_json_message(item) + + def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage: + """Process image content and return a blob message.""" + return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}) def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool": return MCPTool( diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index e649caec1d..828dc3b810 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.plugin.impl.tool import PluginToolManager from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -16,7 +16,7 @@ class PluginTool(Tool): self.tenant_id = tenant_id self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - self.runtime_parameters: Optional[list[ToolParameter]] = None + self.runtime_parameters: list[ToolParameter] | None = None def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.PLUGIN @@ -25,9 +25,9 @@ class PluginTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: manager = PluginToolManager() @@ -57,9 +57,9 @@ class PluginTool(Tool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get the runtime parameters diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 5acac20739..cb86555783 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -51,10 +51,10 @@ class ToolEngine: message: Message, invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, - trace_manager: Optional[TraceQueueManager] = None, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + trace_manager: TraceQueueManager | None = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> tuple[str, list[str], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. @@ -194,9 +194,9 @@ class ToolEngine: tool: Tool, tool_parameters: dict, user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: """ Invoke the tool with the given arguments. diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index ad650196ce..6289f1d335 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -6,7 +6,7 @@ import os import time from collections.abc import Generator from mimetypes import guess_extension, guess_type -from typing import Optional, Union +from typing import Union from uuid import uuid4 import httpx @@ -72,10 +72,10 @@ class ToolFileManager: *, user_id: str, tenant_id: str, - conversation_id: Optional[str], + conversation_id: str | None, file_binary: bytes, mimetype: str, - filename: Optional[str] = None, + filename: str | None = None, ) -> ToolFile: extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex @@ -112,7 +112,7 @@ class ToolFileManager: user_id: str, tenant_id: str, file_url: str, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> ToolFile: # try to download image try: @@ -217,7 +217,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Optional[Generator], Optional[ToolFile]]: + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: """ get file binary diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 84b874975a..39646b7fc8 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -87,9 +87,7 @@ class ToolLabelManager: assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] - labels: list[ToolLabelBinding] = ( - db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all() - ) + labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index b25d0ef944..011e7ad242 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -161,7 +161,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: """ get the tool runtime @@ -450,7 +450,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Tool: """ get tool runtime from plugin @@ -675,9 +675,9 @@ class ToolManager: # get db api providers if "api" in filters: - db_api_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() - ) + db_api_providers = db.session.scalars( + select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id) + ).all() api_provider_controllers: list[dict[str, Any]] = [ {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} @@ -698,9 +698,9 @@ class ToolManager: if "workflow" in filters: # get workflow providers - workflow_providers: list[WorkflowToolProvider] = ( - db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() - ) + workflow_providers = db.session.scalars( + select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) + ).all() workflow_provider_controllers: list[WorkflowToolProviderController] = [] for workflow_provider in workflow_providers: diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py index 2e572099b3..ac2967d0c1 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -13,7 +12,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): description: str = "use this to retrieve a dataset. " tenant_id: str top_k: int = 4 - score_threshold: Optional[float] = None + score_threshold: float | None = None hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] return_resource: bool retriever_from: str diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index b536c5a25c..0e2237befd 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, cast +from typing import Any, cast from pydantic import BaseModel, Field from sqlalchemy import select @@ -37,7 +37,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): args_schema: type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " dataset_id: str - user_id: Optional[str] = None + user_id: str | None = None retrieve_config: DatasetRetrieveConfigEntity inputs: dict diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index d5803e33e7..a62d419243 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -87,9 +87,9 @@ class DatasetRetrieverTool(Tool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: return [ ToolParameter( @@ -112,9 +112,9 @@ class DatasetRetrieverTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke dataset retriever tool diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py index fd0463d14a..6ea033b2b6 100644 --- a/api/core/tools/utils/encryption.py +++ b/api/core/tools/utils/encryption.py @@ -1,6 +1,6 @@ import contextlib from copy import deepcopy -from typing import Any, Optional, Protocol +from typing import Any, Protocol from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter @@ -13,7 +13,7 @@ class ProviderConfigCache(Protocol): Interface for provider configuration cache operations """ - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """Get cached provider configuration""" ... diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index bf075bd730..0851a54338 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -3,7 +3,6 @@ from collections.abc import Generator from datetime import date, datetime from decimal import Decimal from mimetypes import guess_extension -from typing import Optional from uuid import UUID import numpy as np @@ -60,7 +59,7 @@ class ToolFileMessageTransformer: messages: Generator[ToolInvokeMessage, None, None], user_id: str, tenant_id: str, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Transform tool message and handle file download @@ -165,5 +164,5 @@ class ToolFileMessageTransformer: yield message @classmethod - def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: + def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str: return f"/files/tools/{tool_file_id}{extension or '.bin'}" diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 251d914800..526f5c8b9a 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -5,7 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models. """ import json -from typing import Optional, cast +from typing import cast from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult @@ -51,7 +51,7 @@ class ModelInvocationUtils: if not schema: raise InvokeModelError("No model schema found") - max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) + max_tokens: int | None = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index cae21633fe..2e306db6c7 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -2,7 +2,6 @@ import re from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError -from typing import Optional from flask import request from requests import get @@ -198,9 +197,9 @@ class ApiBasedToolSchemaParser: return bundles @staticmethod - def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: + def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None: parameter = parameter or {} - typ: Optional[str] = None + typ: str | None = None if parameter.get("format") == "binary": return ToolParameter.ToolParameterType.FILE diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_oauth_encryption.py index f3c946b95f..6b7007842d 100644 --- a/api/core/tools/utils/system_oauth_encryption.py +++ b/api/core/tools/utils/system_oauth_encryption.py @@ -2,7 +2,7 @@ import base64 import hashlib import logging from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from Crypto.Cipher import AES from Crypto.Random import get_random_bytes @@ -28,7 +28,7 @@ class SystemOAuthEncrypter: using AES-CBC mode with a key derived from the application's SECRET_KEY. """ - def __init__(self, secret_key: Optional[str] = None): + def __init__(self, secret_key: str | None = None): """ Initialize the OAuth encrypter. @@ -130,7 +130,7 @@ class SystemOAuthEncrypter: # Factory function for creating encrypter instances -def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter: +def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter: """ Create an OAuth encrypter instance. @@ -144,7 +144,7 @@ def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAu # Global encrypter instance (for backward compatibility) -_oauth_encrypter: Optional[SystemOAuthEncrypter] = None +_oauth_encrypter: SystemOAuthEncrypter | None = None def get_system_oauth_encrypter() -> SystemOAuthEncrypter: diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index d8403c2e15..52c16c34a0 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -2,7 +2,7 @@ import mimetypes import re from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Optional, cast +from typing import Any, cast from urllib.parse import unquote import chardet @@ -27,7 +27,7 @@ def page_result(text: str, cursor: int, max_length: int) -> str: return text[cursor : cursor + max_length] -def get_url(url: str, user_agent: Optional[str] = None) -> str: +def get_url(url: str, user_agent: str | None = None) -> str: """Fetch URL and return the contents as a string.""" headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 8a0a91a50c..e9b5dab7d3 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -1,4 +1,5 @@ import logging +from functools import lru_cache from pathlib import Path from typing import Any @@ -8,28 +9,25 @@ from yaml import YAMLError logger = logging.getLogger(__name__) -def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}): - """ - Safe loading a YAML file - :param file_path: the path of the YAML file - :param ignore_error: - if True, return default_value if error occurs and the error will be logged in debug level - if False, raise error if error occurs - :param default_value: the value returned when errors ignored - :return: an object of the YAML content - """ +def _load_yaml_file(*, file_path: str): if not file_path or not Path(file_path).exists(): - if ignore_error: - return default_value - else: - raise FileNotFoundError(f"File not found: {file_path}") + raise FileNotFoundError(f"File not found: {file_path}") with open(file_path, encoding="utf-8") as yaml_file: try: yaml_content = yaml.safe_load(yaml_file) - return yaml_content or default_value + return yaml_content except Exception as e: - if ignore_error: - return default_value - else: - raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e + + +@lru_cache(maxsize=128) +def load_yaml_file_cached(file_path: str) -> Any: + """ + Cached version of load_yaml_file for static configuration files. + Only use for files that don't change during runtime (e.g., position files) + + :param file_path: the path of the YAML file + :return: an object of the YAML content + """ + return _load_yaml_file(file_path=file_path) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 18e6993b38..4d9c8895fc 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -1,5 +1,4 @@ from collections.abc import Mapping -from typing import Optional from pydantic import Field @@ -207,7 +206,7 @@ class WorkflowToolProviderController(ToolProviderController): return self.tools - def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: # type: ignore + def get_tool(self, tool_name: str) -> WorkflowTool | None: # type: ignore """ get tool by name diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 73163e0e69..5adf04611d 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from typing import Any, Optional +from typing import Any from sqlalchemy import select @@ -61,9 +61,9 @@ class WorkflowTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke the tool diff --git a/api/core/variables/segment_group.py b/api/core/variables/segment_group.py index b363255b2c..0a41b64228 100644 --- a/api/core/variables/segment_group.py +++ b/api/core/variables/segment_group.py @@ -4,7 +4,7 @@ from .types import SegmentType class SegmentGroup(Segment): value_type: SegmentType = SegmentType.GROUP - value: list[Segment] + value: list[Segment] = None # type: ignore @property def text(self): diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 406b4e6f93..6c9e6d726e 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -19,7 +19,7 @@ class Segment(BaseModel): model_config = ConfigDict(frozen=True) value_type: SegmentType - value: Any + value: Any = None @field_validator("value_type") @classmethod @@ -74,12 +74,12 @@ class NoneSegment(Segment): class StringSegment(Segment): value_type: SegmentType = SegmentType.STRING - value: str + value: str = None # type: ignore class FloatSegment(Segment): value_type: SegmentType = SegmentType.FLOAT - value: float + value: float = None # type: ignore # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. # @@ -98,12 +98,12 @@ class FloatSegment(Segment): class IntegerSegment(Segment): value_type: SegmentType = SegmentType.INTEGER - value: int + value: int = None # type: ignore class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] + value: Mapping[str, Any] = None # type: ignore @property def text(self) -> str: @@ -136,7 +136,7 @@ class ArraySegment(Segment): class FileSegment(Segment): value_type: SegmentType = SegmentType.FILE - value: File + value: File = None # type: ignore @property def markdown(self) -> str: @@ -153,17 +153,17 @@ class FileSegment(Segment): class BooleanSegment(Segment): value_type: SegmentType = SegmentType.BOOLEAN - value: bool + value: bool = None # type: ignore class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] + value: Sequence[Any] = None # type: ignore class ArrayStringSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] + value: Sequence[str] = None # type: ignore @property def text(self) -> str: @@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment): class ArrayNumberSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] + value: Sequence[float | int] = None # type: ignore class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] + value: Sequence[Mapping[str, Any]] = None # type: ignore class ArrayFileSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] + value: Sequence[File] = None # type: ignore @property def markdown(self) -> str: @@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment): class ArrayBooleanSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] + value: Sequence[bool] = None # type: ignore def get_segment_discriminator(v: Any) -> SegmentType | None: diff --git a/api/core/workflow/entities/run_condition.py b/api/core/workflow/entities/run_condition.py index eedce8842b..7b9a379215 100644 --- a/api/core/workflow/entities/run_condition.py +++ b/api/core/workflow/entities/run_condition.py @@ -1,5 +1,5 @@ import hashlib -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -10,10 +10,10 @@ class RunCondition(BaseModel): type: Literal["branch_identify", "condition"] """condition type""" - branch_identify: Optional[str] = None + branch_identify: str | None = None """branch identify like: sourceHandle, required when type is branch_identify""" - conditions: Optional[list[Condition]] = None + conditions: list[Condition] | None = None """conditions to run the node, required when type is condition""" @property diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index c41a17e165..a8a86d3db2 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -7,7 +7,7 @@ implementation details like tenant_id, app_id, etc. from collections.abc import Mapping from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -28,7 +28,7 @@ class WorkflowExecution(BaseModel): graph: Mapping[str, Any] = Field(...) inputs: Mapping[str, Any] = Field(...) - outputs: Optional[Mapping[str, Any]] = None + outputs: Mapping[str, Any] | None = None status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING error_message: str = Field(default="") @@ -37,7 +37,7 @@ class WorkflowExecution(BaseModel): exceptions_count: int = Field(default=0) started_at: datetime = Field(...) - finished_at: Optional[datetime] = None + finished_at: datetime | None = None @property def elapsed_time(self) -> float: diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py index 6111d8654b..ef3022352a 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/core/workflow/entities/workflow_node_execution.py @@ -39,14 +39,14 @@ class WorkflowNodeExecution(BaseModel): # NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`. # While `node_execution_id` may sometimes be a UUID string, this is not guaranteed. # In most scenarios, `id` should be used as the primary identifier. - node_execution_id: Optional[str] = None + node_execution_id: str | None = None workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging) + workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging) # --------- Core identification fields ends --------- # Execution positioning and flow index: int # Sequence number for ordering in trace visualization - predecessor_node_id: Optional[str] = None # ID of the node that executed before this one + predecessor_node_id: str | None = None # ID of the node that executed before this one node_id: str # ID of the node being executed node_type: NodeType # Type of node (e.g., start, llm, knowledge) title: str # Display title of the node @@ -59,15 +59,15 @@ class WorkflowNodeExecution(BaseModel): # Execution state status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status - error: Optional[str] = None # Error message if execution failed + error: str | None = None # Error message if execution failed elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds # Additional metadata - metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.) # Timing information created_at: datetime # When execution started - finished_at: Optional[datetime] = None # When execution completed + finished_at: datetime | None = None # When execution completed _truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None) _truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None) @@ -123,10 +123,10 @@ class WorkflowNodeExecution(BaseModel): def update_from_mapping( self, - inputs: Optional[Mapping[str, Any]] = None, - process_data: Optional[Mapping[str, Any]] = None, - outputs: Optional[Mapping[str, Any]] = None, - metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None, + inputs: Mapping[str, Any] | None = None, + process_data: Mapping[str, Any] | None = None, + outputs: Mapping[str, Any] | None = None, + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, ): """ Update the model from mappings. diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 7e25fc0866..123ef3d449 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -5,13 +5,13 @@ This module defines command types that can be sent to a running GraphEngine instance to control its execution flow. """ -from enum import Enum +from enum import StrEnum from typing import Any from pydantic import BaseModel, Field -class CommandType(str, Enum): +class CommandType(StrEnum): """Types of commands that can be sent to GraphEngine.""" ABORT = "abort" diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 67f16743c3..4662cec2e9 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from packaging.version import Version from pydantic import ValidationError @@ -71,7 +71,7 @@ class AgentNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = AgentNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -80,7 +80,7 @@ class AgentNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -324,7 +324,7 @@ class AgentNode(Node): memory = self._fetch_memory(model_instance) if memory: prompt_messages = memory.get_history_prompt_messages( - message_limit=node_data.memory.window.size if node_data.memory.window.size else None + message_limit=node_data.memory.window.size or None ) history_prompt_messages = [ prompt_message.model_dump(mode="json") for prompt_message in prompt_messages @@ -408,7 +408,7 @@ class AgentNode(Node): icon = None return icon - def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None: # get conversation id conversation_id_variable = self.graph_runtime_state.variable_pool.get( ["sys", SystemVariableKey.CONVERSATION_ID.value] diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 11b11068e7..ce6eb33ecc 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,4 +1,4 @@ -from enum import Enum, StrEnum +from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union from pydantic import BaseModel @@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData): agent_parameters: dict[str, AgentInput] -class ParamsAutoGenerated(Enum): - CLOSE = 0 - OPEN = 1 +class ParamsAutoGenerated(IntEnum): + CLOSE = auto() + OPEN = auto() class AgentOldVersionModelFeatures(StrEnum): @@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum): TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" - VISION = "vision" + VISION = auto() STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = "document" - VIDEO = "video" - AUDIO = "audio" + DOCUMENT = auto() + VIDEO = auto() + AUDIO = auto() diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py index d5955bdd7d..944f5f0b20 100644 --- a/api/core/workflow/nodes/agent/exc.py +++ b/api/core/workflow/nodes/agent/exc.py @@ -1,6 +1,3 @@ -from typing import Optional - - class AgentNodeError(Exception): """Base exception for all agent node errors.""" @@ -12,7 +9,7 @@ class AgentNodeError(Exception): class AgentStrategyError(AgentNodeError): """Exception raised when there's an error with the agent strategy.""" - def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None): + def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None): self.strategy_name = strategy_name self.provider_name = provider_name super().__init__(message) @@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError): class AgentStrategyNotFoundError(AgentStrategyError): """Exception raised when the specified agent strategy is not found.""" - def __init__(self, strategy_name: str, provider_name: Optional[str] = None): + def __init__(self, strategy_name: str, provider_name: str | None = None): super().__init__( f"Agent strategy '{strategy_name}' not found" + (f" for provider '{provider_name}'" if provider_name else ""), @@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError): class AgentInvocationError(AgentNodeError): """Exception raised when there's an error invoking the agent.""" - def __init__(self, message: str, original_error: Optional[Exception] = None): + def __init__(self, message: str, original_error: Exception | None = None): self.original_error = original_error super().__init__(message) @@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError): class AgentParameterError(AgentNodeError): """Exception raised when there's an error with agent parameters.""" - def __init__(self, message: str, parameter_name: Optional[str] = None): + def __init__(self, message: str, parameter_name: str | None = None): self.parameter_name = parameter_name super().__init__(message) @@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError): class AgentVariableError(AgentNodeError): """Exception raised when there's an error with variables in the agent node.""" - def __init__(self, message: str, variable_name: Optional[str] = None): + def __init__(self, message: str, variable_name: str | None = None): self.variable_name = variable_name super().__init__(message) @@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError): class ToolFileError(AgentNodeError): """Exception raised when there's an error with a tool file.""" - def __init__(self, message: str, file_id: Optional[str] = None): + def __init__(self, message: str, file_id: str | None = None): self.file_id = file_id super().__init__(message) @@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError): class AgentMessageTransformError(AgentNodeError): """Exception raised when there's an error transforming agent messages.""" - def __init__(self, message: str, original_error: Optional[Exception] = None): + def __init__(self, message: str, original_error: Exception | None = None): self.original_error = original_error super().__init__(message) @@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError): class AgentModelError(AgentNodeError): """Exception raised when there's an error with the model used by the agent.""" - def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None): + def __init__(self, message: str, model_name: str | None = None, provider: str | None = None): self.model_name = model_name self.provider = provider super().__init__(message) @@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError): class AgentMemoryError(AgentNodeError): """Exception raised when there's an error with the agent's memory.""" - def __init__(self, message: str, conversation_id: Optional[str] = None): + def __init__(self, message: str, conversation_id: str | None = None): self.conversation_id = conversation_id super().__init__(message) @@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError): def __init__( self, message: str, - variable_name: Optional[str] = None, - expected_type: Optional[str] = None, - actual_type: Optional[str] = None, + variable_name: str | None = None, + expected_type: str | None = None, + actual_type: str | None = None, ): self.variable_name = variable_name self.expected_type = expected_type diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 4ef5c880c4..86174c7ea6 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from core.variables import ArrayFileSegment, FileSegment, Segment from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus @@ -20,7 +20,7 @@ class AnswerNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = AnswerNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -29,7 +29,7 @@ class AnswerNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index a05cc44c99..850ff14880 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum +from enum import StrEnum, auto from pydantic import BaseModel, Field @@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel): Generate Route Chunk. """ - class ChunkType(Enum): - VAR = "var" - TEXT = "text" + class ChunkType(StrEnum): + VAR = auto() + TEXT = auto() type: ChunkType = Field(..., description="generate route chunk type") diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index bc07e26456..5aef9d79cf 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -2,7 +2,7 @@ import json from abc import ABC from collections.abc import Sequence from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, model_validator @@ -45,7 +45,7 @@ class DefaultValueType(StrEnum): class DefaultValue(BaseModel): - value: Any + value: Any = None type: DefaultValueType key: str @@ -128,10 +128,10 @@ class DefaultValue(BaseModel): class BaseNodeData(ABC, BaseModel): title: str - desc: Optional[str] = None + desc: str | None = None version: str = "1" - error_strategy: Optional[ErrorStrategy] = None - default_value: Optional[list[DefaultValue]] = None + error_strategy: ErrorStrategy | None = None + default_value: list[DefaultValue] | None = None retry_config: RetryConfig = RetryConfig() @property @@ -142,7 +142,7 @@ class BaseNodeData(ABC, BaseModel): class BaseIterationNodeData(BaseNodeData): - start_node_id: Optional[str] = None + start_node_id: str | None = None class BaseIterationState(BaseModel): @@ -157,7 +157,7 @@ class BaseIterationState(BaseModel): class BaseLoopNodeData(BaseNodeData): - start_node_id: Optional[str] = None + start_node_id: str | None = None class BaseLoopState(BaseModel): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index ae7f4b19cc..438d768104 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -285,7 +285,7 @@ class Node: ... @abstractmethod - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: """Get the node description.""" ... @@ -316,7 +316,7 @@ class Node: return self._get_title() @property - def description(self) -> Optional[str]: + def description(self) -> str | None: """Get the node description.""" return self._get_description() diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 8171686022..4fa97e0478 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from decimal import Decimal -from typing import Any, Optional +from typing import Any from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage @@ -30,7 +30,7 @@ class CodeNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = CodeNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -39,7 +39,7 @@ class CodeNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -49,7 +49,7 @@ class CodeNode(Node): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): """ Get default config of node. :param filters: filter by node config parameters. @@ -154,7 +154,7 @@ class CodeNode(Node): def _transform_result( self, result: Mapping[str, Any], - output_schema: Optional[dict[str, CodeNodeData.Output]], + output_schema: dict[str, CodeNodeData.Output] | None, prefix: str = "", depth: int = 1, ): diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index c8095e26e1..8026011196 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Literal, Optional +from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel @@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: Optional[dict[str, "CodeNodeData.Output"]] = None + children: dict[str, "CodeNodeData.Output"] | None = None class Dependency(BaseModel): name: str @@ -44,4 +44,4 @@ class CodeNodeData(BaseNodeData): code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str outputs: dict[str, Output] - dependencies: Optional[list[Dependency]] = None + dependencies: list[Dependency] | None = None diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 38213ea4b4..ae1061d72c 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,7 +5,7 @@ import logging import os import tempfile from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any import chardet import docx @@ -49,7 +49,7 @@ class DocumentExtractorNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = DocumentExtractorNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -58,7 +58,7 @@ class DocumentExtractorNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index ca2aeddf3e..2bdfe4efce 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -18,7 +18,7 @@ class EndNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = EndNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -27,7 +27,7 @@ class EndNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 8d7ba25d47..5a7db6e0e6 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,7 +1,7 @@ import mimetypes from collections.abc import Sequence from email.message import Message -from typing import Any, Literal, Optional +from typing import Any, Literal import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator @@ -18,7 +18,7 @@ class HttpRequestNodeAuthorizationConfig(BaseModel): class HttpRequestNodeAuthorization(BaseModel): type: Literal["no-auth", "api-key"] - config: Optional[HttpRequestNodeAuthorizationConfig] = None + config: HttpRequestNodeAuthorizationConfig | None = None @field_validator("config", mode="before") @classmethod @@ -88,9 +88,9 @@ class HttpRequestNodeData(BaseNodeData): authorization: HttpRequestNodeAuthorization headers: str params: str - body: Optional[HttpRequestNodeBody] = None - timeout: Optional[HttpRequestNodeTimeout] = None - ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY + body: HttpRequestNodeBody | None = None + timeout: HttpRequestNodeTimeout | None = None + ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY class Response: @@ -183,7 +183,7 @@ class Response: return f"{(self.size / 1024 / 1024):.2f} MB" @property - def parsed_content_disposition(self) -> Optional[Message]: + def parsed_content_disposition(self) -> Message | None: content_disposition = self.headers.get("content-disposition", "") if content_disposition: msg = Message() diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 8186a002f8..6226d8d362 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,7 +1,7 @@ import logging import mimetypes from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from configs import dify_config from core.file import File, FileTransferMethod @@ -39,7 +39,7 @@ class HttpRequestNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = HttpRequestNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -48,7 +48,7 @@ class HttpRequestNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -58,7 +58,7 @@ class HttpRequestNode(Node): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict[str, Any]] = None): + def get_default_config(cls, filters: dict[str, Any] | None = None): return { "type": "http-request", "config": { diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 67d6d6a886..b22bd6f508 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field @@ -20,7 +20,7 @@ class IfElseNodeData(BaseNodeData): logical_operator: Literal["and", "or"] conditions: list[Condition] - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) - cases: Optional[list[Case]] = None + cases: list[Case] | None = None diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 2149a9a05b..075f6f8444 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal from typing_extensions import deprecated @@ -22,7 +22,7 @@ class IfElseNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IfElseNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -31,7 +31,7 @@ class IfElseNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 7a489dd725..9608edb06e 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import Field @@ -17,7 +17,7 @@ class IterationNodeData(BaseIterationNodeData): Iteration Node Data. """ - parent_loop_id: Optional[str] = None # redundant field, not used currently + parent_loop_id: str | None = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector is_parallel: bool = False # open the parallel mode or not @@ -39,7 +39,7 @@ class IterationState(BaseIterationState): """ outputs: list[Any] = Field(default_factory=list) - current_output: Optional[Any] = None + current_output: Any | None = None class MetaData(BaseIterationState.MetaData): """ @@ -48,7 +48,7 @@ class IterationState(BaseIterationState): iterator_length: int - def get_last_output(self) -> Optional[Any]: + def get_last_output(self) -> Any | None: """ Get last output. """ @@ -56,7 +56,7 @@ class IterationState(BaseIterationState): return self.outputs[-1] return None - def get_current_output(self) -> Optional[Any]: + def get_current_output(self) -> Any | None: """ Get current output. """ diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index f15730d105..274e829ea5 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,7 +1,7 @@ import logging from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Union, cast from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment @@ -58,7 +58,7 @@ class IterationNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IterationNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -67,7 +67,7 @@ class IterationNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -77,7 +77,7 @@ class IterationNode(Node): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): return { "type": "iteration", "config": { diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index c03e7257a2..80f39ccebc 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -20,7 +20,7 @@ class IterationStartNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IterationStartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -29,7 +29,7 @@ class IterationStartNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index b71271abeb..8aa6a5016f 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field @@ -49,11 +49,11 @@ class MultipleRetrievalConfig(BaseModel): """ top_k: int - score_threshold: Optional[float] = None + score_threshold: float | None = None reranking_mode: str = "reranking_model" reranking_enable: bool = True - reranking_model: Optional[RerankingModelConfig] = None - weights: Optional[WeightedScoreConfig] = None + reranking_model: RerankingModelConfig | None = None + weights: WeightedScoreConfig | None = None class SingleRetrievalConfig(BaseModel): @@ -91,7 +91,7 @@ SupportedComparisonOperator = Literal[ class Condition(BaseModel): """ - Conditon detail + Condition detail """ name: str @@ -104,8 +104,8 @@ class MetadataFilteringCondition(BaseModel): Metadata Filtering Condition. """ - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) class KnowledgeRetrievalNodeData(BaseNodeData): @@ -117,11 +117,11 @@ class KnowledgeRetrievalNodeData(BaseNodeData): query_variable_selector: list[str] dataset_ids: list[str] retrieval_mode: Literal["single", "multiple"] - multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None - single_retrieval_config: Optional[SingleRetrievalConfig] = None - metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_model_config: Optional[ModelConfig] = None - metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None + multiple_retrieval_config: MultipleRetrievalConfig | None = None + single_retrieval_config: SingleRetrievalConfig | None = None + metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled" + metadata_model_config: ModelConfig | None = None + metadata_filtering_conditions: MetadataFilteringCondition | None = None vision: VisionConfig = Field(default_factory=VisionConfig) @property diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index d66b0cdf1a..ee5c4ae289 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -119,7 +119,7 @@ class KnowledgeRetrievalNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = KnowledgeRetrievalNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -128,7 +128,7 @@ class KnowledgeRetrievalNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -250,7 +250,7 @@ class KnowledgeRetrievalNode(Node): ) all_documents = [] dataset_retrieval = DatasetRetrieval() - if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: + if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # fetch model config if node_data.single_retrieval_config is None: raise ValueError("single_retrieval_config is required") @@ -282,7 +282,7 @@ class KnowledgeRetrievalNode(Node): metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) - elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: + elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": @@ -410,7 +410,7 @@ class KnowledgeRetrievalNode(Node): def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: + ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: document_query = db.session.query(Document).where( Document.dataset_id.in_(dataset_ids), Document.indexing_status == "completed", diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index fd6c75944e..7a31d69221 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Mapping, Sequence -from typing import Any, Optional, TypeAlias, TypeVar +from typing import Any, TypeAlias, TypeVar from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -43,7 +43,7 @@ class ListOperatorNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = ListOperatorNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -52,7 +52,7 @@ class ListOperatorNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -66,8 +66,8 @@ class ListOperatorNode(Node): return "1" def _run(self): - inputs: dict[str, list] = {} - process_data: dict[str, list] = {} + inputs: dict[str, Sequence[object]] = {} + process_data: dict[str, Sequence[object]] = {} outputs: dict[str, Any] = {} variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable) diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 72f83eb25b..fe6f2290aa 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator @@ -18,7 +18,7 @@ class ModelConfig(BaseModel): class ContextConfig(BaseModel): enabled: bool - variable_selector: Optional[list[str]] = None + variable_selector: list[str] | None = None class VisionConfigOptions(BaseModel): @@ -51,18 +51,18 @@ class PromptConfig(BaseModel): class LLMNodeChatModelMessage(ChatModelMessage): text: str = "" - jinja2_text: Optional[str] = None + jinja2_text: str | None = None class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - jinja2_text: Optional[str] = None + jinja2_text: str | None = None class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_config: PromptConfig = Field(default_factory=PromptConfig) - memory: Optional[MemoryConfig] = None + memory: MemoryConfig | None = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) structured_output: Mapping[str, Any] | None = None diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index af22b8588c..ad969cdad1 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Optional, cast +from typing import cast from sqlalchemy import select, update from sqlalchemy.orm import Session @@ -86,8 +86,8 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc def fetch_memory( - variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance -) -> Optional[TokenBufferMemory]: + variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance +) -> TokenBufferMemory | None: if not node_data_memory: return None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index ec9e42a250..f0f27e1e2d 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -4,7 +4,7 @@ import json import logging import re from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager @@ -139,7 +139,7 @@ class LLMNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LLMNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -148,7 +148,7 @@ class LLMNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -354,10 +354,10 @@ class LLMNode(Node): node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], - stop: Optional[Sequence[str]] = None, + stop: Sequence[str] | None = None, user_id: str, structured_output_enabled: bool, - structured_output: Optional[Mapping[str, Any]] = None, + structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, file_outputs: list["File"], node_id: str, @@ -716,7 +716,7 @@ class LLMNode(Node): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, - ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: + ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] if isinstance(prompt_template, list): @@ -959,7 +959,7 @@ class LLMNode(Node): return variable_mapping @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): return { "type": "llm", "config": { @@ -987,7 +987,7 @@ class LLMNode(Node): def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], - context: Optional[str], + context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, @@ -1175,7 +1175,7 @@ class LLMNode(Node): def _combine_message_content_with_role( - *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole + *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole ): match role: case PromptMessageRole.USER: @@ -1184,7 +1184,8 @@ def _combine_message_content_with_role( return AssistantPromptMessage(content=contents) case PromptMessageRole.SYSTEM: return SystemPromptMessage(content=contents) - raise NotImplementedError(f"Role {role} is not supported") + case _: + raise NotImplementedError(f"Role {role} is not supported") def _render_jinja2_message( @@ -1280,7 +1281,7 @@ def _handle_memory_completion_mode( def _handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, - context: Optional[str], + context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, ) -> Sequence[PromptMessage]: diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 6f6939810b..57434fa9f0 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -34,7 +34,7 @@ class LoopVariableData(BaseModel): label: str var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] value_type: Literal["variable", "constant"] - value: Any = None + value: Any | list[str] | None = None class LoopNodeData(BaseLoopNodeData): @@ -74,7 +74,7 @@ class LoopState(BaseLoopState): """ outputs: list[Any] = Field(default_factory=list) - current_output: Optional[Any] = None + current_output: Any | None = None class MetaData(BaseLoopState.MetaData): """ @@ -83,7 +83,7 @@ class LoopState(BaseLoopState): loop_length: int - def get_last_output(self) -> Optional[Any]: + def get_last_output(self) -> Any | None: """ Get last output. """ @@ -91,7 +91,7 @@ class LoopState(BaseLoopState): return self.outputs[-1] return None - def get_current_output(self) -> Optional[Any]: + def get_current_output(self) -> Any | None: """ Get current output. """ diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 8b1b5b424d..38aef06d24 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -20,7 +20,7 @@ class LoopEndNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopEndNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -29,7 +29,7 @@ class LoopEndNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 3c5259ea26..1b83319ab0 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -2,7 +2,7 @@ import json import logging from collections.abc import Callable, Generator, Mapping, Sequence from datetime import datetime -from typing import TYPE_CHECKING, Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, cast from configs import dify_config from core.variables import Segment, SegmentType @@ -52,7 +52,7 @@ class LoopNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -61,7 +61,7 @@ class LoopNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index 9f3febe9b0..e777a8cbe9 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -20,7 +20,7 @@ class LoopStartNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopStartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -29,7 +29,7 @@ class LoopStartNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 4c0b14b2d7..4e3819c4cf 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any, Literal from pydantic import ( BaseModel, @@ -48,7 +48,7 @@ class ParameterConfig(BaseModel): name: str type: Annotated[SegmentType, BeforeValidator(_validate_type)] - options: Optional[list[str]] = None + options: list[str] | None = None description: str required: bool @@ -86,8 +86,8 @@ class ParameterExtractorNodeData(BaseNodeData): model: ModelConfig query: list[str] parameters: list[ParameterConfig] - instruction: Optional[str] = None - memory: Optional[MemoryConfig] = None + instruction: str | None = None + memory: MemoryConfig | None = None reasoning_mode: Literal["function_call", "prompt"] vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 3f79006836..832422fcc3 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -3,7 +3,7 @@ import json import logging import uuid from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import File @@ -96,7 +96,7 @@ class ParameterExtractorNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = ParameterExtractorNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -105,7 +105,7 @@ class ParameterExtractorNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -114,11 +114,11 @@ class ParameterExtractorNode(Node): def get_base_node_data(self) -> BaseNodeData: return self._node_data - _model_instance: Optional[ModelInstance] = None - _model_config: Optional[ModelConfigWithCredentialsEntity] = None + _model_instance: ModelInstance | None = None + _model_config: ModelConfigWithCredentialsEntity | None = None @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): return { "model": { "prompt_templates": { @@ -293,7 +293,7 @@ class ParameterExtractorNode(Node): prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], stop: list[str], - ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: + ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=node_data_model.completion_params, @@ -323,9 +323,9 @@ class ParameterExtractorNode(Node): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. @@ -405,9 +405,9 @@ class ParameterExtractorNode(Node): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate prompt engineering prompt. @@ -443,9 +443,9 @@ class ParameterExtractorNode(Node): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate completion prompt. @@ -477,9 +477,9 @@ class ParameterExtractorNode(Node): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate chat prompt. @@ -651,7 +651,7 @@ class ParameterExtractorNode(Node): return transformed_result - def _extract_complete_json_response(self, result: str) -> Optional[dict]: + def _extract_complete_json_response(self, result: str) -> dict | None: """ Extract complete json response. """ @@ -666,7 +666,7 @@ class ParameterExtractorNode(Node): logger.info("extra error: %s", result) return None - def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: + def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None: """ Extract json from tool call. """ @@ -705,7 +705,7 @@ class ParameterExtractorNode(Node): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ) -> list[ChatModelMessage]: model_mode = ModelMode(node_data.model.mode) @@ -732,7 +732,7 @@ class ParameterExtractorNode(Node): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ): model_mode = ModelMode(node_data.model.mode) @@ -768,7 +768,7 @@ class ParameterExtractorNode(Node): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - context: Optional[str], + context: str | None, ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index 6248df0edf..edde30708a 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig @@ -16,8 +14,8 @@ class QuestionClassifierNodeData(BaseNodeData): query_variable_selector: list[str] model: ModelConfig classes: list[ClassConfig] - instruction: Optional[str] = None - memory: Optional[MemoryConfig] = None + instruction: str | None = None + memory: MemoryConfig | None = None vision: VisionConfig = Field(default_factory=VisionConfig) @property diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 929216652e..945ce113f1 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -80,7 +80,7 @@ class QuestionClassifierNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = QuestionClassifierNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -89,7 +89,7 @@ class QuestionClassifierNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -271,7 +271,7 @@ class QuestionClassifierNode(Node): return variable_mapping @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): """ Get default config of node. :param filters: filter by node config parameters (not used in this implementation). @@ -285,7 +285,7 @@ class QuestionClassifierNode(Node): node_data: QuestionClassifierNodeData, query: str, model_config: ModelConfigWithCredentialsEntity, - context: Optional[str], + context: str | None, ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_template = self._get_prompt_template(node_data, query, None, 2000) @@ -328,7 +328,7 @@ class QuestionClassifierNode(Node): self, node_data: QuestionClassifierNodeData, query: str, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ): model_mode = ModelMode(node_data.model.mode) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 608f6b11cc..2f33c54128 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus @@ -18,7 +18,7 @@ class StartNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = StartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -27,7 +27,7 @@ class StartNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 9039476871..e00c838b1d 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,6 +1,6 @@ import os from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus @@ -20,7 +20,7 @@ class TemplateTransformNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = TemplateTransformNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -29,7 +29,7 @@ class TemplateTransformNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -39,7 +39,7 @@ class TemplateTransformNode(Node): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): """ Get default config of node. :param filters: filter by node config parameters. diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index cf41d74d7e..2921b3b911 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from sqlalchemy import select from sqlalchemy.orm import Session @@ -471,7 +471,7 @@ class ToolNode(Node): return result - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -480,7 +480,7 @@ class ToolNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index f4577d7573..13dbc5dbe6 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel from core.variables.types import SegmentType @@ -33,4 +31,4 @@ class VariableAssignerNodeData(BaseNodeData): type: str = "variable-assigner" output_type: str variables: list[list[str]] - advanced_settings: Optional[AdvancedSettings] = None + advanced_settings: AdvancedSettings | None = None diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index d2627d9d3b..be00d55937 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.variables.segments import Segment from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus @@ -17,7 +17,7 @@ class VariableAggregatorNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -26,7 +26,7 @@ class VariableAggregatorNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py index 8caee27363..04a7323739 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -16,7 +16,7 @@ class UpdatedVariable(BaseModel): name: str selector: Sequence[str] value_type: SegmentType - new_value: Any + new_value: Any = None _T = TypeVar("_T", bound=MutableMapping[str, Any]) diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 5eb9938b9e..c2a9ecd7fb 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias from core.variables import SegmentType, Variable from core.variables.segments import BooleanSegment @@ -33,7 +33,7 @@ class VariableAssignerNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -42,7 +42,7 @@ class VariableAssignerNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index e7833aa46f..a89055fd66 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable @@ -60,7 +60,7 @@ class VariableAssignerNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -69,7 +69,7 @@ class VariableAssignerNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py index e36789152a..43b41ff6b8 100644 --- a/api/core/workflow/repositories/workflow_node_execution_repository.py +++ b/api/core/workflow/repositories/workflow_node_execution_repository.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Literal, Optional, Protocol +from typing import Literal, Protocol from core.workflow.entities import WorkflowNodeExecution @@ -10,7 +10,7 @@ class OrderConfig: """Configuration for ordering NodeExecution instances.""" order_by: list[str] - order_direction: Optional[Literal["asc", "desc"]] = None + order_direction: Literal["asc", "desc"] | None = None class WorkflowNodeExecutionRepository(Protocol): @@ -56,7 +56,7 @@ class WorkflowNodeExecutionRepository(Protocol): def get_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index d6c89d385a..a88f350a9e 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional, Union +from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( @@ -85,9 +85,9 @@ class WorkflowCycleManager: total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, - external_trace_id: Optional[str] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) @@ -112,9 +112,9 @@ class WorkflowCycleManager: total_steps: int, outputs: Mapping[str, Any] | None = None, exceptions_count: int = 0, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, - external_trace_id: Optional[str] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) @@ -140,10 +140,10 @@ class WorkflowCycleManager: total_steps: int, status: WorkflowExecutionStatus, error_message: str, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, exceptions_count: int = 0, - external_trace_id: Optional[str] = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) now = naive_utc_now() @@ -302,9 +302,9 @@ class WorkflowCycleManager: total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, - error_message: Optional[str] = None, + error_message: str | None = None, exceptions_count: int = 0, - finished_at: Optional[datetime] = None, + finished_at: datetime | None = None, ): """Update workflow execution with completion data.""" execution.status = status @@ -318,10 +318,10 @@ class WorkflowCycleManager: def _add_trace_task_if_needed( self, - trace_manager: Optional[TraceQueueManager], + trace_manager: TraceQueueManager | None, workflow_execution: WorkflowExecution, - conversation_id: Optional[str], - external_trace_id: Optional[str], + conversation_id: str | None, + external_trace_id: str | None, ): """Add trace task if trace manager is provided.""" if trace_manager: @@ -363,8 +363,8 @@ class WorkflowCycleManager: workflow_execution: WorkflowExecution, event: QueueNodeStartedEvent, status: WorkflowNodeExecutionStatus, - error: Optional[str] = None, - created_at: Optional[datetime] = None, + error: str | None = None, + created_at: datetime | None = None, ) -> WorkflowNodeExecution: """Create a node execution from an event.""" now = naive_utc_now() @@ -408,7 +408,7 @@ class WorkflowCycleManager: QueueNodeExceptionEvent, ], status: WorkflowNodeExecutionStatus, - error: Optional[str] = None, + error: str | None = None, handle_special_values: bool = False, ): """Update node execution with completion data.""" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index d7d539914f..83594bff8b 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -343,7 +343,7 @@ class WorkflowEntry: raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) @staticmethod - def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: + def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: # NOTE(QuantumGhost): Avoid using this function in new code. # Keep values structured as long as possible and only convert to dict # immediately before serialization (e.g., JSON serialization) to maintain @@ -399,7 +399,7 @@ class WorkflowEntry: raise ValueError(f"Variable key {node_variable} not found in user inputs.") # environment variable already exist in variable pool, not from user inputs - if variable_pool.get(variable_selector): + if variable_pool.get(variable_selector) and variable_selector[0] == ENVIRONMENT_VARIABLE_NODE_ID: continue # fetch variable node id from variable selector diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index b8b5a89dc5..69959acd19 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from events.app_event import app_model_config_was_updated from extensions.ext_database import db from models.dataset import AppDatasetJoin @@ -13,7 +15,7 @@ def handle(sender, **kwargs): dataset_ids = get_dataset_ids_from_model_config(app_model_config) - app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() + app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all() removed_dataset_ids: set[str] = set() if not app_dataset_joins: diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index fcc3b63fa7..898ec1f153 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,5 +1,7 @@ from typing import cast +from sqlalchemy import select + from core.workflow.nodes import NodeType from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated @@ -15,7 +17,7 @@ def handle(sender, **kwargs): published_workflow = cast(Workflow, published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow) - app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() + app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all() removed_dataset_ids: set[str] = set() if not app_dataset_joins: diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index c67d0ca508..27efa539dc 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -1,7 +1,7 @@ import logging import time as time_module from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel from sqlalchemy import update @@ -33,7 +33,7 @@ def _get_provider_cache_key(tenant_id: str, provider_name: str) -> str: @redis_fallback(default_return=None) -def _get_last_update_timestamp(cache_key: str) -> Optional[datetime]: +def _get_last_update_timestamp(cache_key: str) -> datetime | None: """Get last update timestamp from Redis cache.""" timestamp_str = redis_client.get(cache_key) if timestamp_str: @@ -52,8 +52,8 @@ class _ProviderUpdateFilters(BaseModel): tenant_id: str provider_name: str - provider_type: Optional[str] = None - quota_type: Optional[str] = None + provider_type: str | None = None + quota_type: str | None = None class _ProviderUpdateAdditionalFilters(BaseModel): @@ -65,8 +65,8 @@ class _ProviderUpdateAdditionalFilters(BaseModel): class _ProviderUpdateValues(BaseModel): """Values to update in Provider records.""" - last_used: Optional[datetime] = None - quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression + last_used: datetime | None = None + quota_used: Any | None = None # Can be Provider.quota_used + int expression class _ProviderUpdateOperation(BaseModel): @@ -182,7 +182,7 @@ def handle(sender: Message, **kwargs): def _calculate_quota_usage( *, message: Message, system_configuration: SystemConfiguration, model_name: str -) -> Optional[int]: +) -> int | None: """Calculate quota usage based on message tokens and quota type.""" quota_unit = None for quota_configuration in system_configuration.quota_configurations: diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index fb5352ca8f..585539e2ce 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,6 +1,6 @@ import ssl from datetime import timedelta -from typing import Any, Optional +from typing import Any import pytz from celery import Celery, Task @@ -10,7 +10,7 @@ from configs import dify_config from dify_app import DifyApp -def _get_celery_ssl_options() -> Optional[dict[str, Any]]: +def _get_celery_ssl_options() -> dict[str, Any] | None: """Get SSL configuration for Celery broker/backend connections.""" # Use REDIS_USE_SSL for consistency with the main Redis client # Only apply SSL if we're using Redis as broker/backend @@ -141,9 +141,7 @@ def init_app(app: DifyApp) -> Celery: imports.append("schedule.queue_monitor_task") beat_schedule["datasets-queue-monitor"] = { "task": "schedule.queue_monitor_task.queue_monitor_task", - "schedule": timedelta( - minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30 - ), + "schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30), } if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: imports.append("schedule.check_upgradable_plugin_task") diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index 58ab023559..042bf8cc47 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from flask import Flask @@ -68,7 +67,7 @@ class Mail: case _: raise ValueError(f"Unsupported mail type {mail_type}") - def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): + def send(self, to: str, subject: str, html: str, from_: str | None = None): if not self._client: raise ValueError("Mail client is not initialized") diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 61b26b5b95..487917b2a7 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -3,7 +3,7 @@ import logging import ssl from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import redis from redis import RedisError @@ -246,7 +246,7 @@ def init_app(app: DifyApp): app.extensions["redis"] = redis_client -def redis_fallback(default_return: Optional[Any] = None): +def redis_fallback(default_return: Any | None = None): """ decorator to handle Redis operation exceptions and return a default value when Redis is unavailable. diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 7ec0889776..9053aece89 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -1,6 +1,5 @@ from collections.abc import Generator from datetime import timedelta -from typing import Optional from azure.identity import ChainedTokenCredential, DefaultAzureCredential from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas @@ -21,7 +20,7 @@ class AzureBlobStorage(BaseStorage): self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY - self.credential: Optional[ChainedTokenCredential] = None + self.credential: ChainedTokenCredential | None = None if self.account_key == "managedidentity": self.credential = DefaultAzureCredential() else: diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 33fa7d0a8d..2ffac9a92d 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -10,7 +10,6 @@ import tempfile from collections.abc import Generator from io import BytesIO from pathlib import Path -from typing import Optional import clickzetta # type: ignore[import] from pydantic import BaseModel, model_validator @@ -33,7 +32,7 @@ class ClickZettaVolumeConfig(BaseModel): vcluster: str = "default_ap" schema_name: str = "dify" volume_type: str = "table" # table|user|external - volume_name: Optional[str] = None # For external volumes + volume_name: str | None = None # For external volumes table_prefix: str = "dataset_" # Prefix for table volume names dify_prefix: str = "dify_km" # Directory prefix for User Volume permission_check: bool = True # Enable/disable permission checking @@ -154,7 +153,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.exception("Failed to initialize permission manager") raise - def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str: + def _get_volume_path(self, filename: str, dataset_id: str | None = None) -> str: """Get the appropriate volume path based on volume type.""" if self._config.volume_type == "user": # Add dify prefix for User Volume to organize files @@ -179,7 +178,7 @@ class ClickZettaVolumeStorage(BaseStorage): else: raise ValueError(f"Unsupported volume type: {self._config.volume_type}") - def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str: + def _get_volume_sql_prefix(self, dataset_id: str | None = None) -> str: """Get SQL prefix for volume operations.""" if self._config.volume_type == "user": return "USER VOLUME" diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index ef6b12fd59..6ab02ad8cc 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -7,21 +7,22 @@ Supports complete lifecycle management for knowledge base files. import json import logging +import operator from dataclasses import asdict, dataclass from datetime import datetime -from enum import Enum -from typing import Any, Optional +from enum import StrEnum, auto +from typing import Any logger = logging.getLogger(__name__) -class FileStatus(Enum): +class FileStatus(StrEnum): """File status enumeration""" - ACTIVE = "active" # Active status - ARCHIVED = "archived" # Archived - DELETED = "deleted" # Deleted (soft delete) - BACKUP = "backup" # Backup file + ACTIVE = auto() # Active status + ARCHIVED = auto() # Archived + DELETED = auto() # Deleted (soft delete) + BACKUP = auto() # Backup file @dataclass @@ -34,9 +35,9 @@ class FileMetadata: modified_at: datetime version: int | None status: FileStatus - checksum: Optional[str] = None - tags: Optional[dict[str, str]] = None - parent_version: Optional[int] = None + checksum: str | None = None + tags: dict[str, str] | None = None + parent_version: int | None = None def to_dict(self): """Convert to dictionary format""" @@ -59,7 +60,7 @@ class FileMetadata: class FileLifecycleManager: """File lifecycle manager""" - def __init__(self, storage, dataset_id: Optional[str] = None): + def __init__(self, storage, dataset_id: str | None = None): """Initialize lifecycle manager Args: @@ -74,9 +75,9 @@ class FileLifecycleManager: self._deleted_prefix = ".deleted/" # Get permission manager (if exists) - self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None) + self._permission_manager: Any | None = getattr(storage, "_permission_manager", None) - def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata: + def save_with_lifecycle(self, filename: str, data: bytes, tags: dict[str, str] | None = None) -> FileMetadata: """Save file and manage lifecycle Args: @@ -150,7 +151,7 @@ class FileLifecycleManager: logger.exception("Failed to save file with lifecycle") raise - def get_file_metadata(self, filename: str) -> Optional[FileMetadata]: + def get_file_metadata(self, filename: str) -> FileMetadata | None: """Get file metadata Args: @@ -356,7 +357,7 @@ class FileLifecycleManager: # Cleanup old versions for each file for base_filename, versions in file_versions.items(): # Sort by version number - versions.sort(key=lambda x: x[0], reverse=True) + versions.sort(key=operator.itemgetter(0), reverse=True) # Keep the newest max_versions versions, delete the rest if len(versions) > max_versions: diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index 243df92efe..eb1116638f 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -5,13 +5,12 @@ According to ClickZetta's permission model, different Volume types have differen """ import logging -from enum import Enum -from typing import Optional +from enum import StrEnum logger = logging.getLogger(__name__) -class VolumePermission(Enum): +class VolumePermission(StrEnum): """Volume permission type enumeration""" READ = "SELECT" # Corresponds to ClickZetta's SELECT permission @@ -24,7 +23,7 @@ class VolumePermission(Enum): class VolumePermissionManager: """Volume permission manager""" - def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None): + def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None): """Initialize permission manager Args: @@ -63,7 +62,7 @@ class VolumePermissionManager: self._permission_cache: dict[str, set[str]] = {} self._current_username = None # Will get current username from connection - def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool: + def check_permission(self, operation: VolumePermission, dataset_id: str | None = None) -> bool: """Check if user has permission to perform specific operation Args: @@ -126,7 +125,7 @@ class VolumePermissionManager: logger.info("User Volume permission check failed, but permission checking is disabled in this version") return False - def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool: + def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: str | None) -> bool: """Check Table Volume permission Table Volume permission rules: @@ -440,7 +439,7 @@ class VolumePermissionManager: self._permission_cache.clear() logger.debug("Permission cache cleared") - def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]: + def get_permission_summary(self, dataset_id: str | None = None) -> dict[str, bool]: """Get permission summary Args: @@ -582,7 +581,7 @@ class VolumePermissionManager: return any(pattern in file_path_lower for pattern in sensitive_patterns) - def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool: + def validate_operation(self, operation: str, dataset_id: str | None = None) -> bool: """Validate operation permission Args: @@ -614,16 +613,14 @@ class VolumePermissionManager: class VolumePermissionError(Exception): """Volume permission error exception""" - def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None): + def __init__(self, message: str, operation: str, volume_type: str, dataset_id: str | None = None): self.operation = operation self.volume_type = volume_type self.dataset_id = dataset_id super().__init__(message) -def check_volume_permission( - permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None -): +def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None): """Permission check decorator function Args: diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 46ed6e15fb..41505ab025 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -511,9 +511,9 @@ class StorageKeyLoader: upload_file_row = upload_files.get(model_id) if upload_file_row is None: raise ValueError(f"Upload file not found for id: {model_id}") - file._storage_key = upload_file_row.key + file.storage_key = upload_file_row.key elif file.transfer_method == FileTransferMethod.TOOL_FILE: tool_file_row = tool_files.get(model_id) if tool_file_row is None: raise ValueError(f"Tool file not found for id: {model_id}") - file._storage_key = tool_file_row.file_key + file.storage_key = tool_file_row.file_key diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index 8288bd54a3..b2b793d40e 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -12,4 +12,7 @@ def serialize_value_type(v: _VarTypedDict | Segment) -> str: if isinstance(v, Segment): return v.value_type.exposed_type().value else: - return v["value_type"].exposed_type().value + value_type = v.get("value_type") + if value_type is None: + raise ValueError("value_type is required but not provided") + return value_type.exposed_type().value diff --git a/api/installed_plugins.jsonl b/api/installed_plugins.jsonl deleted file mode 100644 index 463e24ae64..0000000000 --- a/api/installed_plugins.jsonl +++ /dev/null @@ -1 +0,0 @@ -{"not_installed": [], "plugin_install_failed": []} \ No newline at end of file diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index 3c039dff53..37ff1a438e 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -7,8 +7,8 @@ eliminates the need for repetitive language switching logic. """ from dataclasses import dataclass -from enum import Enum -from typing import Any, Optional, Protocol +from enum import StrEnum, auto +from typing import Any, Protocol from flask import render_template from pydantic import BaseModel, Field @@ -17,26 +17,30 @@ from extensions.ext_mail import mail from services.feature_service import BrandingModel, FeatureService -class EmailType(Enum): +class EmailType(StrEnum): """Enumeration of supported email types.""" - RESET_PASSWORD = "reset_password" - INVITE_MEMBER = "invite_member" - EMAIL_CODE_LOGIN = "email_code_login" - CHANGE_EMAIL_OLD = "change_email_old" - CHANGE_EMAIL_NEW = "change_email_new" - CHANGE_EMAIL_COMPLETED = "change_email_completed" - OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm" - OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify" - OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify" - ACCOUNT_DELETION_SUCCESS = "account_deletion_success" - ACCOUNT_DELETION_VERIFICATION = "account_deletion_verification" - ENTERPRISE_CUSTOM = "enterprise_custom" - QUEUE_MONITOR_ALERT = "queue_monitor_alert" - DOCUMENT_CLEAN_NOTIFY = "document_clean_notify" + RESET_PASSWORD = auto() + RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = auto() + INVITE_MEMBER = auto() + EMAIL_CODE_LOGIN = auto() + CHANGE_EMAIL_OLD = auto() + CHANGE_EMAIL_NEW = auto() + CHANGE_EMAIL_COMPLETED = auto() + OWNER_TRANSFER_CONFIRM = auto() + OWNER_TRANSFER_OLD_NOTIFY = auto() + OWNER_TRANSFER_NEW_NOTIFY = auto() + ACCOUNT_DELETION_SUCCESS = auto() + ACCOUNT_DELETION_VERIFICATION = auto() + ENTERPRISE_CUSTOM = auto() + QUEUE_MONITOR_ALERT = auto() + DOCUMENT_CLEAN_NOTIFY = auto() + EMAIL_REGISTER = auto() + EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto() + RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto() -class EmailLanguage(Enum): +class EmailLanguage(StrEnum): """Supported email languages with fallback handling.""" EN_US = "en-US" @@ -167,7 +171,7 @@ class EmailI18nService: email_type: EmailType, language_code: str, to: str, - template_context: Optional[dict[str, Any]] = None, + template_context: dict[str, Any] | None = None, ): """ Send internationalized email with branding support. @@ -441,6 +445,54 @@ def create_default_email_config() -> EmailI18nConfig: branded_template_path="clean_document_job_mail_template_zh-CN.html", ), }, + EmailType.EMAIL_REGISTER: { + EmailLanguage.EN_US: EmailTemplate( + subject="Register Your {application_title} Account", + template_path="register_email_template_en-US.html", + branded_template_path="without-brand/register_email_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="注册您的 {application_title} 账户", + template_path="register_email_template_zh-CN.html", + branded_template_path="without-brand/register_email_template_zh-CN.html", + ), + }, + EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST: { + EmailLanguage.EN_US: EmailTemplate( + subject="Register Your {application_title} Account", + template_path="register_email_when_account_exist_template_en-US.html", + branded_template_path="without-brand/register_email_when_account_exist_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="注册您的 {application_title} 账户", + template_path="register_email_when_account_exist_template_zh-CN.html", + branded_template_path="without-brand/register_email_when_account_exist_template_zh-CN.html", + ), + }, + EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST: { + EmailLanguage.EN_US: EmailTemplate( + subject="Reset Your {application_title} Password", + template_path="reset_password_mail_when_account_not_exist_template_en-US.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="重置您的 {application_title} 密码", + template_path="reset_password_mail_when_account_not_exist_template_zh-CN.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html", + ), + }, + EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER: { + EmailLanguage.EN_US: EmailTemplate( + subject="Reset Your {application_title} Password", + template_path="reset_password_mail_when_account_not_exist_no_register_template_en-US.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="重置您的 {application_title} 密码", + template_path="reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html", + ), + }, } return EmailI18nConfig(templates=templates) @@ -463,7 +515,7 @@ def get_default_email_i18n_service() -> EmailI18nService: # Global instance -_email_i18n_service: Optional[EmailI18nService] = None +_email_i18n_service: EmailI18nService | None = None def get_email_i18n_service() -> EmailI18nService: diff --git a/api/libs/exception.py b/api/libs/exception.py index 5970269ecd..73379dfded 100644 --- a/api/libs/exception.py +++ b/api/libs/exception.py @@ -1,11 +1,9 @@ -from typing import Optional - from werkzeug.exceptions import HTTPException class BaseHTTPException(HTTPException): error_code: str = "unknown" - data: Optional[dict] = None + data: dict | None = None def __init__(self, description=None, response=None): super().__init__(description, response) diff --git a/api/libs/external_api.py b/api/libs/external_api.py index cee80f7f24..cf91b0117f 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -69,6 +69,8 @@ def register_external_error_handlers(api: Api): headers["WWW-Authenticate"] = 'Bearer realm="api"' return data, status_code, headers + _ = handle_http_exception + @api.errorhandler(ValueError) def handle_value_error(e: ValueError): got_request_exception.send(current_app, exception=e) @@ -76,6 +78,8 @@ def register_external_error_handlers(api: Api): data = {"code": "invalid_param", "message": str(e), "status": status_code} return data, status_code + _ = handle_value_error + @api.errorhandler(AppInvokeQuotaExceededError) def handle_quota_exceeded(e: AppInvokeQuotaExceededError): got_request_exception.send(current_app, exception=e) @@ -83,15 +87,17 @@ def register_external_error_handlers(api: Api): data = {"code": "too_many_requests", "message": str(e), "status": status_code} return data, status_code + _ = handle_quota_exceeded + @api.errorhandler(Exception) def handle_general_exception(e: Exception): got_request_exception.send(current_app, exception=e) status_code = 500 - data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)}) + data = getattr(e, "data", {"message": http_status_message(status_code)}) # 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response) - if not isinstance(data, Mapping): + if not isinstance(data, dict): data = {"message": str(e)} data.setdefault("code", "unknown") @@ -101,10 +107,12 @@ def register_external_error_handlers(api: Api): exc_info: Any = sys.exc_info() if exc_info[1] is None: exc_info = None - current_app.log_exception(exc_info) # ty: ignore [invalid-argument-type] + current_app.log_exception(exc_info) return data, status_code + _ = handle_general_exception + class ExternalApi(Api): _authorizations = { diff --git a/api/libs/helper.py b/api/libs/helper.py index 139cb329de..0551470f65 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -68,7 +68,7 @@ class AppIconUrlField(fields.Raw): if isinstance(obj, dict) and "app" in obj: obj = obj["app"] - if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value: + if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE: return file_helpers.get_signed_file_url(obj.icon) return None @@ -167,13 +167,6 @@ class DatetimeString: return value -def _get_float(value): - try: - return float(value) - except (TypeError, ValueError): - raise ValueError(f"{value} is not a valid float") - - def timezone(timezone_string): if timezone_string and timezone_string in available_timezones(): return timezone_string @@ -276,8 +269,8 @@ class TokenManager: cls, token_type: str, account: Optional["Account"] = None, - email: Optional[str] = None, - additional_data: Optional[dict] = None, + email: str | None = None, + additional_data: dict | None = None, ) -> str: if account is None and email is None: raise ValueError("Account or email must be provided") @@ -319,19 +312,19 @@ class TokenManager: redis_client.delete(token_key) @classmethod - def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]: + def get_token_data(cls, token: str, token_type: str) -> dict[str, Any] | None: key = cls._get_token_key(token, token_type) token_data_json = redis_client.get(key) if token_data_json is None: logger.warning("%s token %s not found with key %s", token_type, token, key) return None - token_data: Optional[dict[str, Any]] = json.loads(token_data_json) + token_data: dict[str, Any] | None = json.loads(token_data_json) return token_data @classmethod - def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: + def _get_current_token_for_account(cls, account_id: str, token_type: str) -> str | None: key = cls._get_account_token_key(account_id, token_type) - current_token: Optional[str] = redis_client.get(key) + current_token: str | None = redis_client.get(key) return current_token @classmethod diff --git a/api/libs/oauth.py b/api/libs/oauth.py index df75b55019..35bd6c2c7c 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,6 +1,5 @@ import urllib.parse from dataclasses import dataclass -from typing import Optional import requests @@ -41,7 +40,7 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" - def get_authorization_url(self, invite_token: Optional[str] = None): + def get_authorization_url(self, invite_token: str | None = None): params = { "client_id": self.client_id, "redirect_uri": self.redirect_uri, @@ -93,7 +92,7 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" - def get_authorization_url(self, invite_token: Optional[str] = None): + def get_authorization_url(self, invite_token: str | None = None): params = { "client_id": self.client_id, "response_type": "code", diff --git a/api/libs/orjson.py b/api/libs/orjson.py index 2fc5ce8dd3..6e7c6b738d 100644 --- a/api/libs/orjson.py +++ b/api/libs/orjson.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import orjson @@ -6,6 +6,6 @@ import orjson def orjson_dumps( obj: Any, encoding: str = "utf-8", - option: Optional[int] = None, + option: int | None = None, ) -> str: return orjson.dumps(obj, option=option).decode(encoding) diff --git a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py index 9e02fe1a03..6416f30619 100644 --- a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py +++ b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py @@ -20,7 +20,7 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('pipeline_built_in_templates', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), sa.Column('name', sa.String(length=255), nullable=False), sa.Column('description', sa.Text(), nullable=False), @@ -35,7 +35,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey') ) op.create_table('pipeline_customized_templates', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), sa.Column('name', sa.String(length=255), nullable=False), @@ -52,7 +52,7 @@ def upgrade(): batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False) op.create_table('pipelines', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('name', sa.String(length=255), nullable=False), sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False), diff --git a/api/migrations/versions/2025_05_30_0033-c459994abfa8_add_pipeline_info_3.py b/api/migrations/versions/2025_05_30_0033-c459994abfa8_add_pipeline_info_3.py index 0b010d535d..ad7deaaac5 100644 --- a/api/migrations/versions/2025_05_30_0033-c459994abfa8_add_pipeline_info_3.py +++ b/api/migrations/versions/2025_05_30_0033-c459994abfa8_add_pipeline_info_3.py @@ -20,7 +20,7 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('datasource_oauth_params', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), sa.Column('plugin_id', models.types.StringUUID(), nullable=False), sa.Column('provider', sa.String(length=255), nullable=False), sa.Column('system_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False), @@ -28,7 +28,7 @@ def upgrade(): sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx') ) op.create_table('datasource_providers', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('plugin_id', models.types.StringUUID(), nullable=False), sa.Column('provider', sa.String(length=255), nullable=False), diff --git a/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py b/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py index a695adc74a..071f15adb4 100644 --- a/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py +++ b/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py @@ -20,7 +20,7 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('document_pipeline_execution_logs', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), sa.Column('document_id', models.types.StringUUID(), nullable=False), sa.Column('datasource_type', sa.String(length=255), nullable=False), diff --git a/api/migrations/versions/2025_09_01_1443-8c5088481127_add_pipeline_info_17.py b/api/migrations/versions/2025_09_01_1443-8c5088481127_add_pipeline_info_17.py index 0269c6a32d..bd9f254cb1 100644 --- a/api/migrations/versions/2025_09_01_1443-8c5088481127_add_pipeline_info_17.py +++ b/api/migrations/versions/2025_09_01_1443-8c5088481127_add_pipeline_info_17.py @@ -20,7 +20,7 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('pipeline_recommended_plugins', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), sa.Column('plugin_id', sa.Text(), nullable=False), sa.Column('provider_name', sa.Text(), nullable=False), sa.Column('position', sa.Integer(), nullable=False), diff --git a/api/migrations/versions/2025_09_11_1537-cf7c38a32b2d_add_credential_status_for_provider_table.py b/api/migrations/versions/2025_09_11_1537-cf7c38a32b2d_add_credential_status_for_provider_table.py new file mode 100644 index 0000000000..17467e6495 --- /dev/null +++ b/api/migrations/versions/2025_09_11_1537-cf7c38a32b2d_add_credential_status_for_provider_table.py @@ -0,0 +1,33 @@ +"""Add credential status for provider table + +Revision ID: cf7c38a32b2d +Revises: c20211f18133 +Create Date: 2025-09-11 15:37:17.771298 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'cf7c38a32b2d' +down_revision = 'c20211f18133' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_status', sa.String(length=20), server_default=sa.text("'active'::character varying"), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.drop_column('credential_status') + + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/models/account.py b/api/models/account.py index 019159d2da..8c1f990aa2 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -7,6 +7,7 @@ import sqlalchemy as sa from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor +from typing_extensions import deprecated from models.base import Base @@ -89,24 +90,24 @@ class Account(UserMixin, Base): id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column(String(255)) email: Mapped[str] = mapped_column(String(255)) - password: Mapped[Optional[str]] = mapped_column(String(255)) - password_salt: Mapped[Optional[str]] = mapped_column(String(255)) - avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - interface_language: Mapped[Optional[str]] = mapped_column(String(255)) - interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - timezone: Mapped[Optional[str]] = mapped_column(String(255)) - last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + password: Mapped[str | None] = mapped_column(String(255)) + password_salt: Mapped[str | None] = mapped_column(String(255)) + avatar: Mapped[str | None] = mapped_column(String(255), nullable=True) + interface_language: Mapped[str | None] = mapped_column(String(255)) + interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True) + timezone: Mapped[str | None] = mapped_column(String(255)) + last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True) last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying")) - initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) @reconstructor def init_on_load(self): - self.role: Optional[TenantAccountRole] = None - self._current_tenant: Optional[Tenant] = None + self.role: TenantAccountRole | None = None + self._current_tenant: Tenant | None = None @property def is_password_set(self): @@ -187,7 +188,28 @@ class Account(UserMixin, Base): return TenantAccountRole.is_admin_role(self.role) @property + @deprecated("Use has_edit_permission instead.") def is_editor(self): + """Determines if the account has edit permissions in their current tenant (workspace). + + This property checks if the current role has editing privileges, which includes: + - `OWNER` + - `ADMIN` + - `EDITOR` + + Note: This checks for any role with editing permission, not just the 'EDITOR' role specifically. + """ + return self.has_edit_permission + + @property + def has_edit_permission(self): + """Determines if the account has editing permissions in their current tenant (workspace). + + This property checks if the current role has editing privileges, which includes: + - `OWNER` + - `ADMIN` + - `EDITOR` + """ return TenantAccountRole.is_editing_role(self.role) @property @@ -210,18 +232,20 @@ class Tenant(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column(String(255)) - encrypt_public_key: Mapped[Optional[str]] = mapped_column(sa.Text) + encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text) plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying")) status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) - custom_config: Mapped[Optional[str]] = mapped_column(sa.Text) + custom_config: Mapped[str | None] = mapped_column(sa.Text) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: - return ( - db.session.query(Account) - .where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) - .all() + return list( + db.session.scalars( + select(Account).where( + Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id + ) + ).all() ) @property @@ -247,7 +271,7 @@ class TenantAccountJoin(Base): account_id: Mapped[str] = mapped_column(StringUUID) current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) role: Mapped[str] = mapped_column(String(16), server_default="normal") - invited_by: Mapped[Optional[str]] = mapped_column(StringUUID) + invited_by: Mapped[str | None] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) @@ -281,10 +305,10 @@ class InvitationCode(Base): batch: Mapped[str] = mapped_column(String(255)) code: Mapped[str] = mapped_column(String(32)) status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying")) - used_at: Mapped[Optional[datetime]] = mapped_column(DateTime) - used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID) - used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) - deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + used_at: Mapped[datetime | None] = mapped_column(DateTime) + used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID) + used_by_account_id: Mapped[str | None] = mapped_column(StringUUID) + deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/dataset.py b/api/models/dataset.py index 0cd53138cc..248c436dfa 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -10,7 +10,7 @@ import re import time from datetime import datetime from json import JSONDecodeError -from typing import Any, Optional, cast +from typing import Any, cast import sqlalchemy as sa from sqlalchemy import DateTime, String, func, select @@ -56,7 +56,7 @@ class Dataset(Base): provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'::character varying")) permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'::character varying")) data_source_type = mapped_column(String(255)) - indexing_technique: Mapped[Optional[str]] = mapped_column(String(255)) + indexing_technique: Mapped[str | None] = mapped_column(String(255)) index_struct = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -241,7 +241,9 @@ class Dataset(Base): @property def doc_metadata(self): - dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all() + dataset_metadatas = db.session.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id) + ).all() doc_metadata = [ { @@ -255,35 +257,35 @@ class Dataset(Base): doc_metadata.append( { "id": "built-in", - "name": BuiltInField.document_name.value, + "name": BuiltInField.document_name, "type": "string", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.uploader.value, + "name": BuiltInField.uploader, "type": "string", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.upload_date.value, + "name": BuiltInField.upload_date, "type": "time", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.last_update_date.value, + "name": BuiltInField.last_update_date, "type": "time", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.source.value, + "name": BuiltInField.source, "type": "string", } ) @@ -361,42 +363,42 @@ class Document(Base): created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) # start processing - processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + processing_started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # parsing file_id = mapped_column(sa.Text, nullable=True) - word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable - parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable + parsing_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # cleaning - cleaning_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + cleaning_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # split - splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + splitting_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # indexing - tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) - indexing_latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + indexing_latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # pause - is_paused: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) + is_paused: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) paused_by = mapped_column(StringUUID, nullable=True) - paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + paused_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # error error = mapped_column(sa.Text, nullable=True) - stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # basic fields indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'::character varying")) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) - disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) archived: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) archived_reason = mapped_column(String(255), nullable=True) archived_by = mapped_column(StringUUID, nullable=True) - archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) doc_type = mapped_column(String(40), nullable=True) doc_metadata = mapped_column(JSONB, nullable=True) @@ -575,7 +577,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.source, "type": "string", - "value": MetadataDataSource[self.data_source_type].value, + "value": MetadataDataSource[self.data_source_type], } ) return built_in_fields @@ -708,17 +710,17 @@ class DocumentSegment(Base): # basic fields hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) - disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'::character varying")) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) error = mapped_column(sa.Text, nullable=True) - stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) @property def dataset(self): @@ -881,8 +883,8 @@ class ChildChunk(Base): updated_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) - indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) error = mapped_column(sa.Text, nullable=True) @property @@ -1109,13 +1111,11 @@ class ExternalKnowledgeApis(Base): @property def dataset_bindings(self) -> list[dict[str, Any]]: - external_knowledge_bindings = ( - db.session.query(ExternalKnowledgeBindings) - .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) - .all() - ) + external_knowledge_bindings = db.session.scalars( + select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) + ).all() dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] - datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() + datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() dataset_bindings: list[dict[str, Any]] = [] for dataset in datasets: dataset_bindings.append({"id": dataset.id, "name": dataset.name}) @@ -1226,7 +1226,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_built_in_templates" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, server_default=db.text("uuidv7()")) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False) chunk_structure = db.Column(db.String(255), nullable=False) @@ -1257,7 +1257,7 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] db.Index("pipeline_customized_template_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, server_default=db.text("uuidv7()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False) @@ -1284,7 +1284,7 @@ class Pipeline(Base): # type: ignore[name-defined] __tablename__ = "pipelines" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, server_default=db.text("uuidv7()")) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) @@ -1307,7 +1307,7 @@ class DocumentPipelineExecutionLog(Base): db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, server_default=db.text("uuidv7()")) pipeline_id = db.Column(StringUUID, nullable=False) document_id = db.Column(StringUUID, nullable=False) datasource_type = db.Column(db.String(255), nullable=False) @@ -1322,7 +1322,7 @@ class PipelineRecommendedPlugin(Base): __tablename__ = "pipeline_recommended_plugins" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, server_default=db.text("uuidv7()")) plugin_id = db.Column(db.Text, nullable=False) provider_name = db.Column(db.Text, nullable=False) position = db.Column(db.Integer, nullable=False, default=0) diff --git a/api/models/model.py b/api/models/model.py index feeaaa0da5..783bef6f88 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,7 +3,7 @@ import re import uuid from collections.abc import Mapping from datetime import datetime -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, Literal, Optional, cast import sqlalchemy as sa @@ -62,9 +62,9 @@ class AppMode(StrEnum): raise ValueError(f"invalid mode value {value}") -class IconType(Enum): - IMAGE = "image" - EMOJI = "emoji" +class IconType(StrEnum): + IMAGE = auto() + EMOJI = auto() class App(Base): @@ -76,9 +76,9 @@ class App(Base): name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying")) mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji + icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji icon = mapped_column(String(255)) - icon_background: Mapped[Optional[str]] = mapped_column(String(255)) + icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) @@ -90,7 +90,7 @@ class App(Base): is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) tracing = mapped_column(sa.Text, nullable=True) - max_active_requests: Mapped[Optional[int]] + max_active_requests: Mapped[int | None] created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -134,7 +134,7 @@ class App(Base): return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" @property - def tenant(self) -> Optional[Tenant]: + def tenant(self) -> Tenant | None: tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @@ -149,15 +149,15 @@ class App(Base): if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get( "strategy", "" ) in {"function_call", "react"}: - self.mode = AppMode.AGENT_CHAT.value + self.mode = AppMode.AGENT_CHAT db.session.commit() return True return False @property def mode_compatible_with_agent(self) -> str: - if self.mode == AppMode.CHAT.value and self.is_agent: - return AppMode.AGENT_CHAT.value + if self.mode == AppMode.CHAT and self.is_agent: + return AppMode.AGENT_CHAT return str(self.mode) @@ -292,7 +292,7 @@ class App(Base): return tags or [] @property - def author_name(self) -> Optional[str]: + def author_name(self) -> str | None: if self.created_by: account = db.session.query(Account).where(Account.id == self.created_by).first() if account: @@ -335,7 +335,7 @@ class AppModelConfig(Base): file_upload = mapped_column(sa.Text) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @@ -547,7 +547,7 @@ class RecommendedApp(Base): updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @@ -571,12 +571,12 @@ class InstalledApp(Base): created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @property - def tenant(self) -> Optional[Tenant]: + def tenant(self) -> Tenant | None: tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @@ -712,9 +712,9 @@ class Conversation(Base): @property def model_config(self): model_config = {} - app_model_config: Optional[AppModelConfig] = None + app_model_config: AppModelConfig | None = None - if self.mode == AppMode.ADVANCED_CHAT.value: + if self.mode == AppMode.ADVANCED_CHAT: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) model_config = override_model_configs @@ -813,7 +813,7 @@ class Conversation(Base): @property def status_count(self): - messages = db.session.query(Message).where(Message.conversation_id == self.id).all() + messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all() status_counts = { WorkflowExecutionStatus.RUNNING: 0, WorkflowExecutionStatus.SUCCEEDED: 0, @@ -860,7 +860,7 @@ class Conversation(Base): return None @property - def from_account_name(self) -> Optional[str]: + def from_account_name(self) -> str | None: if self.from_account_id: account = db.session.query(Account).where(Account.id == self.from_account_id).first() if account: @@ -935,14 +935,14 @@ class Message(Base): status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) error = mapped_column(sa.Text) message_metadata = mapped_column(sa.Text) - invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True) from_source: Mapped[str] = mapped_column(String(255), nullable=False) - from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) - from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) + from_end_user_id: Mapped[str | None] = mapped_column(StringUUID) + from_account_id: Mapped[str | None] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) @property def inputs(self) -> dict[str, Any]: @@ -1092,7 +1092,7 @@ class Message(Base): @property def feedbacks(self): - feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all() + feedbacks = db.session.scalars(select(MessageFeedback).where(MessageFeedback.message_id == self.id)).all() return feedbacks @property @@ -1147,7 +1147,7 @@ class Message(Base): def message_files(self) -> list[dict[str, Any]]: from factories import file_factory - message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() + message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all() current_app = db.session.query(App).where(App.id == self.app_id).first() if not current_app: raise ValueError(f"App {self.app_id} not found") @@ -1339,9 +1339,9 @@ class MessageFile(Base): message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) transfer_method: Mapped[str] = mapped_column(String(255), nullable=False) - url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) - belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + url: Mapped[str | None] = mapped_column(sa.Text, nullable=True) + belongs_to: Mapped[str | None] = mapped_column(String(255), nullable=True) + upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1358,8 +1358,8 @@ class MessageAnnotation(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id: Mapped[str] = mapped_column(StringUUID) - conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) - message_id: Mapped[Optional[str]] = mapped_column(StringUUID) + conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) + message_id: Mapped[str | None] = mapped_column(StringUUID) question = mapped_column(sa.Text, nullable=True) content = mapped_column(sa.Text, nullable=False) hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) @@ -1461,6 +1461,14 @@ class OperationLog(Base): updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) +class DefaultEndUserSessionID(StrEnum): + """ + End User Session ID enum. + """ + + DEFAULT_SESSION_ID = "DEFAULT-USER" + + class EndUser(Base, UserMixin): __tablename__ = "end_users" __table_args__ = ( @@ -1747,18 +1755,18 @@ class MessageAgentThought(Base): # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design tool_process_data = mapped_column(sa.Text, nullable=True) message = mapped_column(sa.Text, nullable=True) - message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) message_unit_price = mapped_column(sa.Numeric, nullable=True) message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) message_files = mapped_column(sa.Text, nullable=True) answer = mapped_column(sa.Text, nullable=True) - answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) answer_unit_price = mapped_column(sa.Numeric, nullable=True) answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) - tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) total_price = mapped_column(sa.Numeric, nullable=True) currency = mapped_column(String, nullable=True) - latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1856,11 +1864,11 @@ class DatasetRetrieverResource(Base): document_name = mapped_column(sa.Text, nullable=False) data_source_type = mapped_column(sa.Text, nullable=True) segment_id = mapped_column(StringUUID, nullable=True) - score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + score: Mapped[float | None] = mapped_column(sa.Float, nullable=True) content = mapped_column(sa.Text, nullable=False) - hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) - word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) - segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + hit_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + segment_position: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) index_node_hash = mapped_column(sa.Text, nullable=True) retriever_from = mapped_column(sa.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) diff --git a/api/models/oauth.py b/api/models/oauth.py index 9869fb40ff..b6a76793fc 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -15,7 +15,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, server_default=db.text("uuidv7()")) plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) @@ -28,7 +28,7 @@ class DatasourceProvider(Base): db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"), db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, server_default=db.text("uuidv7()")) tenant_id = db.Column(StringUUID, nullable=False) name: Mapped[str] = db.Column(db.String(255), nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) @@ -50,7 +50,7 @@ class DatasourceOauthTenantParamConfig(Base): db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, server_default=db.text("uuidv7()")) tenant_id = db.Column(StringUUID, nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) diff --git a/api/models/provider.py b/api/models/provider.py index 9a344ea56d..aacc6e505a 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,7 +1,6 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum, auto from functools import cached_property -from typing import Optional import sqlalchemy as sa from sqlalchemy import DateTime, String, func, text @@ -12,9 +11,9 @@ from .engine import db from .types import StringUUID -class ProviderType(Enum): - CUSTOM = "custom" - SYSTEM = "system" +class ProviderType(StrEnum): + CUSTOM = auto() + SYSTEM = auto() @staticmethod def value_of(value: str) -> "ProviderType": @@ -24,14 +23,14 @@ class ProviderType(Enum): raise ValueError(f"No matching enum found for value '{value}'") -class ProviderQuotaType(Enum): - PAID = "paid" +class ProviderQuotaType(StrEnum): + PAID = auto() """hosted paid quota""" - FREE = "free" + FREE = auto() """third-party free quota""" - TRIAL = "trial" + TRIAL = auto() """hosted trial quota""" @staticmethod @@ -63,14 +62,14 @@ class Provider(Base): String(40), nullable=False, server_default=text("'custom'::character varying") ) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) - last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - quota_type: Mapped[Optional[str]] = mapped_column( + quota_type: Mapped[str | None] = mapped_column( String(40), nullable=True, server_default=text("''::character varying") ) - quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True) - quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0) + quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True) + quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, default=0) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -133,7 +132,7 @@ class ProviderModel(Base): provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) - credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -201,17 +200,17 @@ class ProviderOrder(Base): provider_name: Mapped[str] = mapped_column(String(255), nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False) - payment_id: Mapped[Optional[str]] = mapped_column(String(191)) - transaction_id: Mapped[Optional[str]] = mapped_column(String(191)) + payment_id: Mapped[str | None] = mapped_column(String(191)) + transaction_id: Mapped[str | None] = mapped_column(String(191)) quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) - currency: Mapped[Optional[str]] = mapped_column(String(40)) - total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer) + currency: Mapped[str | None] = mapped_column(String(40)) + total_amount: Mapped[int | None] = mapped_column(sa.Integer) payment_status: Mapped[str] = mapped_column( String(40), nullable=False, server_default=text("'wait_pay'::character varying") ) - paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime) - pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) - refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + paid_at: Mapped[datetime | None] = mapped_column(DateTime) + pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime) + refunded_at: Mapped[datetime | None] = mapped_column(DateTime) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -255,9 +254,9 @@ class LoadBalancingModelConfig(Base): model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) - credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) - credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True) + encrypted_config: Mapped[str | None] = mapped_column(sa.Text, nullable=True) + credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 8456d65a87..5b4c486bc4 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,6 +1,5 @@ import json from datetime import datetime -from typing import Optional import sqlalchemy as sa from sqlalchemy import DateTime, String, func @@ -27,7 +26,7 @@ class DataSourceOauthBinding(Base): source_info = mapped_column(JSONB, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) + disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) class DataSourceApiKeyAuthBinding(Base): @@ -45,7 +44,7 @@ class DataSourceApiKeyAuthBinding(Base): credentials = mapped_column(sa.Text, nullable=True) # JSON created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) + disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) def to_dict(self): return { diff --git a/api/models/task.py b/api/models/task.py index 9a52fcfb41..3da1674536 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional import sqlalchemy as sa from celery import states @@ -32,7 +31,7 @@ class CeleryTask(Base): args = mapped_column(sa.LargeBinary, nullable=True) kwargs = mapped_column(sa.LargeBinary, nullable=True) worker = mapped_column(String(155), nullable=True) - retries: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + retries: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) queue = mapped_column(String(155), nullable=True) @@ -46,4 +45,4 @@ class CeleryTaskSet(Base): ) taskset_id = mapped_column(String(155), unique=True) result = mapped_column(db.PickleType, nullable=True) - date_done: Mapped[Optional[datetime]] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) + date_done: Mapped[datetime | None] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 277a9d032c..dae3d6eb88 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from urllib.parse import urlparse import sqlalchemy as sa @@ -501,13 +501,13 @@ class ToolFile(TypeBase): # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID) # conversation id - conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) # file key file_key: Mapped[str] = mapped_column(String(255), nullable=False) # mime type mimetype: Mapped[str] = mapped_column(String(255), nullable=False) # original url - original_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True, default=None) + original_url: Mapped[str | None] = mapped_column(String(2048), nullable=True, default=None) # name name: Mapped[str] = mapped_column(default="") # size diff --git a/api/models/workflow.py b/api/models/workflow.py index f654679956..a25d65669a 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,8 +2,8 @@ import json import logging from collections.abc import Mapping, Sequence from datetime import datetime -from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from enum import StrEnum +from typing import TYPE_CHECKING, Any, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -43,7 +43,7 @@ from .types import EnumText, StringUUID logger = logging.getLogger(__name__) -class WorkflowType(Enum): +class WorkflowType(StrEnum): """ Workflow Type Enum """ @@ -133,7 +133,7 @@ class Workflow(Base): _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) + updated_by: Mapped[str | None] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, @@ -531,18 +531,18 @@ class WorkflowRun(Base): type: Mapped[str] = mapped_column(String(255)) triggered_from: Mapped[str] = mapped_column(String(255)) version: Mapped[str] = mapped_column(String(255)) - graph: Mapped[Optional[str]] = mapped_column(sa.Text) - inputs: Mapped[Optional[str]] = mapped_column(sa.Text) + graph: Mapped[str | None] = mapped_column(sa.Text) + inputs: Mapped[str | None] = mapped_column(sa.Text) status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded - outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error: Mapped[Optional[str]] = mapped_column(sa.Text) + outputs: Mapped[str | None] = mapped_column(sa.Text, default="{}") + error: Mapped[str | None] = mapped_column(sa.Text) elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + finished_at: Mapped[datetime | None] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) @property @@ -739,24 +739,24 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) triggered_from: Mapped[str] = mapped_column(String(255)) - workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) index: Mapped[int] = mapped_column(sa.Integer) - predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255)) - node_execution_id: Mapped[Optional[str]] = mapped_column(String(255)) + predecessor_node_id: Mapped[str | None] = mapped_column(String(255)) + node_execution_id: Mapped[str | None] = mapped_column(String(255)) node_id: Mapped[str] = mapped_column(String(255)) node_type: Mapped[str] = mapped_column(String(255)) title: Mapped[str] = mapped_column(String(255)) - inputs: Mapped[Optional[str]] = mapped_column(sa.Text) - process_data: Mapped[Optional[str]] = mapped_column(sa.Text) - outputs: Mapped[Optional[str]] = mapped_column(sa.Text) + inputs: Mapped[str | None] = mapped_column(sa.Text) + process_data: Mapped[str | None] = mapped_column(sa.Text) + outputs: Mapped[str | None] = mapped_column(sa.Text) status: Mapped[str] = mapped_column(String(255)) - error: Mapped[Optional[str]] = mapped_column(sa.Text) + error: Mapped[str | None] = mapped_column(sa.Text) elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) - execution_metadata: Mapped[Optional[str]] = mapped_column(sa.Text) + execution_metadata: Mapped[str | None] = mapped_column(sa.Text) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) created_by_role: Mapped[str] = mapped_column(String(255)) created_by: Mapped[str] = mapped_column(StringUUID) - finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + finished_at: Mapped[datetime | None] = mapped_column(DateTime) offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship( "WorkflowNodeExecutionOffload", @@ -953,7 +953,7 @@ class WorkflowNodeExecutionOffload(Base): ) -class WorkflowAppLogCreatedFrom(Enum): +class WorkflowAppLogCreatedFrom(StrEnum): """ Workflow App Log Created From Enum """ diff --git a/api/pyproject.toml b/api/pyproject.toml index 24b2b6ebe4..5db0d045fe 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -79,7 +79,7 @@ dependencies = [ "sqlalchemy~=2.0.29", "starlette==0.47.2", "tiktoken~=0.9.0", - "transformers~=4.53.0", + "transformers~=4.56.1", "unstructured[docx,epub,md,ppt,pptx]~=0.16.1", "weave~=0.51.0", "yarl~=1.18.3", @@ -169,6 +169,8 @@ dev = [ "types-redis>=4.6.0.20241004", "celery-types>=0.23.0", "mypy~=1.17.1", + "locust>=2.40.4", + "sseclient-py>=1.8.0", ] ############################################################ diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index a3a5f2044e..7c59c2ca28 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -1,24 +1,44 @@ { - "include": ["models", "configs"], - "exclude": [".venv", "tests/", "migrations/"], - "ignore": [ - "core/", - "controllers/", - "tasks/", - "services/", - "schedule/", - "extensions/", - "utils/", - "repositories/", - "libs/", - "fields/", - "factories/", - "events/", - "contexts/", - "constants/", - "commands.py" + "include": ["."], + "exclude": [ + ".venv", + "tests/", + "migrations/", + "core/rag", + "extensions", + "libs", + "controllers/console/datasets", + "controllers/service_api/dataset", + "core/ops", + "core/tools", + "core/model_runtime", + "core/workflow", + "core/app/app_config/easy_ui_based_app/dataset" ], "typeCheckingMode": "strict", + "allowedUntypedLibraries": [ + "flask_restx", + "flask_login", + "opentelemetry.instrumentation.celery", + "opentelemetry.instrumentation.flask", + "opentelemetry.instrumentation.requests", + "opentelemetry.instrumentation.sqlalchemy", + "opentelemetry.instrumentation.redis" + ], + "reportUnknownMemberType": "hint", + "reportUnknownParameterType": "hint", + "reportUnknownArgumentType": "hint", + "reportUnknownVariableType": "hint", + "reportUnknownLambdaType": "hint", + "reportMissingParameterType": "hint", + "reportMissingTypeArgument": "hint", + "reportUnnecessaryContains": "hint", + "reportUnnecessaryComparison": "hint", + "reportUnnecessaryCast": "hint", + "reportUnnecessaryIsInstance": "hint", + "reportUntypedFunctionDecorator": "hint", + + "reportAttributeAccessIssue": "hint", "pythonVersion": "3.11", "pythonPlatform": "All" } diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 00a2d1f87d..fa2c94b623 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -11,7 +11,7 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m from collections.abc import Sequence from datetime import datetime -from typing import Optional, Protocol +from typing import Protocol from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel @@ -44,7 +44,7 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr app_id: str, workflow_id: str, node_id: str, - ) -> Optional[WorkflowNodeExecutionModel]: + ) -> WorkflowNodeExecutionModel | None: """ Get the most recent execution for a specific node. @@ -87,8 +87,8 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr def get_execution_by_id( self, execution_id: str, - tenant_id: Optional[str] = None, - ) -> Optional[WorkflowNodeExecutionModel]: + tenant_id: str | None = None, + ) -> WorkflowNodeExecutionModel | None: """ Get a workflow node execution by its ID. diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 59e7baeb79..3ac28fad75 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -36,7 +36,7 @@ Example: from collections.abc import Sequence from datetime import datetime -from typing import Optional, Protocol +from typing import Protocol from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -58,7 +58,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): app_id: str, triggered_from: str, limit: int = 20, - last_id: Optional[str] = None, + last_id: str | None = None, ) -> InfiniteScrollPagination: """ Get paginated workflow runs with filtering. @@ -90,7 +90,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): tenant_id: str, app_id: str, run_id: str, - ) -> Optional[WorkflowRun]: + ) -> WorkflowRun | None: """ Get a specific workflow run by ID. diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 5ed278b15d..9bc6acc41f 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -7,7 +7,6 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. from collections.abc import Sequence from datetime import datetime -from typing import Optional from sqlalchemy import asc, delete, desc, select from sqlalchemy.orm import Session, sessionmaker @@ -49,7 +48,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut app_id: str, workflow_id: str, node_id: str, - ) -> Optional[WorkflowNodeExecutionModel]: + ) -> WorkflowNodeExecutionModel | None: """ Get the most recent execution for a specific node. @@ -116,8 +115,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut def get_execution_by_id( self, execution_id: str, - tenant_id: Optional[str] = None, - ) -> Optional[WorkflowNodeExecutionModel]: + tenant_id: str | None = None, + ) -> WorkflowNodeExecutionModel | None: """ Get a workflow node execution by its ID. diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 6294846f5e..205f8c87ee 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -22,7 +22,6 @@ Implementation Notes: import logging from collections.abc import Sequence from datetime import datetime -from typing import Optional from sqlalchemy import delete, select from sqlalchemy.orm import Session, sessionmaker @@ -61,7 +60,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): app_id: str, triggered_from: str, limit: int = 20, - last_id: Optional[str] = None, + last_id: str | None = None, ) -> InfiniteScrollPagination: """ Get paginated workflow runs with filtering. @@ -107,7 +106,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): tenant_id: str, app_id: str, run_id: str, - ) -> Optional[WorkflowRun]: + ) -> WorkflowRun | None: """ Get a specific workflow run by ID with tenant and app isolation. """ diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 63e6132b6a..9efd46ba5d 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -1,6 +1,6 @@ import datetime import time -from typing import Optional, TypedDict +from typing import TypedDict import click from sqlalchemy import func, select @@ -17,7 +17,7 @@ from services.feature_service import FeatureService class CleanupConfig(TypedDict): clean_day: datetime.datetime - plan_filter: Optional[str] + plan_filter: str | None add_logs: bool @@ -96,11 +96,11 @@ def clean_unused_datasets_task(): break for dataset in datasets: - dataset_query = ( - db.session.query(DatasetQuery) - .where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id) - .all() - ) + dataset_query = db.session.scalars( + select(DatasetQuery).where( + DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id + ) + ).all() if not dataset_query or len(dataset_query) == 0: try: @@ -121,15 +121,13 @@ def clean_unused_datasets_task(): if should_clean: # Add auto disable log if required if add_logs: - documents = ( - db.session.query(Document) - .where( + documents = db.session.scalars( + select(Document).where( Document.dataset_id == dataset.id, Document.enabled == True, Document.archived == False, ) - .all() - ) + ).all() for document in documents: dataset_auto_disable_log = DatasetAutoDisableLog( tenant_id=dataset.tenant_id, diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 9e32ecc716..ef6edd6709 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -3,6 +3,7 @@ import time from collections import defaultdict import click +from sqlalchemy import select import app from configs import dify_config @@ -31,9 +32,9 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: - dataset_auto_disable_logs = ( - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all() - ) + dataset_auto_disable_logs = db.session.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) + ).all() # group by tenant_id dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) for dataset_auto_disable_log in dataset_auto_disable_logs: diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 1bfeb869e2..1befa0e8b5 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -1,6 +1,8 @@ import time +from collections.abc import Sequence import click +from sqlalchemy import select import app from configs import dify_config @@ -15,11 +17,9 @@ def update_tidb_serverless_status_task(): start_at = time.perf_counter() try: # check the number of idle tidb serverless - tidb_serverless_list = ( - db.session.query(TidbAuthBinding) - .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") - .all() - ) + tidb_serverless_list = db.session.scalars( + select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + ).all() if len(tidb_serverless_list) == 0: return # update tidb serverless status @@ -32,7 +32,7 @@ def update_tidb_serverless_status_task(): click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green")) -def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): +def update_clusters(tidb_serverless_list: Sequence[TidbAuthBinding]): try: # batch 20 for i in range(0, len(tidb_serverless_list), 20): diff --git a/api/services/account_service.py b/api/services/account_service.py index a76792f88e..8362e415c1 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -5,7 +5,7 @@ import secrets import uuid from datetime import UTC, datetime, timedelta from hashlib import sha256 -from typing import Any, Optional, cast +from typing import Any, cast from pydantic import BaseModel from sqlalchemy import func @@ -37,7 +37,6 @@ from services.billing_service import BillingService from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, - AccountNotFoundError, AccountNotLinkTenantError, AccountPasswordError, AccountRegisterError, @@ -65,7 +64,11 @@ from tasks.mail_owner_transfer_task import ( send_old_owner_transfer_notify_email_task, send_owner_transfer_confirm_task, ) -from tasks.mail_reset_password_task import send_reset_password_mail_task +from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist +from tasks.mail_reset_password_task import ( + send_reset_password_mail_task, + send_reset_password_mail_task_when_account_not_exist, +) logger = logging.getLogger(__name__) @@ -82,8 +85,9 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) class AccountService: reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) + email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1) email_code_login_rate_limiter = RateLimiter( - prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1 + prefix="email_code_login_rate_limit", max_attempts=3, time_window=300 * 1 ) email_code_account_deletion_rate_limiter = RateLimiter( prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1 @@ -95,6 +99,7 @@ class AccountService: FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5 CHANGE_EMAIL_MAX_ERROR_LIMITS = 5 OWNER_TRANSFER_MAX_ERROR_LIMITS = 5 + EMAIL_REGISTER_MAX_ERROR_LIMITS = 5 @staticmethod def _get_refresh_token_key(refresh_token: str) -> str: @@ -166,12 +171,12 @@ class AccountService: return token @staticmethod - def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account: + def authenticate(email: str, password: str, invite_token: str | None = None) -> Account: """authenticate account with email and password""" account = db.session.query(Account).filter_by(email=email).first() if not account: - raise AccountNotFoundError() + raise AccountPasswordError("Invalid email or password.") if account.status == AccountStatus.BANNED.value: raise AccountLoginError("Account is banned.") @@ -223,9 +228,9 @@ class AccountService: email: str, name: str, interface_language: str, - password: Optional[str] = None, + password: str | None = None, interface_theme: str = "light", - is_setup: Optional[bool] = False, + is_setup: bool | None = False, ) -> Account: """create account""" if not FeatureService.get_system_features().is_allow_register and not is_setup: @@ -246,6 +251,8 @@ class AccountService: account.name = name if password: + valid_password(password) + # generate password salt salt = secrets.token_bytes(16) base64_salt = base64.b64encode(salt).decode() @@ -269,7 +276,7 @@ class AccountService: @staticmethod def create_account_and_tenant( - email: str, name: str, interface_language: str, password: Optional[str] = None + email: str, name: str, interface_language: str, password: str | None = None ) -> Account: """create account""" account = AccountService.create_account( @@ -294,7 +301,9 @@ class AccountService: if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError - raise EmailCodeAccountDeletionRateLimitExceededError() + raise EmailCodeAccountDeletionRateLimitExceededError( + int(cls.email_code_account_deletion_rate_limiter.time_window / 60) + ) send_account_deletion_verification_code.delay(to=email, code=code) @@ -321,7 +330,7 @@ class AccountService: """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: Optional[AccountIntegrate] = ( + account_integrate: AccountIntegrate | None = ( db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() ) @@ -382,7 +391,7 @@ class AccountService: db.session.commit() @staticmethod - def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair: + def login(account: Account, *, ip_address: str | None = None) -> TokenPair: if ip_address: AccountService.update_login_info(account=account, ip_address=ip_address) @@ -430,9 +439,10 @@ class AccountService: @classmethod def send_reset_password_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", + is_allow_register: bool = False, ): account_email = account.email if account else email if account_email is None: @@ -441,26 +451,67 @@ class AccountService: if cls.reset_password_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import PasswordResetRateLimitExceededError - raise PasswordResetRateLimitExceededError() + raise PasswordResetRateLimitExceededError(int(cls.reset_password_rate_limiter.time_window / 60)) code, token = cls.generate_reset_password_token(account_email, account) - send_reset_password_mail_task.delay( - language=language, - to=account_email, - code=code, - ) + if account: + send_reset_password_mail_task.delay( + language=language, + to=account_email, + code=code, + ) + else: + send_reset_password_mail_task_when_account_not_exist.delay( + language=language, + to=account_email, + is_allow_register=is_allow_register, + ) cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token + @classmethod + def send_email_register_email( + cls, + account: Account | None = None, + email: str | None = None, + language: str = "en-US", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + if cls.email_register_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import EmailRegisterRateLimitExceededError + + raise EmailRegisterRateLimitExceededError(int(cls.email_register_rate_limiter.time_window / 60)) + + code, token = cls.generate_email_register_token(account_email) + + if account: + send_email_register_mail_task_when_account_exist.delay( + language=language, + to=account_email, + account_name=account.name, + ) + + else: + send_email_register_mail_task.delay( + language=language, + to=account_email, + code=code, + ) + cls.email_register_rate_limiter.increment_rate_limit(account_email) + return token + @classmethod def send_change_email_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, - old_email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, + old_email: str | None = None, language: str = "en-US", - phase: Optional[str] = None, + phase: str | None = None, ): account_email = account.email if account else email if account_email is None: @@ -471,7 +522,7 @@ class AccountService: if cls.change_email_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import EmailChangeRateLimitExceededError - raise EmailChangeRateLimitExceededError() + raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60)) code, token = cls.generate_change_email_token(account_email, account, old_email=old_email) @@ -487,8 +538,8 @@ class AccountService: @classmethod def send_change_email_completed_notify_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", ): account_email = account.email if account else email @@ -503,10 +554,10 @@ class AccountService: @classmethod def send_owner_transfer_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", - workspace_name: Optional[str] = "", + workspace_name: str | None = "", ): account_email = account.email if account else email if account_email is None: @@ -515,7 +566,7 @@ class AccountService: if cls.owner_transfer_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import OwnerTransferRateLimitExceededError - raise OwnerTransferRateLimitExceededError() + raise OwnerTransferRateLimitExceededError(int(cls.owner_transfer_rate_limiter.time_window / 60)) code, token = cls.generate_owner_transfer_token(account_email, account) workspace_name = workspace_name or "" @@ -532,10 +583,10 @@ class AccountService: @classmethod def send_old_owner_transfer_notify_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", - workspace_name: Optional[str] = "", + workspace_name: str | None = "", new_owner_email: str = "", ): account_email = account.email if account else email @@ -553,10 +604,10 @@ class AccountService: @classmethod def send_new_owner_transfer_notify_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", - workspace_name: Optional[str] = "", + workspace_name: str | None = "", ): account_email = account.email if account else email if account_email is None: @@ -573,8 +624,8 @@ class AccountService: def generate_reset_password_token( cls, email: str, - account: Optional[Account] = None, - code: Optional[str] = None, + account: Account | None = None, + code: str | None = None, additional_data: dict[str, Any] = {}, ): if not code: @@ -585,13 +636,26 @@ class AccountService: ) return code, token + @classmethod + def generate_email_register_token( + cls, + email: str, + code: str | None = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) + additional_data["code"] = code + token = TokenManager.generate_token(email=email, token_type="email_register", additional_data=additional_data) + return code, token + @classmethod def generate_change_email_token( cls, email: str, - account: Optional[Account] = None, - code: Optional[str] = None, - old_email: Optional[str] = None, + account: Account | None = None, + code: str | None = None, + old_email: str | None = None, additional_data: dict[str, Any] = {}, ): if not code: @@ -607,8 +671,8 @@ class AccountService: def generate_owner_transfer_token( cls, email: str, - account: Optional[Account] = None, - code: Optional[str] = None, + account: Account | None = None, + code: str | None = None, additional_data: dict[str, Any] = {}, ): if not code: @@ -623,6 +687,10 @@ class AccountService: def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password") + @classmethod + def revoke_email_register_token(cls, token: str): + TokenManager.revoke_token(token, "email_register") + @classmethod def revoke_change_email_token(cls, token: str): TokenManager.revoke_token(token, "change_email") @@ -632,22 +700,26 @@ class AccountService: TokenManager.revoke_token(token, "owner_transfer") @classmethod - def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_reset_password_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "reset_password") @classmethod - def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_email_register_data(cls, token: str) -> dict[str, Any] | None: + return TokenManager.get_token_data(token, "email_register") + + @classmethod + def get_change_email_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "change_email") @classmethod - def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_owner_transfer_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "owner_transfer") @classmethod def send_email_code_login_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", ): email = account.email if account else email @@ -656,7 +728,7 @@ class AccountService: if cls.email_code_login_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError - raise EmailCodeLoginRateLimitExceededError() + raise EmailCodeLoginRateLimitExceededError(int(cls.email_code_login_rate_limiter.time_window / 60)) code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) token = TokenManager.generate_token( @@ -671,7 +743,7 @@ class AccountService: return token @classmethod - def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "email_code_login") @classmethod @@ -742,6 +814,16 @@ class AccountService: count = int(count) + 1 redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count) + @staticmethod + @redis_fallback(default_return=None) + def add_email_register_error_rate_limit(email: str) -> None: + key = f"email_register_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, dify_config.EMAIL_REGISTER_LOCKOUT_DURATION, count) + @staticmethod @redis_fallback(default_return=False) def is_forgot_password_error_rate_limit(email: str) -> bool: @@ -761,6 +843,24 @@ class AccountService: key = f"forgot_password_error_rate_limit:{email}" redis_client.delete(key) + @staticmethod + @redis_fallback(default_return=False) + def is_email_register_error_rate_limit(email: str) -> bool: + key = f"email_register_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + count = int(count) + if count > AccountService.EMAIL_REGISTER_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + @redis_fallback(default_return=None) + def reset_email_register_error_rate_limit(email: str): + key = f"email_register_error_rate_limit:{email}" + redis_client.delete(key) + @staticmethod @redis_fallback(default_return=None) def add_change_email_error_rate_limit(email: str): @@ -865,7 +965,7 @@ class AccountService: class TenantService: @staticmethod - def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant: + def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant: """Create tenant""" if ( not FeatureService.get_system_features().is_allow_create_workspace @@ -896,9 +996,7 @@ class TenantService: return tenant @staticmethod - def create_owner_tenant_if_not_exist( - account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False - ): + def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False): """Check if user have a workspace or not""" available_ta = ( db.session.query(TenantAccountJoin) @@ -970,7 +1068,7 @@ class TenantService: return tenant @staticmethod - def switch_tenant(account: Account, tenant_id: Optional[str] = None): + def switch_tenant(account: Account, tenant_id: str | None = None): """Switch the current workspace for the account""" # Ensure tenant_id is provided @@ -1052,7 +1150,7 @@ class TenantService: ) @staticmethod - def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountRole]: + def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None: """Get the role of the current account for a given tenant""" join = ( db.session.query(TenantAccountJoin) @@ -1192,13 +1290,13 @@ class RegisterService: cls, email, name, - password: Optional[str] = None, - open_id: Optional[str] = None, - provider: Optional[str] = None, - language: Optional[str] = None, - status: Optional[AccountStatus] = None, - is_setup: Optional[bool] = False, - create_workspace_required: Optional[bool] = True, + password: str | None = None, + open_id: str | None = None, + provider: str | None = None, + language: str | None = None, + status: AccountStatus | None = None, + is_setup: bool | None = False, + create_workspace_required: bool | None = True, ) -> Account: db.session.begin_nested() """Register account""" @@ -1315,10 +1413,8 @@ class RegisterService: redis_client.delete(cls._get_invitation_token_key(token)) @classmethod - def get_invitation_if_token_valid( - cls, workspace_id: Optional[str], email: str, token: str - ) -> Optional[dict[str, Any]]: - invitation_data = cls._get_invitation_by_token(token, workspace_id, email) + def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None: + invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -1355,9 +1451,9 @@ class RegisterService: } @classmethod - def _get_invitation_by_token( - cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None - ) -> Optional[dict[str, str]]: + def get_invitation_by_token( + cls, token: str, workspace_id: str | None = None, email: str | None = None + ) -> dict[str, str] | None: if workspace_id is not None and email is not None: email_hash = sha256(email.encode()).hexdigest() cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 6f0ab2546a..f2ffa3b170 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -32,14 +32,14 @@ class AdvancedPromptTemplateService: def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str): context_prompt = copy.deepcopy(CONTEXT) - if app_mode == AppMode.CHAT.value: + if app_mode == AppMode.CHAT: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt ) elif model_mode == "chat": return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - elif app_mode == AppMode.COMPLETION.value: + elif app_mode == AppMode.COMPLETION: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt @@ -73,7 +73,7 @@ class AdvancedPromptTemplateService: def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str): baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) - if app_mode == AppMode.CHAT.value: + if app_mode == AppMode.CHAT: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt @@ -82,7 +82,7 @@ class AdvancedPromptTemplateService: return cls.get_chat_prompt( copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt ) - elif app_mode == AppMode.COMPLETION.value: + elif app_mode == AppMode.COMPLETION: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 8578f38a0d..d631ce812f 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -1,5 +1,5 @@ import threading -from typing import Any, Optional +from typing import Any import pytz @@ -35,7 +35,7 @@ class AgentService: if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Optional[Message] = ( + message: Message | None = ( db.session.query(Message) .where( Message.id == message_id, diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index ba86a31240..9feca7337f 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,5 +1,4 @@ import uuid -from typing import Optional import pandas as pd from sqlalchemy import or_, select @@ -42,7 +41,7 @@ class AppAnnotationService: if not message: raise NotFound("Message Not Exists.") - annotation: Optional[MessageAnnotation] = message.annotation + annotation: MessageAnnotation | None = message.annotation # save the message annotation if annotation: annotation.content = args["answer"] @@ -263,11 +262,9 @@ class AppAnnotationService: db.session.delete(annotation) - annotation_hit_histories = ( - db.session.query(AppAnnotationHitHistory) - .where(AppAnnotationHitHistory.annotation_id == annotation_id) - .all() - ) + annotation_hit_histories = db.session.scalars( + select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id) + ).all() if annotation_hit_histories: for annotation_hit_history in annotation_hit_histories: db.session.delete(annotation_hit_history) @@ -349,7 +346,7 @@ class AppAnnotationService: try: # Skip the first row - df = pd.read_csv(file, dtype=str) + df = pd.read_csv(file.stream, dtype=str) result = [] for _, row in df.iterrows(): content = {"question": row.iloc[0], "answer": row.iloc[1]} @@ -463,15 +460,23 @@ class AppAnnotationService: annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail - return { - "id": annotation_setting.id, - "enabled": True, - "score_threshold": annotation_setting.score_threshold, - "embedding_model": { - "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name, - }, - } + if collection_binding_detail: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } + else: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": {}, + } return {"enabled": False} @classmethod @@ -506,15 +511,23 @@ class AppAnnotationService: collection_binding_detail = annotation_setting.collection_binding_detail - return { - "id": annotation_setting.id, - "enabled": True, - "score_threshold": annotation_setting.score_threshold, - "embedding_model": { - "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name, - }, - } + if collection_binding_detail: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } + else: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": {}, + } @classmethod def clear_all_annotations(cls, app_id: str): diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index aaf7e3ab5a..8701fe4f4e 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -4,7 +4,6 @@ import logging import uuid from collections.abc import Mapping from enum import StrEnum -from typing import Optional from urllib.parse import urlparse from uuid import uuid4 @@ -61,8 +60,8 @@ class ImportStatus(StrEnum): class Import(BaseModel): id: str status: ImportStatus - app_id: Optional[str] = None - app_mode: Optional[str] = None + app_id: str | None = None + app_mode: str | None = None current_dsl_version: str = CURRENT_DSL_VERSION imported_dsl_version: str = "" error: str = "" @@ -99,17 +98,17 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus: class PendingData(BaseModel): import_mode: str yaml_content: str - name: str | None - description: str | None - icon_type: str | None - icon: str | None - icon_background: str | None - app_id: str | None + name: str | None = None + description: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + app_id: str | None = None class CheckDependenciesPendingData(BaseModel): dependencies: list[PluginDependency] - app_id: str | None + app_id: str | None = None class AppDslService: @@ -121,14 +120,14 @@ class AppDslService: *, account: Account, import_mode: str, - yaml_content: Optional[str] = None, - yaml_url: Optional[str] = None, - name: Optional[str] = None, - description: Optional[str] = None, - icon_type: Optional[str] = None, - icon: Optional[str] = None, - icon_background: Optional[str] = None, - app_id: Optional[str] = None, + yaml_content: str | None = None, + yaml_url: str | None = None, + name: str | None = None, + description: str | None = None, + icon_type: str | None = None, + icon: str | None = None, + icon_background: str | None = None, + app_id: str | None = None, ) -> Import: """Import an app from YAML content or URL.""" import_id = str(uuid.uuid4()) @@ -407,15 +406,15 @@ class AppDslService: def _create_or_update_app( self, *, - app: Optional[App], + app: App | None, data: dict, account: Account, - name: Optional[str] = None, - description: Optional[str] = None, - icon_type: Optional[str] = None, - icon: Optional[str] = None, - icon_background: Optional[str] = None, - dependencies: Optional[list[PluginDependency]] = None, + name: str | None = None, + description: str | None = None, + icon_type: str | None = None, + icon: str | None = None, + icon_background: str | None = None, + dependencies: list[PluginDependency] | None = None, ) -> App: """Create a new app or update an existing one.""" app_data = data.get("app", {}) @@ -533,7 +532,7 @@ class AppDslService: return app @classmethod - def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: Optional[str] = None) -> str: + def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: str | None = None) -> str: """ Export app :param app_model: App instance @@ -566,7 +565,7 @@ class AppDslService: @classmethod def _append_workflow_export_data( - cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: Optional[str] = None + cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None ): """ Append workflow export data diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 4fa91ef682..8911da4728 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Any, Optional, Union +from typing import Any, Union from openai._exceptions import RateLimitError @@ -60,7 +60,7 @@ class AppGenerateService: request_id = RateLimit.gen_request_key() try: request_id = rate_limit.enter(request_id) - if app_model.mode == AppMode.COMPLETION.value: + if app_model.mode == AppMode.COMPLETION: return rate_limit.generate( CompletionAppGenerator.convert_to_event_stream( CompletionAppGenerator().generate( @@ -69,7 +69,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: return rate_limit.generate( AgentChatAppGenerator.convert_to_event_stream( AgentChatAppGenerator().generate( @@ -78,7 +78,7 @@ class AppGenerateService: ), request_id, ) - elif app_model.mode == AppMode.CHAT.value: + elif app_model.mode == AppMode.CHAT: return rate_limit.generate( ChatAppGenerator.convert_to_event_stream( ChatAppGenerator().generate( @@ -87,7 +87,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.ADVANCED_CHAT.value: + elif app_model.mode == AppMode.ADVANCED_CHAT: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( @@ -103,7 +103,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( @@ -154,14 +154,14 @@ class AppGenerateService: @classmethod def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( AdvancedChatAppGenerator().single_iteration_generate( app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( WorkflowAppGenerator().single_iteration_generate( @@ -173,14 +173,14 @@ class AppGenerateService: @classmethod def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( AdvancedChatAppGenerator().single_loop_generate( app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( WorkflowAppGenerator().single_loop_generate( @@ -213,7 +213,7 @@ class AppGenerateService: ) @classmethod - def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow: + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: str | None = None) -> Workflow: """ Get workflow :param app_model: app model diff --git a/api/services/app_service.py b/api/services/app_service.py index 9b200a570d..ab2b38ec01 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, TypedDict, cast +from typing import TypedDict, cast from flask_sqlalchemy.pagination import Pagination @@ -40,15 +40,15 @@ class AppService: filters = [App.tenant_id == tenant_id, App.is_universal == False] if args["mode"] == "workflow": - filters.append(App.mode == AppMode.WORKFLOW.value) + filters.append(App.mode == AppMode.WORKFLOW) elif args["mode"] == "completion": - filters.append(App.mode == AppMode.COMPLETION.value) + filters.append(App.mode == AppMode.COMPLETION) elif args["mode"] == "chat": - filters.append(App.mode == AppMode.CHAT.value) + filters.append(App.mode == AppMode.CHAT) elif args["mode"] == "advanced-chat": - filters.append(App.mode == AppMode.ADVANCED_CHAT.value) + filters.append(App.mode == AppMode.ADVANCED_CHAT) elif args["mode"] == "agent-chat": - filters.append(App.mode == AppMode.AGENT_CHAT.value) + filters.append(App.mode == AppMode.AGENT_CHAT) if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) @@ -171,7 +171,7 @@ class AppService: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None # get original app model config - if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: + if app.mode == AppMode.AGENT_CHAT or app.is_agent: model_config = app.app_model_config if not model_config: return app @@ -370,7 +370,7 @@ class AppService: } ) else: - app_model_config: Optional[AppModelConfig] = app_model.app_model_config + app_model_config: AppModelConfig | None = app_model.app_model_config if not app_model_config: return meta @@ -393,7 +393,7 @@ class AppService: meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" elif provider_type == "api": try: - provider: Optional[ApiToolProvider] = ( + provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first() ) if provider is None: diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 9b1999d813..1158fc5197 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -2,7 +2,6 @@ import io import logging import uuid from collections.abc import Generator -from typing import Optional from flask import Response, stream_with_context from werkzeug.datastructures import FileStorage @@ -30,8 +29,8 @@ logger = logging.getLogger(__name__) class AudioService: @classmethod - def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None): + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise ValueError("Speech to text is not enabled") @@ -77,18 +76,18 @@ class AudioService: def transcript_tts( cls, app_model: App, - text: Optional[str] = None, - voice: Optional[str] = None, - end_user: Optional[str] = None, - message_id: Optional[str] = None, + text: str | None = None, + voice: str | None = None, + end_user: str | None = None, + message_id: str | None = None, is_draft: bool = False, ): from app import app - def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False): + def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False): with app.app_context(): if voice is None: - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if is_draft: workflow = WorkflowService().get_draft_workflow(app_model=app_model) else: diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index f6e960b413..055cf65816 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -1,5 +1,7 @@ import json +from sqlalchemy import select + from core.helper import encrypter from extensions.ext_database import db from models.source import DataSourceApiKeyAuthBinding @@ -9,11 +11,11 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory class ApiKeyAuthService: @staticmethod def get_provider_auth_list(tenant_id: str): - data_source_api_key_bindings = ( - db.session.query(DataSourceApiKeyAuthBinding) - .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) - .all() - ) + data_source_api_key_bindings = db.session.scalars( + select(DataSourceApiKeyAuthBinding).where( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False) + ) + ).all() return data_source_api_key_bindings @staticmethod diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 066bed3234..a364862a88 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,5 +1,5 @@ import os -from typing import Literal, Optional +from typing import Literal import httpx from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed @@ -73,7 +73,7 @@ class BillingService: def is_tenant_owner_or_admin(current_user: Account): tenant_id = current_user.current_tenant_id - join: Optional[TenantAccountJoin] = ( + join: TenantAccountJoin | None = ( db.session.query(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .first() diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 2f1b63664f..f8f89d7428 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app +from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -115,7 +116,7 @@ class ClearFreePlanTenantExpiredLogs: @classmethod def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): with flask_app.app_context(): - apps = db.session.query(App).where(App.tenant_id == tenant_id).all() + apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all() app_ids = [app.id for app in apps] while True: with Session(db.engine).no_autoflush as session: @@ -407,6 +408,7 @@ class ClearFreePlanTenantExpiredLogs: datetime.timedelta(hours=1), ] + tenant_count = 0 for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index d017ce54ab..a8e51a426d 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,7 +1,7 @@ import contextlib import logging from collections.abc import Callable, Sequence -from typing import Any, Optional, Union +from typing import Any, Union from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -36,12 +36,12 @@ class ConversationService: *, session: Session, app_model: App, - user: Optional[Union[Account, EndUser]], - last_id: Optional[str], + user: Union[Account, EndUser] | None, + last_id: str | None, limit: int, invoke_from: InvokeFrom, - include_ids: Optional[Sequence[str]] = None, - exclude_ids: Optional[Sequence[str]] = None, + include_ids: Sequence[str] | None = None, + exclude_ids: Sequence[str] | None = None, sort_by: str = "-updated_at", ) -> InfiniteScrollPagination: if not user: @@ -118,7 +118,7 @@ class ConversationService: cls, app_model: App, conversation_id: str, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, name: str, auto_generate: bool, ): @@ -158,7 +158,7 @@ class ConversationService: return conversation @classmethod - def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): conversation = ( db.session.query(Conversation) .where( @@ -178,7 +178,7 @@ class ConversationService: return conversation @classmethod - def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): try: logger.info( "Initiating conversation deletion for app_name %s, conversation_id: %s", @@ -200,9 +200,9 @@ class ConversationService: cls, app_model: App, conversation_id: str, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, limit: int, - last_id: Optional[str], + last_id: str | None, ) -> InfiniteScrollPagination: conversation = cls.get_conversation(app_model, conversation_id, user) @@ -222,8 +222,8 @@ class ConversationService: # Filter for variables created after the last_id stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at) - # Apply limit to query - query_stmt = stmt.limit(limit) # Get one extra to check if there are more + # Apply limit to query: fetch one extra row to determine has_more + query_stmt = stmt.limit(limit + 1) rows = session.scalars(query_stmt).all() has_more = False @@ -248,7 +248,7 @@ class ConversationService: app_model: App, conversation_id: str, variable_id: str, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, new_value: Any, ): """ diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 03757fe4a5..d0d3c2d426 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,8 @@ import secrets import time import uuid from collections import Counter -from typing import Any, Literal, Optional +from collections.abc import Sequence +from typing import Any, Literal import sqlalchemy as sa import yaml @@ -47,6 +48,7 @@ from models.dataset import ( ) from models.model import UploadFile from models.provider_ids import ModelProviderID +from models.source import DataSourceOauthBinding from models.workflow import Workflow from services.entities.knowledge_entities.knowledge_entities import ( ChildChunkUpdateArgs, @@ -144,11 +146,14 @@ class DatasetService: # Check if tag_ids is not empty to avoid WHERE false condition if tag_ids and len(tag_ids) > 0: - target_ids = TagService.get_target_ids_by_tag_ids( - "knowledge", - tenant_id, # ty: ignore [invalid-argument-type] - tag_ids, - ) + if tenant_id is not None: + target_ids = TagService.get_target_ids_by_tag_ids( + "knowledge", + tenant_id, + tag_ids, + ) + else: + target_ids = [] if target_ids and len(target_ids) > 0: query = query.where(Dataset.id.in_(target_ids)) else: @@ -191,16 +196,16 @@ class DatasetService: def create_empty_dataset( tenant_id: str, name: str, - description: Optional[str], - indexing_technique: Optional[str], + description: str | None, + indexing_technique: str | None, account: Account, - permission: Optional[str] = None, + permission: str | None = None, provider: str = "vendor", - external_knowledge_api_id: Optional[str] = None, - external_knowledge_id: Optional[str] = None, - embedding_model_provider: Optional[str] = None, - embedding_model_name: Optional[str] = None, - retrieval_model: Optional[RetrievalModel] = None, + external_knowledge_api_id: str | None = None, + external_knowledge_id: str | None = None, + embedding_model_provider: str | None = None, + embedding_model_name: str | None = None, + retrieval_model: RetrievalModel | None = None, ): # check if dataset name already exists if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): @@ -969,7 +974,7 @@ class DatasetService: raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod - def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): + def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None): if not dataset: raise ValueError("Dataset not found") @@ -1028,14 +1033,12 @@ class DatasetService: } # get recent 30 days auto disable logs start_date = datetime.datetime.now() - datetime.timedelta(days=30) - dataset_auto_disable_logs = ( - db.session.query(DatasetAutoDisableLog) - .where( + dataset_auto_disable_logs = db.session.scalars( + select(DatasetAutoDisableLog).where( DatasetAutoDisableLog.dataset_id == dataset_id, DatasetAutoDisableLog.created_at >= start_date, ) - .all() - ) + ).all() if dataset_auto_disable_logs: return { "document_ids": [log.document_id for log in dataset_auto_disable_logs], @@ -1156,7 +1159,7 @@ class DocumentService: } @staticmethod - def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: + def get_document(dataset_id: str, document_id: str | None = None) -> Document | None: if document_id: document = ( db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() @@ -1166,75 +1169,64 @@ class DocumentService: return None @staticmethod - def get_document_by_id(document_id: str) -> Optional[Document]: + def get_document_by_id(document_id: str) -> Document | None: document = db.session.query(Document).where(Document.id == document_id).first() return document @staticmethod - def get_document_by_ids(document_ids: list[str]) -> list[Document]: - documents = ( - db.session.query(Document) - .where( + def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where( Document.id.in_(document_ids), Document.enabled == True, Document.indexing_status == "completed", Document.archived == False, ) - .all() - ) + ).all() return documents @staticmethod - def get_document_by_dataset_id(dataset_id: str) -> list[Document]: - documents = ( - db.session.query(Document) - .where( + def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where( Document.dataset_id == dataset_id, Document.enabled == True, ) - .all() - ) + ).all() return documents @staticmethod - def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = ( - db.session.query(Document) - .where( + def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where( Document.dataset_id == dataset_id, Document.enabled == True, Document.indexing_status == "completed", Document.archived == False, ) - .all() - ) + ).all() return documents @staticmethod - def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = ( - db.session.query(Document) - .where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) - .all() - ) + def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + ).all() return documents @staticmethod - def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: + def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]: assert isinstance(current_user, Account) - - documents = ( - db.session.query(Document) - .where( + documents = db.session.scalars( + select(Document).where( Document.batch == batch, Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id, ) - .all() - ) + ).all() return documents @@ -1271,13 +1263,14 @@ class DocumentService: # Check if document_ids is not empty to avoid WHERE false condition if not document_ids or len(document_ids) == 0: return - documents = db.session.query(Document).where(Document.id.in_(document_ids)).all() + documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all() file_ids = [ document.data_source_info_dict.get("upload_file_id", "") for document in documents if document.data_source_type == "upload_file" and document.data_source_info_dict ] - batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) + if dataset.doc_form is not None: + batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) for document in documents: db.session.delete(document) @@ -1302,7 +1295,7 @@ class DocumentService: if dataset.built_in_field_enabled: if document.doc_metadata: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata[BuiltInField.document_name.value] = name + doc_metadata[BuiltInField.document_name] = name document.doc_metadata = doc_metadata document.name = name @@ -1397,7 +1390,7 @@ class DocumentService: dataset: Dataset, knowledge_config: KnowledgeConfig, account: Account | Any, - dataset_process_rule: Optional[DatasetProcessRule] = None, + dataset_process_rule: DatasetProcessRule | None = None, created_from: str = "web", ) -> tuple[list[Document], str]: # check doc_form @@ -2025,7 +2018,7 @@ class DocumentService: dataset: Dataset, document_data: KnowledgeConfig, account: Account, - dataset_process_rule: Optional[DatasetProcessRule] = None, + dataset_process_rule: DatasetProcessRule | None = None, created_from: str = "web", ): assert isinstance(current_user, Account) @@ -2930,7 +2923,22 @@ class SegmentService: if segment.enabled: # send delete segment index task redis_client.setex(indexing_cache_key, 600, 1) - delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) + + # Get child chunk IDs before parent segment is deleted + child_node_ids = [] + if segment.index_node_id: + child_chunks = ( + db.session.query(ChildChunk.index_node_id) + .where( + ChildChunk.segment_id == segment.id, + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] + + delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids) + db.session.delete(segment) # update document word count assert document.word_count is not None @@ -2940,9 +2948,13 @@ class SegmentService: @classmethod def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): - assert isinstance(current_user, Account) - segments = ( - db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) + assert current_user is not None + # Check if segment_ids is not empty to avoid WHERE false condition + if not segment_ids or len(segment_ids) == 0: + return + segments_info = ( + db.session.query(DocumentSegment) + .with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count) .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, @@ -2952,11 +2964,29 @@ class SegmentService: .all() ) - if not segments: + if not segments_info: return - index_node_ids = [seg.index_node_id for seg in segments] - total_words = sum(seg.word_count for seg in segments) + index_node_ids = [info[0] for info in segments_info] + segment_db_ids = [info[1] for info in segments_info] + total_words = sum(info[2] for info in segments_info if info[2] is not None) + + # Get child chunk IDs before parent segments are deleted + child_node_ids = [] + if index_node_ids: + child_chunks = ( + db.session.query(ChildChunk.index_node_id) + .where( + ChildChunk.segment_id.in_(segment_db_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] + + # Start async cleanup with both parent and child node IDs + if index_node_ids or child_node_ids: + delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids) if document.word_count is None: document.word_count = 0 @@ -2965,7 +2995,7 @@ class SegmentService: db.session.add(document) - delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) + # Delete database records db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete() db.session.commit() @@ -2979,16 +3009,14 @@ class SegmentService: if not segment_ids or len(segment_ids) == 0: return if action == "enable": - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, DocumentSegment.enabled == False, ) - .all() - ) + ).all() if not segments: return real_deal_segment_ids = [] @@ -3006,16 +3034,14 @@ class SegmentService: enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) elif action == "disable": - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, DocumentSegment.enabled == True, ) - .all() - ) + ).all() if not segments: return real_deal_segment_ids = [] @@ -3087,16 +3113,13 @@ class SegmentService: dataset: Dataset, ) -> list[ChildChunk]: assert isinstance(current_user, Account) - - child_chunks = ( - db.session.query(ChildChunk) - .where( + child_chunks = db.session.scalars( + select(ChildChunk).where( ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, ChildChunk.segment_id == segment.id, ) - .all() - ) + ).all() child_chunks_map = {chunk.id: chunk for chunk in child_chunks} new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] @@ -3192,7 +3215,7 @@ class SegmentService: @classmethod def get_child_chunks( - cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None + cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: str | None = None ): assert isinstance(current_user, Account) @@ -3211,7 +3234,7 @@ class SegmentService: return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod - def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]: + def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None: """Get a child chunk by its ID.""" result = ( db.session.query(ChildChunk) @@ -3248,57 +3271,7 @@ class SegmentService: return paginated_segments.items, paginated_segments.total @classmethod - def update_segment_by_id( - cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str - ) -> tuple[DocumentSegment, Document]: - """Update a segment by its ID with validation and checks.""" - # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: - raise NotFound("Dataset not found.") - - # check user's model setting - DatasetService.check_dataset_model_setting(dataset) - - # check document - document = DocumentService.get_document(dataset_id, document_id) - if not document: - raise NotFound("Document not found.") - - # check embedding model setting if high quality - if dataset.indexing_technique == "high_quality": - try: - model_manager = ModelManager() - model_manager.get_model_instance( - tenant_id=user_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) - - # check segment - segment = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) - .first() - ) - if not segment: - raise NotFound("Segment not found.") - - # validate and update segment - cls.segment_create_args_validate(segment_data, document) - updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset) - - return updated_segment, document - - @classmethod - def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: + def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None: """Get a segment by its ID.""" result = ( db.session.query(DocumentSegment) @@ -3356,19 +3329,13 @@ class DatasetCollectionBindingService: class DatasetPermissionService: @classmethod def get_dataset_partial_member_list(cls, dataset_id): - user_list_query = ( - db.session.query( + user_list_query = db.session.scalars( + select( DatasetPermission.account_id, - ) - .where(DatasetPermission.dataset_id == dataset_id) - .all() - ) + ).where(DatasetPermission.dataset_id == dataset_id) + ).all() - user_list = [] - for user in user_list_query: - user_list.append(user.account_id) - - return user_list + return user_list_query @classmethod def update_partial_member_list(cls, tenant_id, dataset_id, user_list): diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py index 7f13fc3abb..817dbd95f8 100644 --- a/api/services/enterprise/plugin_manager_service.py +++ b/api/services/enterprise/plugin_manager_service.py @@ -6,10 +6,12 @@ from pydantic import BaseModel from services.enterprise.base import EnterprisePluginManagerRequest from services.errors.base import BaseServiceError +logger = logging.getLogger(__name__) + class PluginCredentialType(enum.Enum): - MODEL = 0 - TOOL = 1 + MODEL = 0 # must be 0 for API contract compatibility + TOOL = 1 # must be 1 for API contract compatibility def to_number(self): return self.value @@ -49,5 +51,7 @@ class PluginManagerService: logging.debug( "Credential policy compliance checked for %s with credential %s, result: %s", - body.provider, body.dify_credential_id, ret.get('result', False) + body.provider, + body.dify_credential_id, + ret.get("result", False), ) diff --git a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py index 4545f385eb..c9fb1c9e21 100644 --- a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py +++ b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Union +from typing import Literal, Union from pydantic import BaseModel @@ -11,7 +11,7 @@ class AuthorizationConfig(BaseModel): class Authorization(BaseModel): type: Literal["no-auth", "api-key"] - config: Optional[AuthorizationConfig] = None + config: AuthorizationConfig | None = None class ProcessStatusSetting(BaseModel): @@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel): class ExternalKnowledgeApiSetting(BaseModel): url: str request_method: str - headers: Optional[dict] = None - params: Optional[dict] = None + headers: dict | None = None + params: dict | None = None diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 26678c2e69..33f65bde58 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -11,14 +11,14 @@ class ParentMode(StrEnum): class NotionIcon(BaseModel): type: str - url: Optional[str] = None - emoji: Optional[str] = None + url: str | None = None + emoji: str | None = None class NotionPage(BaseModel): page_id: str page_name: str - page_icon: Optional[NotionIcon] = None + page_icon: NotionIcon | None = None type: str @@ -41,9 +41,9 @@ class FileInfo(BaseModel): class InfoList(BaseModel): data_source_type: Literal["upload_file", "notion_import", "website_crawl"] - notion_info_list: Optional[list[NotionInfo]] = None - file_info_list: Optional[FileInfo] = None - website_info_list: Optional[WebsiteInfo] = None + notion_info_list: list[NotionInfo] | None = None + file_info_list: FileInfo | None = None + website_info_list: WebsiteInfo | None = None class DataSource(BaseModel): @@ -62,20 +62,20 @@ class Segmentation(BaseModel): class Rule(BaseModel): - pre_processing_rules: Optional[list[PreProcessingRule]] = None - segmentation: Optional[Segmentation] = None - parent_mode: Optional[Literal["full-doc", "paragraph"]] = None - subchunk_segmentation: Optional[Segmentation] = None + pre_processing_rules: list[PreProcessingRule] | None = None + segmentation: Segmentation | None = None + parent_mode: Literal["full-doc", "paragraph"] | None = None + subchunk_segmentation: Segmentation | None = None class ProcessRule(BaseModel): mode: Literal["automatic", "custom", "hierarchical"] - rules: Optional[Rule] = None + rules: Rule | None = None class RerankingModel(BaseModel): - reranking_provider_name: Optional[str] = None - reranking_model_name: Optional[str] = None + reranking_provider_name: str | None = None + reranking_model_name: str | None = None class WeightVectorSetting(BaseModel): @@ -89,20 +89,20 @@ class WeightKeywordSetting(BaseModel): class WeightModel(BaseModel): - weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None - vector_setting: Optional[WeightVectorSetting] = None - keyword_setting: Optional[WeightKeywordSetting] = None + weight_type: Literal["semantic_first", "keyword_first", "customized"] | None = None + vector_setting: WeightVectorSetting | None = None + keyword_setting: WeightKeywordSetting | None = None class RetrievalModel(BaseModel): search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"] reranking_enable: bool - reranking_model: Optional[RerankingModel] = None - reranking_mode: Optional[str] = None + reranking_model: RerankingModel | None = None + reranking_mode: str | None = None top_k: int score_threshold_enabled: bool - score_threshold: Optional[float] = None - weights: Optional[WeightModel] = None + score_threshold: float | None = None + weights: WeightModel | None = None class MetaDataConfig(BaseModel): @@ -111,29 +111,29 @@ class MetaDataConfig(BaseModel): class KnowledgeConfig(BaseModel): - original_document_id: Optional[str] = None + original_document_id: str | None = None duplicate: bool = True indexing_technique: Literal["high_quality", "economy"] - data_source: Optional[DataSource] = None - process_rule: Optional[ProcessRule] = None - retrieval_model: Optional[RetrievalModel] = None + data_source: DataSource | None = None + process_rule: ProcessRule | None = None + retrieval_model: RetrievalModel | None = None doc_form: str = "text_model" doc_language: str = "English" - embedding_model: Optional[str] = None - embedding_model_provider: Optional[str] = None - name: Optional[str] = None + embedding_model: str | None = None + embedding_model_provider: str | None = None + name: str | None = None class SegmentUpdateArgs(BaseModel): - content: Optional[str] = None - answer: Optional[str] = None - keywords: Optional[list[str]] = None + content: str | None = None + answer: str | None = None + keywords: list[str] | None = None regenerate_child_chunks: bool = False - enabled: Optional[bool] = None + enabled: bool | None = None class ChildChunkUpdateArgs(BaseModel): - id: Optional[str] = None + id: str | None = None content: str @@ -144,13 +144,13 @@ class MetadataArgs(BaseModel): class MetadataUpdateArgs(BaseModel): name: str - value: Optional[str | int | float] = None + value: str | int | float | None = None class MetadataDetail(BaseModel): id: str name: str - value: Optional[str | int | float] = None + value: str | int | float | None = None class DocumentMetadataOperation(BaseModel): diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 647052d739..49d48f044c 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -42,11 +41,11 @@ class CustomConfigurationResponse(BaseModel): """ status: CustomConfigurationStatus - current_credential_id: Optional[str] = None - current_credential_name: Optional[str] = None - available_credentials: Optional[list[CredentialConfiguration]] = None - custom_models: Optional[list[CustomModelConfiguration]] = None - can_added_models: Optional[list[UnaddedModelConfiguration]] = None + current_credential_id: str | None = None + current_credential_name: str | None = None + available_credentials: list[CredentialConfiguration] | None = None + custom_models: list[CustomModelConfiguration] | None = None + can_added_models: list[UnaddedModelConfiguration] | None = None class SystemConfigurationResponse(BaseModel): @@ -55,7 +54,7 @@ class SystemConfigurationResponse(BaseModel): """ enabled: bool - current_quota_type: Optional[ProviderQuotaType] = None + current_quota_type: ProviderQuotaType | None = None quota_configurations: list[QuotaConfiguration] = [] @@ -67,15 +66,15 @@ class ProviderResponse(BaseModel): tenant_id: str provider: str label: I18nObject - description: Optional[I18nObject] = None - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None - background: Optional[str] = None - help: Optional[ProviderHelpEntity] = None + description: I18nObject | None = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None + background: str | None = None + help: ProviderHelpEntity | None = None supported_model_types: list[ModelType] configurate_methods: list[ConfigurateMethod] - provider_credential_schema: Optional[ProviderCredentialSchema] = None - model_credential_schema: Optional[ModelCredentialSchema] = None + provider_credential_schema: ProviderCredentialSchema | None = None + model_credential_schema: ModelCredentialSchema | None = None preferred_provider_type: ProviderType custom_configuration: CustomConfigurationResponse system_configuration: SystemConfigurationResponse @@ -108,8 +107,8 @@ class ProviderWithModelsResponse(BaseModel): tenant_id: str provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None status: CustomConfigurationStatus models: list[ProviderModelWithStatusEntity] diff --git a/api/services/errors/base.py b/api/services/errors/base.py index 35ea28468e..0f9631190f 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,6 +1,3 @@ -from typing import Optional - - class BaseServiceError(ValueError): - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description diff --git a/api/services/errors/llm.py b/api/services/errors/llm.py index ca4c9a611d..5bf34f3aa6 100644 --- a/api/services/errors/llm.py +++ b/api/services/errors/llm.py @@ -1,12 +1,9 @@ -from typing import Optional - - class InvokeError(Exception): """Base class for all LLM exceptions.""" - description: Optional[str] = None + description: str | None = None - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description def __str__(self): diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 3262a00663..b6ba3bafea 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from urllib.parse import urlparse import httpx @@ -100,7 +100,7 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: - external_knowledge_api: Optional[ExternalKnowledgeApis] = ( + external_knowledge_api: ExternalKnowledgeApis | None = ( db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first() ) if external_knowledge_api is None: @@ -109,7 +109,7 @@ class ExternalDatasetService: @staticmethod def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: - external_knowledge_api: Optional[ExternalKnowledgeApis] = ( + external_knowledge_api: ExternalKnowledgeApis | None = ( db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() ) if external_knowledge_api is None: @@ -151,7 +151,7 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: - external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ( + external_knowledge_binding: ExternalKnowledgeBindings | None = ( db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() ) if not external_knowledge_binding: @@ -181,7 +181,7 @@ class ExternalDatasetService: do http request depending on api bundle """ - kwargs = { + kwargs: dict[str, Any] = { "url": settings.url, "headers": settings.headers, "follow_redirects": True, @@ -203,7 +203,7 @@ class ExternalDatasetService: return response @staticmethod - def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]: + def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]: authorization = deepcopy(authorization) if headers: headers = deepcopy(headers) @@ -277,7 +277,7 @@ class ExternalDatasetService: dataset_id: str, query: str, external_retrieval_parameters: dict, - metadata_condition: Optional[MetadataCondition] = None, + metadata_condition: MetadataCondition | None = None, ): external_knowledge_binding = ( db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() diff --git a/api/services/file_service.py b/api/services/file_service.py index f9d4eb5686..5708efba3c 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,7 +1,7 @@ import hashlib import os import uuid -from typing import Any, Literal, Union +from typing import Literal, Union from sqlalchemy import Engine from sqlalchemy.orm import sessionmaker @@ -46,7 +46,7 @@ class FileService: filename: str, content: bytes, mimetype: str, - user: Union[Account, EndUser, Any], + user: Union[Account, EndUser], source: Literal["datasets"] | None = None, source_url: str = "", ) -> UploadFile: diff --git a/api/services/message_service.py b/api/services/message_service.py index 13c8e948ca..e2e27443ba 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Union +from typing import Union from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom @@ -29,9 +29,9 @@ class MessageService: def pagination_by_first_id( cls, app_model: App, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, conversation_id: str, - first_id: Optional[str], + first_id: str | None, limit: int, order: str = "asc", ) -> InfiniteScrollPagination: @@ -91,11 +91,11 @@ class MessageService: def pagination_by_last_id( cls, app_model: App, - user: Optional[Union[Account, EndUser]], - last_id: Optional[str], + user: Union[Account, EndUser] | None, + last_id: str | None, limit: int, - conversation_id: Optional[str] = None, - include_ids: Optional[list] = None, + conversation_id: str | None = None, + include_ids: list | None = None, ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -145,9 +145,9 @@ class MessageService: *, app_model: App, message_id: str, - user: Optional[Union[Account, EndUser]], - rating: Optional[str], - content: Optional[str], + user: Union[Account, EndUser] | None, + rating: str | None, + content: str | None, ): if not user: raise ValueError("user cannot be None") @@ -196,7 +196,7 @@ class MessageService: return [record.to_dict() for record in feedbacks] @classmethod - def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): message = ( db.session.query(Message) .where( @@ -216,7 +216,7 @@ class MessageService: @classmethod def get_suggested_questions_after_answer( - cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom + cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom ) -> list[Message]: if not user: raise ValueError("user cannot be None") @@ -229,7 +229,7 @@ class MessageService: model_manager = ModelManager() - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() if invoke_from == InvokeFrom.DEBUGGER: workflow = workflow_service.get_draft_workflow(app_model=app_model) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 05fa5a95bc..6add830813 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -1,6 +1,5 @@ import copy import logging -from typing import Optional from flask_login import current_user @@ -131,11 +130,11 @@ class MetadataService: @staticmethod def get_built_in_fields(): return [ - {"name": BuiltInField.document_name.value, "type": "string"}, - {"name": BuiltInField.uploader.value, "type": "string"}, - {"name": BuiltInField.upload_date.value, "type": "time"}, - {"name": BuiltInField.last_update_date.value, "type": "time"}, - {"name": BuiltInField.source.value, "type": "string"}, + {"name": BuiltInField.document_name, "type": "string"}, + {"name": BuiltInField.uploader, "type": "string"}, + {"name": BuiltInField.upload_date, "type": "time"}, + {"name": BuiltInField.last_update_date, "type": "time"}, + {"name": BuiltInField.source, "type": "string"}, ] @staticmethod @@ -153,11 +152,11 @@ class MetadataService: doc_metadata = {} else: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata[BuiltInField.document_name.value] = document.name - doc_metadata[BuiltInField.uploader.value] = document.uploader - doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() - doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() - doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value + doc_metadata[BuiltInField.document_name] = document.name + doc_metadata[BuiltInField.uploader] = document.uploader + doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp() + doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() + doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata db.session.add(document) dataset.built_in_field_enabled = True @@ -183,11 +182,11 @@ class MetadataService: doc_metadata = {} else: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata.pop(BuiltInField.document_name.value, None) - doc_metadata.pop(BuiltInField.uploader.value, None) - doc_metadata.pop(BuiltInField.upload_date.value, None) - doc_metadata.pop(BuiltInField.last_update_date.value, None) - doc_metadata.pop(BuiltInField.source.value, None) + doc_metadata.pop(BuiltInField.document_name, None) + doc_metadata.pop(BuiltInField.uploader, None) + doc_metadata.pop(BuiltInField.upload_date, None) + doc_metadata.pop(BuiltInField.last_update_date, None) + doc_metadata.pop(BuiltInField.source, None) document.doc_metadata = doc_metadata db.session.add(document) document_ids.append(document.id) @@ -211,11 +210,11 @@ class MetadataService: for metadata_value in operation.metadata_list: doc_metadata[metadata_value.name] = metadata_value.value if dataset.built_in_field_enabled: - doc_metadata[BuiltInField.document_name.value] = document.name - doc_metadata[BuiltInField.uploader.value] = document.uploader - doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() - doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() - doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value + doc_metadata[BuiltInField.document_name] = document.name + doc_metadata[BuiltInField.uploader] = document.uploader + doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp() + doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() + doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata db.session.add(document) db.session.commit() @@ -237,7 +236,7 @@ class MetadataService: redis_client.delete(lock_key) @staticmethod - def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]): + def knowledge_base_metadata_lock_check(dataset_id: str | None, document_id: str | None): if dataset_id: lock_key = f"dataset_metadata_lock_{dataset_id}" if redis_client.get(lock_key): diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index c638087f63..69da3bfb79 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,9 +1,9 @@ import json import logging from json import JSONDecodeError -from typing import Optional, Union +from typing import Union -from sqlalchemy import or_ +from sqlalchemy import or_, select from constants import HIDDEN_VALUE from core.entities.provider_configuration import ProviderConfiguration @@ -165,7 +165,7 @@ class ModelLoadBalancingService: try: if load_balancing_config.encrypted_config: - credentials = json.loads(load_balancing_config.encrypted_config) + credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config) else: credentials = {} except JSONDecodeError: @@ -180,11 +180,13 @@ class ModelLoadBalancingService: for variable in credential_secret_variables: if variable in credentials: try: - credentials[variable] = encrypter.decrypt_token_with_decoding( - credentials.get(variable), # ty: ignore [invalid-argument-type] - decoding_rsa_key, - decoding_cipher_rsa, - ) + token_value = credentials.get(variable) + if isinstance(token_value, str): + credentials[variable] = encrypter.decrypt_token_with_decoding( + token_value, + decoding_rsa_key, + decoding_cipher_rsa, + ) except ValueError: pass @@ -209,7 +211,7 @@ class ModelLoadBalancingService: def get_load_balancing_config( self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str - ) -> Optional[dict]: + ) -> dict | None: """ Get load balancing configuration. :param tenant_id: workspace id @@ -320,16 +322,14 @@ class ModelLoadBalancingService: if not isinstance(configs, list): raise ValueError("Invalid load balancing configs") - current_load_balancing_configs = ( - db.session.query(LoadBalancingModelConfig) - .where( + current_load_balancing_configs = db.session.scalars( + select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) - .all() - ) + ).all() # id as key, config as value current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} @@ -345,8 +345,9 @@ class ModelLoadBalancingService: credential_id = config.get("credential_id") enabled = config.get("enabled") + credential_record: ProviderCredential | ProviderModelCredential | None = None + if credential_id: - credential_record: ProviderCredential | ProviderModelCredential | None = None if config_from == "predefined-model": credential_record = ( db.session.query(ProviderCredential) @@ -477,7 +478,7 @@ class ModelLoadBalancingService: model: str, model_type: str, credentials: dict, - config_id: Optional[str] = None, + config_id: str | None = None, ): """ Validate load balancing credentials. @@ -535,7 +536,7 @@ class ModelLoadBalancingService: model_type: ModelType, model: str, credentials: dict, - load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, + load_balancing_model_config: LoadBalancingModelConfig | None = None, validate: bool = True, ): """ diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 510b1f1fe6..2901a0d273 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.model_runtime.entities.model_entities import ModelType, ParameterRule @@ -52,7 +51,7 @@ class ModelProviderService: return provider_configuration - def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]: + def get_provider_list(self, tenant_id: str, model_type: str | None = None) -> list[ProviderResponse]: """ get provider list. @@ -128,9 +127,7 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] - def get_provider_credential( - self, tenant_id: str, provider: str, credential_id: Optional[str] = None - ) -> Optional[dict]: + def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: """ get provider credentials. @@ -216,7 +213,7 @@ class ModelProviderService: def get_model_credential( self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None - ) -> Optional[dict]: + ) -> dict | None: """ Retrieve model-specific credentials. @@ -449,7 +446,7 @@ class ModelProviderService: return model_schema.parameter_rules if model_schema else [] - def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: + def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> DefaultModelResponse | None: """ get default model of model type. @@ -498,7 +495,7 @@ class ModelProviderService: def get_model_provider_icon( self, tenant_id: str, provider: str, icon_type: str, lang: str - ) -> tuple[Optional[bytes], Optional[str]]: + ) -> tuple[bytes | None, str | None]: """ get model provider icon. diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 2596e9f711..c214640653 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from core.ops.entities.config_entity import BaseTracingConfig from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map @@ -15,7 +15,7 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config_data: Optional[TraceAppConfig] = ( + trace_config_data: TraceAppConfig | None = ( db.session.query(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -153,7 +153,7 @@ class OpsService: project_url = None # check if trace config already exists - trace_config_data: Optional[TraceAppConfig] = ( + trace_config_data: TraceAppConfig | None = ( db.session.query(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() diff --git a/api/services/plugin/github_service.py b/api/services/plugin/github_service.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index b566035258..5db19711e6 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -5,7 +5,7 @@ import time from collections.abc import Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, Optional +from typing import Any from uuid import uuid4 import click @@ -101,6 +101,7 @@ class PluginMigration: datetime.timedelta(hours=1), ] + tenant_count = 0 for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) @@ -257,7 +258,7 @@ class PluginMigration: return [] agent_app_model_config_ids = [ - app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value + app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT ] rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all() @@ -282,7 +283,7 @@ class PluginMigration: return result @classmethod - def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]: + def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> str | None: """ Fetch plugin unique identifier using plugin id. """ diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index f405fbfe4c..604adeb7b5 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -1,7 +1,6 @@ import logging from collections.abc import Mapping, Sequence from mimetypes import guess_type -from typing import Optional from pydantic import BaseModel @@ -46,11 +45,11 @@ class PluginService: REDIS_TTL = 60 * 5 # 5 minutes @staticmethod - def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]: + def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]: """ Fetch the latest plugin version """ - result: dict[str, Optional[PluginService.LatestPluginCache]] = {} + result: dict[str, PluginService.LatestPluginCache | None] = {} try: cache_not_exists = [] @@ -109,7 +108,7 @@ class PluginService: raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only") @staticmethod - def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]): + def _check_plugin_installation_scope(plugin_verification: PluginVerification | None): """ Check the plugin installation scope """ @@ -144,7 +143,7 @@ class PluginService: return manager.get_debugging_key(tenant_id) @staticmethod - def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]: + def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]: """ List the latest versions of the plugins """ diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index df9e01e273..64751d186c 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -1,7 +1,6 @@ import json from os import path from pathlib import Path -from typing import Optional from flask import current_app @@ -14,7 +13,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): Retrieval recommended app from buildin, the location is constants/recommended_apps.json """ - builtin_data: Optional[dict] = None + builtin_data: dict | None = None def get_type(self) -> str: return RecommendAppType.BUILDIN @@ -54,7 +53,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): return builtin_data.get("recommended_apps", {}).get(language, {}) @classmethod - def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: + def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None: """ Fetch recommended app detail from builtin. :param app_id: App ID diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index e19f53f120..d0c49325dc 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -1,4 +1,4 @@ -from typing import Optional +from sqlalchemy import select from constants.languages import languages from extensions.ext_database import db @@ -31,18 +31,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): :param language: language :return: """ - recommended_apps = ( - db.session.query(RecommendedApp) - .where(RecommendedApp.is_listed == True, RecommendedApp.language == language) - .all() - ) + recommended_apps = db.session.scalars( + select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language) + ).all() if len(recommended_apps) == 0: - recommended_apps = ( - db.session.query(RecommendedApp) - .where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) - .all() - ) + recommended_apps = db.session.scalars( + select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + ).all() categories = set() recommended_apps_result = [] @@ -74,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} @classmethod - def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None: """ Fetch recommended app detail from db. :param app_id: App ID diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index 1e59287429..2d57769f63 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import requests @@ -36,7 +35,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): return RecommendAppType.REMOTE @classmethod - def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]: + def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None: """ Fetch recommended app detail from dify official. :param app_id: App ID diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index d9c1b51fa1..544383a106 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,5 +1,3 @@ -from typing import Optional - from configs import dify_config from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory @@ -25,7 +23,7 @@ class RecommendedAppService: return result @classmethod - def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: + def get_recommend_app_detail(cls, app_id: str) -> dict | None: """ Get recommend app detail. :param app_id: app id diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 641e03c3cf..67a0106bbd 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -11,7 +11,7 @@ from services.message_service import MessageService class SavedMessageService: @classmethod def pagination_by_last_id( - cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int + cls, app_model: App, user: Union[Account, EndUser] | None, last_id: str | None, limit: int ) -> InfiniteScrollPagination: if not user: raise ValueError("User is required") @@ -32,7 +32,7 @@ class SavedMessageService: ) @classmethod - def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): if not user: return saved_message = ( @@ -62,7 +62,7 @@ class SavedMessageService: db.session.commit() @classmethod - def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): if not user: return saved_message = ( diff --git a/api/services/tag_service.py b/api/services/tag_service.py index a16bdb46cd..4674335ba8 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,8 +1,7 @@ import uuid -from typing import Optional from flask_login import current_user -from sqlalchemy import func +from sqlalchemy import func, select from werkzeug.exceptions import NotFound from extensions.ext_database import db @@ -12,7 +11,7 @@ from models.model import App, Tag, TagBinding class TagService: @staticmethod - def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None): + def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None): query = ( db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) @@ -29,35 +28,30 @@ class TagService: # Check if tag_ids is not empty to avoid WHERE false condition if not tag_ids or len(tag_ids) == 0: return [] - tags = ( - db.session.query(Tag) - .where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) - .all() - ) + tags = db.session.scalars( + select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + ).all() if not tags: return [] tag_ids = [tag.id for tag in tags] # Check if tag_ids is not empty to avoid WHERE false condition if not tag_ids or len(tag_ids) == 0: return [] - tag_bindings = ( - db.session.query(TagBinding.target_id) - .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) - .all() - ) - if not tag_bindings: - return [] - results = [tag_binding.target_id for tag_binding in tag_bindings] - return results + tag_bindings = db.session.scalars( + select(TagBinding.target_id).where( + TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id + ) + ).all() + return tag_bindings @staticmethod def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str): if not tag_type or not tag_name: return [] - tags = ( - db.session.query(Tag) - .where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) - .all() + tags = list( + db.session.scalars( + select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + ).all() ) if not tags: return [] @@ -117,7 +111,7 @@ class TagService: raise NotFound("Tag not found") db.session.delete(tag) # delete tag binding - tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all() + tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all() if tag_bindings: for tag_binding in tag_bindings: db.session.delete(tag_binding) diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 78e587abee..f86d7e51bf 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from typing import Any, cast from httpx import get +from sqlalchemy import select from core.entities.provider_entities import ProviderConfig from core.model_runtime.utils.encoders import jsonable_encoder @@ -443,9 +444,7 @@ class ApiToolManageService: list api tools """ # get all api providers - db_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or [] - ) + db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() result: list[ToolProviderApiEntity] = [] diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 603165e822..6b0b6b0f0e 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,7 +2,7 @@ import json import logging from collections.abc import Mapping from pathlib import Path -from typing import Any, Optional +from typing import Any from sqlalchemy import exists, select from sqlalchemy.orm import Session @@ -223,8 +223,8 @@ class BuiltinToolManageService: """ add builtin tool provider """ - try: - with Session(db.engine) as session: + with Session(db.engine) as session: + try: lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" with redis_client.lock(lock, timeout=20): provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) @@ -285,9 +285,9 @@ class BuiltinToolManageService: session.add(db_provider) session.commit() - except Exception as e: - session.rollback() - raise ValueError(str(e)) + except Exception as e: + session.rollback() + raise ValueError(str(e)) return {"result": "success"} @staticmethod @@ -582,7 +582,7 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) @staticmethod - def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: + def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: """ This method is used to fetch the builtin provider from the database 1.if the default provider exists, return the default provider @@ -643,8 +643,8 @@ class BuiltinToolManageService: def save_custom_oauth_client_params( tenant_id: str, provider: str, - client_params: Optional[dict] = None, - enable_oauth_custom_client: Optional[bool] = None, + client_params: dict | None = None, + enable_oauth_custom_client: bool | None = None, ): """ setup oauth custom client diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 7e301c9bac..dd626dd615 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -259,11 +259,30 @@ class MCPToolManageService: if sse_read_timeout is not None: mcp_provider.sse_read_timeout = sse_read_timeout if headers is not None: - # Encrypt headers + # Merge masked headers from frontend with existing real values if headers: - encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id) + # existing decrypted and masked headers + existing_decrypted = mcp_provider.decrypted_headers + existing_masked = mcp_provider.masked_headers + + # Build final headers: if value equals masked existing, keep original decrypted value + final_headers: dict[str, str] = {} + for key, incoming_value in headers.items(): + if ( + key in existing_masked + and key in existing_decrypted + and isinstance(incoming_value, str) + and incoming_value == existing_masked.get(key) + ): + # unchanged, use original decrypted value + final_headers[key] = str(existing_decrypted[key]) + else: + final_headers[key] = incoming_value + + encrypted_headers_dict = MCPToolManageService._encrypt_headers(final_headers, tenant_id) mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict) else: + # Explicitly clear headers if empty dict passed mcp_provider.encrypted_headers = None db.session.commit() except IntegrityError as e: diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index f245dd7527..51e9120b8d 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.tool_manager import ToolManager @@ -10,7 +9,7 @@ logger = logging.getLogger(__name__) class ToolCommonService: @staticmethod - def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None): + def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None): """ list tool providers diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index bea62bbe9a..2325c707ff 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from yarl import URL @@ -107,7 +107,7 @@ class ToolTransformService: def builtin_provider_to_user_provider( cls, provider_controller: BuiltinToolProviderController | PluginToolProviderController, - db_provider: Optional[BuiltinToolProvider], + db_provider: BuiltinToolProvider | None, decrypt_credentials: bool = True, ) -> ToolProviderApiEntity: """ diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 2f8a91ed82..2449536d5c 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -3,7 +3,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any -from sqlalchemy import or_ +from sqlalchemy import or_, select from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController @@ -186,7 +186,9 @@ class WorkflowToolManageService: :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() + db_tools = db.session.scalars( + select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) + ).all() tools: list[WorkflowToolProviderController] = [] for provider in db_tools: diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 428abdde17..1c559f2c2b 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -19,7 +18,7 @@ logger = logging.getLogger(__name__) class VectorService: @classmethod def create_segments_vector( - cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str + cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents: list[Document] = [] @@ -79,7 +78,7 @@ class VectorService: index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) @classmethod - def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): + def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset): # update segment index task # format new index diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index c48e24f244..0f54e838f3 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -19,11 +19,11 @@ class WebConversationService: *, session: Session, app_model: App, - user: Optional[Union[Account, EndUser]], - last_id: Optional[str], + user: Union[Account, EndUser] | None, + last_id: str | None, limit: int, invoke_from: InvokeFrom, - pinned: Optional[bool] = None, + pinned: bool | None = None, sort_by="-updated_at", ) -> InfiniteScrollPagination: if not user: @@ -60,7 +60,7 @@ class WebConversationService: ) @classmethod - def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return pinned_conversation = ( @@ -92,7 +92,7 @@ class WebConversationService: db.session.commit() @classmethod - def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return pinned_conversation = ( diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index bb46bf3090..066dc9d741 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -1,7 +1,7 @@ import enum import secrets from datetime import UTC, datetime, timedelta -from typing import Any, Optional +from typing import Any from werkzeug.exceptions import NotFound, Unauthorized @@ -63,7 +63,7 @@ class WebAppAuthService: @classmethod def send_email_code_login_email( - cls, account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US" + cls, account: Account | None = None, email: str | None = None, language: str = "en-US" ): email = account.email if account else email if email is None: @@ -82,7 +82,7 @@ class WebAppAuthService: return token @classmethod - def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "email_code_login") @classmethod @@ -130,7 +130,7 @@ class WebAppAuthService: @classmethod def is_app_require_permission_check( - cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None + cls, app_code: str | None = None, app_id: str | None = None, access_mode: str | None = None ) -> bool: """ Check if the app requires permission check based on its access mode. diff --git a/api/services/website_service.py b/api/services/website_service.py index a905001e22..35a6cc52d6 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,7 +1,7 @@ import datetime import json from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import requests from flask_login import current_user @@ -21,9 +21,9 @@ class CrawlOptions: limit: int = 1 crawl_sub_pages: bool = False only_main_content: bool = False - includes: Optional[str] = None - excludes: Optional[str] = None - max_depth: Optional[int] = None + includes: str | None = None + excludes: str | None = None + max_depth: int | None = None use_sitemap: bool = True def get_include_paths(self) -> list[str]: diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 2994856b54..9ce5b6dbe0 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional +from typing import Any from core.app.app_config.entities import ( DatasetEntity, @@ -18,6 +18,7 @@ from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.nodes import NodeType from events.app_event import app_was_created from extensions.ext_database import db @@ -64,7 +65,7 @@ class WorkflowConverter: new_app = App() new_app.tenant_id = app_model.tenant_id new_app.name = name or app_model.name + "(workflow)" - new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW new_app.icon_type = icon_type or app_model.icon_type new_app.icon = icon or app_model.icon new_app.icon_background = icon_background or app_model.icon_background @@ -202,7 +203,7 @@ class WorkflowConverter: app_mode_enum = AppMode.value_of(app_model.mode) app_config: EasyUIBasedAppConfig if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: - app_model.mode = AppMode.AGENT_CHAT.value + app_model.mode = AppMode.AGENT_CHAT app_config = AgentChatAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) @@ -278,7 +279,7 @@ class WorkflowConverter: "app_id": app_model.id, "tool_variable": tool_variable, "inputs": inputs, - "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "", + "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT else "", }, } @@ -326,7 +327,7 @@ class WorkflowConverter: def _convert_to_knowledge_retrieval_node( self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity - ) -> Optional[dict]: + ) -> dict | None: """ Convert datasets to Knowledge Retrieval Node :param new_app_mode: new app mode @@ -382,7 +383,7 @@ class WorkflowConverter: graph: dict, model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, - file_upload: Optional[FileUploadConfig] = None, + file_upload: FileUploadConfig | None = None, external_data_variable_node_mapping: dict[str, str] | None = None, ): """ @@ -402,7 +403,7 @@ class WorkflowConverter: ) role_prefix = None - prompts: Optional[Any] = None + prompts: Any | None = None # Chat Model if model_config.mode == LLMMode.CHAT.value: @@ -420,7 +421,11 @@ class WorkflowConverter: query_in_prompt=False, ) - template = prompt_template_config["prompt_template"].template + prompt_template_obj = prompt_template_config["prompt_template"] + if not isinstance(prompt_template_obj, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}") + + template = prompt_template_obj.template if not template: prompts = [] else: @@ -457,7 +462,11 @@ class WorkflowConverter: query_in_prompt=False, ) - template = prompt_template_config["prompt_template"].template + prompt_template_obj = prompt_template_config["prompt_template"] + if not isinstance(prompt_template_obj, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}") + + template = prompt_template_obj.template template = self._replace_template_variables( template=template, variables=start_node["data"]["variables"], @@ -467,6 +476,9 @@ class WorkflowConverter: prompts = {"text": template} prompt_rules = prompt_template_config["prompt_rules"] + if not isinstance(prompt_rules, dict): + raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}") + role_prefix = { "user": prompt_rules.get("human_prefix", "Human"), "assistant": prompt_rules.get("assistant_prefix", "Assistant"), @@ -606,7 +618,7 @@ class WorkflowConverter: :param app_model: App instance :return: AppMode """ - if app_model.mode == AppMode.COMPLETION.value: + if app_model.mode == AppMode.COMPLETION: return AppMode.WORKFLOW else: return AppMode.ADVANCED_CHAT diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index e43999a8c9..79d91cab4c 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,6 +1,5 @@ import threading from collections.abc import Sequence -from typing import Optional from sqlalchemy.orm import sessionmaker @@ -80,7 +79,7 @@ class WorkflowRunService: last_id=last_id, ) - def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: + def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None: """ Get workflow run detail diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index dbd83324d7..906b5e3bab 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,7 +2,7 @@ import json import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import exists, select from sqlalchemy.orm import Session, sessionmaker @@ -85,7 +85,7 @@ class WorkflowService: ) return db.session.execute(stmt).scalar_one() - def get_draft_workflow(self, app_model: App, workflow_id: Optional[str] = None) -> Optional[Workflow]: + def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None: """ Get draft workflow """ @@ -105,7 +105,7 @@ class WorkflowService: # return draft workflow return workflow - def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None: """ fetch published workflow by workflow_id """ @@ -127,7 +127,7 @@ class WorkflowService: ) return workflow - def get_published_workflow(self, app_model: App) -> Optional[Workflow]: + def get_published_workflow(self, app_model: App) -> Workflow | None: """ Get published workflow """ @@ -192,7 +192,7 @@ class WorkflowService: app_model: App, graph: dict, features: dict, - unique_hash: Optional[str], + unique_hash: str | None, account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], @@ -559,7 +559,7 @@ class WorkflowService: return default_block_configs - def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: + def get_default_block_config(self, node_type: str, filters: dict | None = None) -> dict | None: """ Get default config of node. :param node_type: node type @@ -856,7 +856,7 @@ class WorkflowService: # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: + if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}: raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow @@ -872,11 +872,11 @@ class WorkflowService: return new_app def validate_features_structure(self, app_model: App, features: dict): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: return AdvancedChatAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: return WorkflowAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) @@ -885,7 +885,7 @@ class WorkflowService: def update_workflow( self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict - ) -> Optional[Workflow]: + ) -> Workflow | None: """ Update workflow attributes diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index d4fc68a084..292ac6e008 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -12,7 +12,7 @@ class WorkspaceService: def get_tenant_info(cls, tenant: Tenant): if not tenant: return None - tenant_info = { + tenant_info: dict[str, object] = { "id": tenant.id, "name": tenant.name, "plan": tenant.plan, diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 3498e08426..cdc07c77a8 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document @@ -39,7 +40,7 @@ def enable_annotation_reply_task( db.session.close() return - annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all() + annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all() enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 7a72c27b0c..9b3857b4a5 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -4,6 +4,7 @@ from typing import Optional import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -37,7 +38,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form if not dataset: raise Exception("Document has no dataset") - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) + ).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -62,7 +65,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form db.session.commit() if file_ids: - files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() + files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() for file in files: try: storage.delete(file.key) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 9d12b6a589..5f2a355d16 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -55,8 +56,8 @@ def clean_dataset_task( index_struct=index_struct, collection_binding_id=collection_binding_id, ) - documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() - segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() + documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace # This ensures all invalid doc_form values are properly handled diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 6549ad04b5..62200715cc 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -1,9 +1,9 @@ import logging import time -from typing import Optional import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: Optional[str]): +def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: str | None): """ Clean document when document deleted. :param document_id: document id @@ -35,7 +35,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if not dataset: raise Exception("Document has no dataset") - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index e7a61e22f2..771b43f9b0 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -34,7 +35,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): document = db.session.query(Document).where(Document.id == document_id).first() db.session.delete(document) - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() index_node_ids = [segment.index_node_id for segment in segments] index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 986e9dbc3c..6b2907cffd 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -1,6 +1,5 @@ import logging import time -from typing import Optional import click from celery import shared_task @@ -16,7 +15,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] = None): +def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = None): """ Async create segment to index :param segment_id: diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 23e929c57e..dc6ef6fb61 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -4,6 +4,7 @@ from typing import Literal import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -36,16 +37,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a if action == "remove": index_processor.clean(dataset, None, with_keywords=False) elif action == "add": - dataset_documents = ( - db.session.query(DatasetDocument) - .where( + dataset_documents = db.session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] @@ -89,16 +88,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) db.session.commit() elif action == "update": - dataset_documents = ( - db.session.query(DatasetDocument) - .where( + dataset_documents = db.session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() # add new index if dataset_documents: # update document status diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 0b750cf4db..e8cbd0f250 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -12,7 +12,9 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str): +def delete_segment_from_index_task( + index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None +): """ Async Remove segment from index :param index_node_ids: @@ -26,6 +28,7 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume try: dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: + logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) return dataset_document = db.session.query(Document).where(Document.id == document_id).first() @@ -33,11 +36,19 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info("Document not in valid state for index operations, skipping") return + doc_form = dataset_document.doc_form - index_type = dataset_document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + # Proceed with index cleanup using the index_node_ids directly + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean( + dataset, + index_node_ids, + with_keywords=True, + delete_child_chunks=True, + precomputed_child_node_ids=child_node_ids, + ) end_at = time.perf_counter() logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index d4899fe0e4..9038dc179b 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -44,15 +45,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen # sync index processor index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, ) - .all() - ) + ).all() if not segments: db.session.close() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index a0950b4719..226d990edb 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor @@ -70,7 +71,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 48566b6104..161502a228 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -45,7 +46,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index d93f30ba37..2020179cd9 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -79,7 +80,9 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 647664641d..c5ca7a6171 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -45,15 +46,13 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i # sync index processor index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, ) - .all() - ) + ).all() if not segments: logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) db.session.close() diff --git a/api/tasks/mail_register_task.py b/api/tasks/mail_register_task.py new file mode 100644 index 0000000000..a9472a6119 --- /dev/null +++ b/api/tasks/mail_register_task.py @@ -0,0 +1,87 @@ +import logging +import time + +import click +from celery import shared_task + +from configs import dify_config +from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service + +logger = logging.getLogger(__name__) + + +@shared_task(queue="mail") +def send_email_register_mail_task(language: str, to: str, code: str) -> None: + """ + Send email register email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + code: Email register code + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start email register mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.EMAIL_REGISTER, + language_code=language, + to=to, + template_context={ + "to": to, + "code": code, + }, + ) + + end_at = time.perf_counter() + logger.info( + click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green") + ) + except Exception: + logger.exception("Send email register mail to %s failed", to) + + +@shared_task(queue="mail") +def send_email_register_mail_task_when_account_exist(language: str, to: str, account_name: str) -> None: + """ + Send email register email with internationalization support when account exist. + + Args: + language: Language code for email localization + to: Recipient email address + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start email register mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + login_url = f"{dify_config.CONSOLE_WEB_URL}/signin" + reset_password_url = f"{dify_config.CONSOLE_WEB_URL}/reset-password" + + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST, + language_code=language, + to=to, + template_context={ + "to": to, + "login_url": login_url, + "reset_password_url": reset_password_url, + "account_name": account_name, + }, + ) + + end_at = time.perf_counter() + logger.info( + click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green") + ) + except Exception: + logger.exception("Send email register mail to %s failed", to) diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 545db84fde..1739562588 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -4,6 +4,7 @@ import time import click from celery import shared_task +from configs import dify_config from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service @@ -44,3 +45,47 @@ def send_reset_password_mail_task(language: str, to: str, code: str): ) except Exception: logger.exception("Send password reset mail to %s failed", to) + + +@shared_task(queue="mail") +def send_reset_password_mail_task_when_account_not_exist(language: str, to: str, is_allow_register: bool) -> None: + """ + Send reset password email with internationalization support when account not exist. + + Args: + language: Language code for email localization + to: Recipient email address + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start password reset mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + if is_allow_register: + sign_up_url = f"{dify_config.CONSOLE_WEB_URL}/signup" + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST, + language_code=language, + to=to, + template_context={ + "to": to, + "sign_up_url": sign_up_url, + }, + ) + else: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER, + language_code=language, + to=to, + ) + + end_at = time.perf_counter() + logger.info( + click.style(f"Send password reset mail to {to} succeeded: latency: {end_at - start_at}", fg="green") + ) + except Exception: + logger.exception("Send password reset mail to %s failed", to) diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index d871b297e0..bae8f1c4db 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -1,3 +1,4 @@ +import operator import traceback import typing @@ -118,7 +119,7 @@ def process_tenant_plugin_autoupgrade_check_task( current_version = version latest_version = manifest.latest_version - def fix_only_checker(latest_version, current_version): + def fix_only_checker(latest_version: str, current_version: str): latest_version_tuple = tuple(int(val) for val in latest_version.split(".")) current_version_tuple = tuple(int(val) for val in current_version.split(".")) @@ -130,8 +131,7 @@ def process_tenant_plugin_autoupgrade_check_task( return False version_checker = { - TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version, - current_version: latest_version != current_version, + TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne, TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker, } diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index ec56ab583b..c0ab2d0b41 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -45,7 +46,7 @@ def remove_document_from_index_task(document_id: str): index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 1899f93ff7..ff7848eea6 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -81,7 +82,9 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ # clean old data index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 3c7c69e3c8..0dc1d841f4 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -63,7 +64,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): # clean old data index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/templates/register_email_template_en-US.html b/api/templates/register_email_template_en-US.html new file mode 100644 index 0000000000..e0fec59100 --- /dev/null +++ b/api/templates/register_email_template_en-US.html @@ -0,0 +1,87 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Dify Sign-up Code

+

Your sign-up code for Dify + + Copy and paste this code, this code will only be valid for the next 5 minutes.

+
+ {{code}} +
+

If you didn't request this code, don't worry. You can safely ignore this email.

+
+ + + \ No newline at end of file diff --git a/api/templates/register_email_template_zh-CN.html b/api/templates/register_email_template_zh-CN.html new file mode 100644 index 0000000000..3b507290f0 --- /dev/null +++ b/api/templates/register_email_template_zh-CN.html @@ -0,0 +1,87 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Dify 注册验证码

+

您的 Dify 注册验证码 + + 复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。

+
+ {{code}} +
+

如果您没有请求,请不要担心。您可以安全地忽略此电子邮件。

+
+ + + \ No newline at end of file diff --git a/api/templates/register_email_when_account_exist_template_en-US.html b/api/templates/register_email_when_account_exist_template_en-US.html new file mode 100644 index 0000000000..ac5042c274 --- /dev/null +++ b/api/templates/register_email_when_account_exist_template_en-US.html @@ -0,0 +1,130 @@ + + + + + + + + +
+
+ + Dify Logo +
+

It looks like you’re signing up with an existing account

+

Hi, {{account_name}}

+

+ We noticed you tried to sign up, but this email is already registered with an existing account. + + Please log in here:

+ Log In +

+ If you forgot your password, you can reset it here: Reset Password +

+

+ If you didn’t request this action, you can safely ignore this email. +

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + diff --git a/api/templates/register_email_when_account_exist_template_zh-CN.html b/api/templates/register_email_when_account_exist_template_zh-CN.html new file mode 100644 index 0000000000..326b58343a --- /dev/null +++ b/api/templates/register_email_when_account_exist_template_zh-CN.html @@ -0,0 +1,127 @@ + + + + + + + + +
+
+ + Dify Logo +
+

您似乎正在使用现有账户注册

+

您好,{{account_name}}

+

+ 我们注意到您尝试注册,但此电子邮件已注册。 + + 请在此登录:

+ 登录 +

+ 如果您忘记了密码,可以在此重置: 重置密码 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + diff --git a/api/templates/reset_password_mail_when_account_not_exist_no_register_template_en-US.html b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_en-US.html new file mode 100644 index 0000000000..1c5253a239 --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_en-US.html @@ -0,0 +1,122 @@ + + + + + + + + +
+
+ + Dify Logo +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. +

+

If you didn’t request this action, you can safely ignore this email.

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + \ No newline at end of file diff --git a/api/templates/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html new file mode 100644 index 0000000000..1431291218 --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html @@ -0,0 +1,121 @@ + + + + + + + + +
+
+ + Dify Logo +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + \ No newline at end of file diff --git a/api/templates/reset_password_mail_when_account_not_exist_template_en-US.html b/api/templates/reset_password_mail_when_account_not_exist_template_en-US.html new file mode 100644 index 0000000000..5759d56f7c --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_template_en-US.html @@ -0,0 +1,124 @@ + + + + + + + + +
+
+ + Dify Logo +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. + + Please sign up here:

+ Sign Up +

If you didn’t request this action, you can safely ignore this email.

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + \ No newline at end of file diff --git a/api/templates/reset_password_mail_when_account_not_exist_template_zh-CN.html b/api/templates/reset_password_mail_when_account_not_exist_template_zh-CN.html new file mode 100644 index 0000000000..4de4a8abaa --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_template_zh-CN.html @@ -0,0 +1,126 @@ + + + + + + + + +
+
+ + Dify Logo +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。 + + 请在此注册:

+

+ 注册 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + \ No newline at end of file diff --git a/api/templates/without-brand/register_email_template_en-US.html b/api/templates/without-brand/register_email_template_en-US.html new file mode 100644 index 0000000000..bd67c8ff4a --- /dev/null +++ b/api/templates/without-brand/register_email_template_en-US.html @@ -0,0 +1,83 @@ + + + + + + + + +
+

{{application_title}} Sign-up Code

+

Your sign-up code + + Copy and paste this code, this code will only be valid for the next 5 minutes.

+
+ {{code}} +
+

If you didn't request this code, don't worry. You can safely ignore this email.

+
+ + + diff --git a/api/templates/without-brand/register_email_template_zh-CN.html b/api/templates/without-brand/register_email_template_zh-CN.html new file mode 100644 index 0000000000..26df4760aa --- /dev/null +++ b/api/templates/without-brand/register_email_template_zh-CN.html @@ -0,0 +1,83 @@ + + + + + + + + +
+

{{application_title}} 注册验证码

+

您的 {{application_title}} 注册验证码 + + 复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。

+
+ {{code}} +
+

如果您没有请求此验证码,请不要担心。您可以安全地忽略此电子邮件。

+
+ + + \ No newline at end of file diff --git a/api/templates/without-brand/register_email_when_account_exist_template_en-US.html b/api/templates/without-brand/register_email_when_account_exist_template_en-US.html new file mode 100644 index 0000000000..2e74956e14 --- /dev/null +++ b/api/templates/without-brand/register_email_when_account_exist_template_en-US.html @@ -0,0 +1,126 @@ + + + + + + + + +
+

It looks like you’re signing up with an existing account

+

Hi, {{account_name}}

+

+ We noticed you tried to sign up, but this email is already registered with an existing account. + + Please log in here:

+ Log In +

+ If you forgot your password, you can reset it here: Reset Password +

+

+ If you didn’t request this action, you can safely ignore this email. +

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + diff --git a/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html b/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html new file mode 100644 index 0000000000..a315f9154d --- /dev/null +++ b/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html @@ -0,0 +1,123 @@ + + + + + + + + +
+

您似乎正在使用现有账户注册

+

您好,{{account_name}}

+

+ 我们注意到您尝试注册,但此电子邮件已注册。 + + 请在此登录:

+ 登录 +

+ 如果您忘记了密码,可以在此重置: 重置密码 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html new file mode 100644 index 0000000000..ae59f36332 --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html @@ -0,0 +1,118 @@ + + + + + + + + +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. +

+

If you didn’t request this action, you can safely ignore this email.

+
+
Please do not reply directly to this email, it is automatically sent by the system.
s + + + diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html new file mode 100644 index 0000000000..4b4fda2c6e --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html @@ -0,0 +1,118 @@ + + + + + + + + +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_en-US.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_en-US.html new file mode 100644 index 0000000000..fedc998809 --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_en-US.html @@ -0,0 +1,121 @@ + + + + + + + + +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. + + Please sign up here:

+ Sign Up +

If you didn’t request this action, you can safely ignore this email.

+ +
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + \ No newline at end of file diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html new file mode 100644 index 0000000000..2464b4a058 --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html @@ -0,0 +1,120 @@ + + + + + + + + +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。 + + 请在此注册:

+ 注册 +

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + \ No newline at end of file diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 2e98dec964..92df93fb13 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -203,6 +203,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5 CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py new file mode 100644 index 0000000000..524713fbf1 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -0,0 +1,101 @@ +"""Integration tests for ChatMessageApi permission verification.""" + +import uuid +from unittest import mock + +import pytest +from flask.testing import FlaskClient + +from controllers.console.app import completion as completion_api +from controllers.console.app import wraps +from libs.datetime_utils import naive_utc_now +from models import Account, App, Tenant +from models.account import TenantAccountRole +from models.model import AppMode +from services.app_generate_service import AppGenerateService + + +class TestChatMessageApiPermissions: + """Test permission verification for ChatMessageApi endpoint.""" + + @pytest.fixture + def mock_app_model(self): + """Create a mock App model for testing.""" + app = App() + app.id = str(uuid.uuid4()) + app.mode = AppMode.CHAT.value + app.tenant_id = str(uuid.uuid4()) + app.status = "normal" + return app + + @pytest.fixture + def mock_account(self): + """Create a mock Account for testing.""" + + account = Account() + account.id = str(uuid.uuid4()) + account.name = "Test User" + account.email = "test@example.com" + account.last_active_at = naive_utc_now() + account.created_at = naive_utc_now() + account.updated_at = naive_utc_now() + + # Create mock tenant + tenant = Tenant() + tenant.id = str(uuid.uuid4()) + tenant.name = "Test Tenant" + + account._current_tenant = tenant + return account + + @pytest.mark.parametrize( + ("role", "status"), + [ + (TenantAccountRole.OWNER, 200), + (TenantAccountRole.ADMIN, 200), + (TenantAccountRole.EDITOR, 200), + (TenantAccountRole.NORMAL, 403), + (TenantAccountRole.DATASET_OPERATOR, 403), + ], + ) + def test_post_with_owner_role_succeeds( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_app_model, + mock_account, + role: TenantAccountRole, + status: int, + ): + """Test that OWNER role can access chat-messages endpoint.""" + + """Setup common mocks for testing.""" + # Mock app loading + + mock_load_app_model = mock.Mock(return_value=mock_app_model) + monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + + # Mock current user + monkeypatch.setattr(completion_api, "current_user", mock_account) + + mock_generate = mock.Mock(return_value={"message": "Test response"}) + monkeypatch.setattr(AppGenerateService, "generate", mock_generate) + + # Set user role to OWNER + mock_account.role = role + + response = test_client.post( + f"/console/api/apps/{mock_app_model.id}/chat-messages", + headers=auth_header, + json={ + "inputs": {}, + "query": "Hello, world!", + "model_config": { + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}} + }, + "response_mode": "blocking", + }, + ) + + assert response.status_code == status diff --git a/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py new file mode 100644 index 0000000000..ca4d452963 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py @@ -0,0 +1,129 @@ +"""Integration tests for ModelConfigResource permission verification.""" + +import uuid +from unittest import mock + +import pytest +from flask.testing import FlaskClient + +from controllers.console.app import model_config as model_config_api +from controllers.console.app import wraps +from libs.datetime_utils import naive_utc_now +from models import Account, App, Tenant +from models.account import TenantAccountRole +from models.model import AppMode +from services.app_model_config_service import AppModelConfigService + + +class TestModelConfigResourcePermissions: + """Test permission verification for ModelConfigResource endpoint.""" + + @pytest.fixture + def mock_app_model(self): + """Create a mock App model for testing.""" + app = App() + app.id = str(uuid.uuid4()) + app.mode = AppMode.CHAT.value + app.tenant_id = str(uuid.uuid4()) + app.status = "normal" + app.app_model_config_id = str(uuid.uuid4()) + return app + + @pytest.fixture + def mock_account(self): + """Create a mock Account for testing.""" + + account = Account() + account.id = str(uuid.uuid4()) + account.name = "Test User" + account.email = "test@example.com" + account.last_active_at = naive_utc_now() + account.created_at = naive_utc_now() + account.updated_at = naive_utc_now() + + # Create mock tenant + tenant = Tenant() + tenant.id = str(uuid.uuid4()) + tenant.name = "Test Tenant" + + account._current_tenant = tenant + return account + + @pytest.mark.parametrize( + ("role", "status"), + [ + (TenantAccountRole.OWNER, 200), + (TenantAccountRole.ADMIN, 200), + (TenantAccountRole.EDITOR, 200), + (TenantAccountRole.NORMAL, 403), + (TenantAccountRole.DATASET_OPERATOR, 403), + ], + ) + def test_post_with_owner_role_succeeds( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_app_model, + mock_account, + role: TenantAccountRole, + status: int, + ): + """Test that OWNER role can access model-config endpoint.""" + # Set user role to OWNER + mock_account.role = role + + # Mock app loading + mock_load_app_model = mock.Mock(return_value=mock_app_model) + monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + + # Mock current user + monkeypatch.setattr(model_config_api, "current_user", mock_account) + + # Mock AccountService.load_user to prevent authentication issues + from services.account_service import AccountService + + mock_load_user = mock.Mock(return_value=mock_account) + monkeypatch.setattr(AccountService, "load_user", mock_load_user) + + mock_validate_config = mock.Mock( + return_value={ + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}}, + "pre_prompt": "You are a helpful assistant.", + "user_input_form": [], + "dataset_query_variable": "", + "agent_mode": {"enabled": False, "tools": []}, + } + ) + monkeypatch.setattr(AppModelConfigService, "validate_configuration", mock_validate_config) + + # Mock database operations + mock_db_session = mock.Mock() + mock_db_session.add = mock.Mock() + mock_db_session.flush = mock.Mock() + mock_db_session.commit = mock.Mock() + monkeypatch.setattr(model_config_api.db, "session", mock_db_session) + + # Mock app_model_config_was_updated event + mock_event = mock.Mock() + mock_event.send = mock.Mock() + monkeypatch.setattr(model_config_api, "app_model_config_was_updated", mock_event) + + response = test_client.post( + f"/console/api/apps/{mock_app_model.id}/model-config", + headers=auth_header, + json={ + "model": { + "provider": "openai", + "name": "gpt-4", + "mode": "chat", + "completion_params": {"temperature": 0.7, "max_tokens": 1000}, + }, + "user_input_form": [], + "dataset_query_variable": "", + "pre_prompt": "You are a helpful assistant.", + "agent_mode": {"enabled": False, "tools": []}, + }, + ) + + assert response.status_code == status diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index 0fb7076c85..bc64fda9c2 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -1,6 +1,5 @@ import unittest from datetime import UTC, datetime -from typing import Optional from unittest.mock import patch from uuid import uuid4 @@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase): self.session.rollback() def _create_upload_file( - self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None ) -> UploadFile: """Helper method to create an UploadFile record for testing.""" if file_id is None: @@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase): return upload_file def _create_tool_file( - self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None ) -> ToolFile: """Helper method to create a ToolFile record for testing.""" if file_id is None: @@ -101,9 +100,7 @@ class TestStorageKeyLoader(unittest.TestCase): return tool_file - def _create_file( - self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None - ) -> File: + def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File: """Helper method to create a File object for testing.""" if tenant_id is None: tenant_id = self.tenant_id diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index d699866fb4..d59d5dc0fe 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -5,8 +5,6 @@ from decimal import Decimal from json import dumps # import monkeypatch -from typing import Optional - from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool @@ -113,8 +111,8 @@ class MockModelClass(PluginModelClient): @staticmethod def generate_function_call( - tools: Optional[list[PromptMessageTool]], - ) -> Optional[AssistantPromptMessage.ToolCall]: + tools: list[PromptMessageTool] | None, + ) -> AssistantPromptMessage.ToolCall | None: if not tools or len(tools) == 0: return None function: PromptMessageTool = tools[0] @@ -157,7 +155,7 @@ class MockModelClass(PluginModelClient): def mocked_chat_create_sync( model: str, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> LLMResult: tool_call = MockModelClass.generate_function_call(tools=tools) @@ -186,7 +184,7 @@ class MockModelClass(PluginModelClient): def mocked_chat_create_stream( model: str, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> Generator[LLMResultChunk, None, None]: tool_call = MockModelClass.generate_function_call(tools=tools) @@ -241,9 +239,9 @@ class MockModelClass(PluginModelClient): model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, ): return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools) diff --git a/api/tests/integration_tests/storage/test_clickzetta_volume.py b/api/tests/integration_tests/storage/test_clickzetta_volume.py index 293b469ef3..7e60f60adc 100644 --- a/api/tests/integration_tests/storage/test_clickzetta_volume.py +++ b/api/tests/integration_tests/storage/test_clickzetta_volume.py @@ -3,6 +3,7 @@ import os import tempfile import unittest +from pathlib import Path import pytest @@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase): # Test download with tempfile.NamedTemporaryFile() as temp_file: storage.download(test_filename, temp_file.name) - with open(temp_file.name, "rb") as f: - downloaded_content = f.read() + downloaded_content = Path(temp_file.name).read_bytes() assert downloaded_content == test_content # Test scan diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index be5b4de5a2..f9f9f4f369 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -1,6 +1,5 @@ import os from collections import UserDict -from typing import Optional from unittest.mock import MagicMock import pytest @@ -22,7 +21,7 @@ class MockBaiduVectorDBClass: def mock_vector_db_client( self, config=None, - adapter: Optional[HTTPAdapter] = None, + adapter: HTTPAdapter | None = None, ): self.conn = MagicMock() self._config = MagicMock() diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index fd7ab0a22b..e0b908cece 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Union +from typing import Union import pytest from _pytest.monkeypatch import MonkeyPatch @@ -23,16 +23,16 @@ class MockTcvectordbClass: key="", read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, timeout=10, - adapter: Optional[HTTPAdapter] = None, + adapter: HTTPAdapter | None = None, pool_size: int = 2, - proxies: Optional[dict] = None, - password: Optional[str] = None, + proxies: dict | None = None, + password: str | None = None, **kwargs, ): self._conn = None self._read_consistency = read_consistency - def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase: + def create_database_if_not_exists(self, database_name: str, timeout: float | None = None) -> RPCDatabase: return RPCDatabase( name="dify", read_consistency=self._read_consistency, @@ -42,7 +42,7 @@ class MockTcvectordbClass: return True def describe_collection( - self, database_name: str, collection_name: str, timeout: Optional[float] = None + self, database_name: str, collection_name: str, timeout: float | None = None ) -> RPCCollection: index = Index( FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY), @@ -71,13 +71,13 @@ class MockTcvectordbClass: collection_name: str, shard: int, replicas: int, - description: Optional[str] = None, - index: Optional[Index] = None, - embedding: Optional[Embedding] = None, - timeout: Optional[float] = None, - ttl_config: Optional[dict] = None, - filter_index_config: Optional[FilterIndexConfig] = None, - indexes: Optional[list[IndexField]] = None, + description: str | None = None, + index: Index | None = None, + embedding: Embedding | None = None, + timeout: float | None = None, + ttl_config: dict | None = None, + filter_index_config: FilterIndexConfig | None = None, + indexes: list[IndexField] | None = None, ) -> RPCCollection: return RPCCollection( RPCDatabase( @@ -102,7 +102,7 @@ class MockTcvectordbClass: database_name: str, collection_name: str, documents: list[Union[Document, dict]], - timeout: Optional[float] = None, + timeout: float | None = None, build_index: bool = True, **kwargs, ): @@ -113,12 +113,12 @@ class MockTcvectordbClass: database_name: str, collection_name: str, vectors: list[list[float]], - filter: Optional[Filter] = None, + filter: Filter | None = None, params=None, retrieve_vector: bool = False, limit: int = 10, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + output_fields: list[str] | None = None, + timeout: float | None = None, ) -> list[list[dict]]: return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]] @@ -126,14 +126,14 @@ class MockTcvectordbClass: self, database_name: str, collection_name: str, - ann: Optional[Union[list[AnnSearch], AnnSearch]] = None, - match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None, - filter: Optional[Union[Filter, str]] = None, - rerank: Optional[Rerank] = None, - retrieve_vector: Optional[bool] = None, - output_fields: Optional[list[str]] = None, - limit: Optional[int] = None, - timeout: Optional[float] = None, + ann: Union[list[AnnSearch], AnnSearch] | None = None, + match: Union[list[KeywordSearch], KeywordSearch] | None = None, + filter: Union[Filter, str] | None = None, + rerank: Rerank | None = None, + retrieve_vector: bool | None = None, + output_fields: list[str] | None = None, + limit: int | None = None, + timeout: float | None = None, return_pd_object=False, **kwargs, ) -> list[list[dict]]: @@ -143,13 +143,13 @@ class MockTcvectordbClass: self, database_name: str, collection_name: str, - document_ids: Optional[list] = None, + document_ids: list | None = None, retrieve_vector: bool = False, - limit: Optional[int] = None, - offset: Optional[int] = None, - filter: Optional[Filter] = None, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + limit: int | None = None, + offset: int | None = None, + filter: Filter | None = None, + output_fields: list[str] | None = None, + timeout: float | None = None, ): return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}] @@ -157,13 +157,13 @@ class MockTcvectordbClass: self, database_name: str, collection_name: str, - document_ids: Optional[list[str]] = None, - filter: Optional[Filter] = None, - timeout: Optional[float] = None, + document_ids: list[str] | None = None, + filter: Filter | None = None, + timeout: float | None = None, ): return {"code": 0, "msg": "operation success"} - def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None): + def drop_collection(self, database_name: str, collection_name: str, timeout: float | None = None): return {"code": 0, "msg": "operation success"} diff --git a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py b/api/tests/integration_tests/vdb/__mock/upstashvectordb.py index 4b251ba836..70c85d4c98 100644 --- a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/upstashvectordb.py @@ -1,6 +1,5 @@ import os from collections import UserDict -from typing import Optional import pytest from _pytest.monkeypatch import MonkeyPatch @@ -34,7 +33,7 @@ class MockIndex: include_vectors: bool = False, include_metadata: bool = False, filter: str = "", - data: Optional[str] = None, + data: str | None = None, namespace: str = "", include_data: bool = False, ): diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index d85d091a2e..76918f689f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -1,7 +1,6 @@ import os import time import uuid -from typing import Optional from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom @@ -27,7 +26,7 @@ def get_mocked_fetch_memory(memory_text: str): human_prefix: str = "Human", ai_prefix: str = "Assistant", max_token_limit: int = 2000, - message_limit: Optional[int] = None, + message_limit: int | None = None, ): return memory_text diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 86fd6c5a85..145e31bca0 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -11,7 +11,6 @@ import logging import os from collections.abc import Generator from pathlib import Path -from typing import Optional import pytest from flask import Flask @@ -42,10 +41,10 @@ class DifyTestContainers: def __init__(self): """Initialize container management with default configurations.""" - self.postgres: Optional[PostgresContainer] = None - self.redis: Optional[RedisContainer] = None - self.dify_sandbox: Optional[DockerContainer] = None - self.dify_plugin_daemon: Optional[DockerContainer] = None + self.postgres: PostgresContainer | None = None + self.redis: RedisContainer | None = None + self.dify_sandbox: DockerContainer | None = None + self.dify_plugin_daemon: DockerContainer | None = None self._containers_started = False logger.info("DifyTestContainers initialized - ready to manage test containers") diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index b6fe8b73a2..21a792de06 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -1,6 +1,5 @@ import unittest from datetime import UTC, datetime -from typing import Optional from unittest.mock import patch from uuid import uuid4 @@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase): self.session.rollback() def _create_upload_file( - self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None ) -> UploadFile: """Helper method to create an UploadFile record for testing.""" if file_id is None: @@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase): return upload_file def _create_tool_file( - self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None ) -> ToolFile: """Helper method to create a ToolFile record for testing.""" if file_id is None: @@ -102,9 +101,7 @@ class TestStorageKeyLoader(unittest.TestCase): return tool_file - def _create_file( - self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None - ) -> File: + def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File: """Helper method to create a File object for testing.""" if tenant_id is None: tenant_id = self.tenant_id diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 415e65ce51..c98406d845 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -13,7 +13,6 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, - AccountNotFoundError, AccountPasswordError, AccountRegisterError, CurrentPasswordIncorrectError, @@ -91,6 +90,28 @@ class TestAccountService: assert account.password is None assert account.password_salt is None + def test_create_account_password_invalid_new_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account create with invalid new password format. + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Test with too short password (assuming minimum length validation) + with pytest.raises(ValueError): # Password validation error + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password="invalid_new_password", + ) + def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies): """ Test account creation when registration is disabled. @@ -139,7 +160,7 @@ class TestAccountService: fake = Faker() email = fake.email() password = fake.password(length=12) - with pytest.raises(AccountNotFoundError): + with pytest.raises(AccountPasswordError): AccountService.authenticate(email, password) def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies): @@ -940,7 +961,8 @@ class TestAccountService: Test getting user through non-existent email. """ fake = Faker() - non_existent_email = fake.email() + domain = f"test-{fake.random_letters(10)}.com" + non_existent_email = fake.email(domain=domain) found_user = AccountService.get_user_through_email(non_existent_email) assert found_user is None @@ -3278,7 +3300,7 @@ class TestRegisterService: redis_client.setex(cache_key, 24 * 60 * 60, account_id) # Execute invitation retrieval - result = RegisterService._get_invitation_by_token( + result = RegisterService.get_invitation_by_token( token=token, workspace_id=workspace_id, email=email, @@ -3316,7 +3338,7 @@ class TestRegisterService: redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) # Execute invitation retrieval - result = RegisterService._get_invitation_by_token(token=token) + result = RegisterService.get_invitation_by_token(token=token) # Verify result contains expected data assert result is not None diff --git a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py index 9ed9008af9..3ec265d009 100644 --- a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py +++ b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py @@ -42,7 +42,7 @@ class TestAdvancedPromptTemplateService: # Test data for Baichuan model args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", @@ -77,7 +77,7 @@ class TestAdvancedPromptTemplateService: # Test data for common model args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -116,7 +116,7 @@ class TestAdvancedPromptTemplateService: for model_name in test_cases: args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": model_name, "has_context": "true", @@ -144,7 +144,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -173,7 +173,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -202,7 +202,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -230,7 +230,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -257,7 +257,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "false") # Assert: Verify the expected outcomes assert result is not None @@ -303,7 +303,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "unsupported_mode", "true") # Assert: Verify empty dict is returned assert result == {} @@ -442,7 +442,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -473,7 +473,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -502,7 +502,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -530,7 +530,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -557,7 +557,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") # Assert: Verify the expected outcomes assert result is not None @@ -603,7 +603,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "unsupported_mode", "true") # Assert: Verify empty dict is returned assert result == {} @@ -621,7 +621,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Test all app modes - app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value] + app_modes = [AppMode.CHAT, AppMode.COMPLETION] model_modes = ["completion", "chat"] for app_mode in app_modes: @@ -653,7 +653,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Test all app modes - app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value] + app_modes = [AppMode.CHAT, AppMode.COMPLETION] model_modes = ["completion", "chat"] for app_mode in app_modes: @@ -686,10 +686,10 @@ class TestAdvancedPromptTemplateService: # Test edge cases edge_cases = [ {"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"}, - {"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"}, - {"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"}, + {"app_mode": AppMode.CHAT, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"}, + {"app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "", "has_context": "true"}, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "", @@ -723,7 +723,7 @@ class TestAdvancedPromptTemplateService: # Test with context args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -757,7 +757,7 @@ class TestAdvancedPromptTemplateService: # Test with context args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", @@ -786,25 +786,25 @@ class TestAdvancedPromptTemplateService: # Test different scenarios test_scenarios = [ { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "chat", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "chat", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -843,25 +843,25 @@ class TestAdvancedPromptTemplateService: # Test different scenarios test_scenarios = [ { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "chat", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "chat", "model_name": "baichuan-13b-chat", "has_context": "true", diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index 4646531a4e..d0f7e945f1 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -255,7 +255,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Try to create metadata with built-in field name - built_in_field_name = BuiltInField.document_name.value + built_in_field_name = BuiltInField.document_name metadata_args = MetadataArgs(type="string", name=built_in_field_name) # Act & Assert: Verify proper error handling @@ -375,7 +375,7 @@ class TestMetadataService: metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Try to update with built-in field name - built_in_field_name = BuiltInField.document_name.value + built_in_field_name = BuiltInField.document_name with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) @@ -540,11 +540,11 @@ class TestMetadataService: field_names = [field["name"] for field in result] field_types = [field["type"] for field in result] - assert BuiltInField.document_name.value in field_names - assert BuiltInField.uploader.value in field_names - assert BuiltInField.upload_date.value in field_names - assert BuiltInField.last_update_date.value in field_names - assert BuiltInField.source.value in field_names + assert BuiltInField.document_name in field_names + assert BuiltInField.uploader in field_names + assert BuiltInField.upload_date in field_names + assert BuiltInField.last_update_date in field_names + assert BuiltInField.source in field_names # Verify field types assert "string" in field_types @@ -682,11 +682,11 @@ class TestMetadataService: # Set document metadata with built-in fields document.doc_metadata = { - BuiltInField.document_name.value: document.name, - BuiltInField.uploader.value: "test_uploader", - BuiltInField.upload_date.value: 1234567890.0, - BuiltInField.last_update_date.value: 1234567890.0, - BuiltInField.source.value: "test_source", + BuiltInField.document_name: document.name, + BuiltInField.uploader: "test_uploader", + BuiltInField.upload_date: 1234567890.0, + BuiltInField.last_update_date: 1234567890.0, + BuiltInField.source: "test_source", } db.session.add(document) db.session.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index cb20238f0c..66527dd506 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from models.account import TenantAccountJoin, TenantAccountRole from models.model import Account, Tenant @@ -468,7 +469,7 @@ class TestModelLoadBalancingService: assert load_balancing_config.id is not None # Verify inherit config was created in database - inherit_configs = ( - db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all() - ) + inherit_configs = db.session.scalars( + select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__") + ).all() assert len(inherit_configs) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index d09a4a17ab..04cff397b2 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy import select from werkzeug.exceptions import NotFound from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -954,7 +955,9 @@ class TestTagService: from extensions.ext_database import db # Verify only one binding exists - bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() + bindings = db.session.scalars( + select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) + ).all() assert len(bindings) == 1 def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): @@ -1064,7 +1067,9 @@ class TestTagService: # No error should be raised, and database state should remain unchanged from extensions.ext_database import db - bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() + bindings = db.session.scalars( + select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) + ).all() assert len(bindings) == 0 def test_check_target_exists_knowledge_success( diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 6d6f1dab72..c9ace46c55 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account @@ -354,16 +355,14 @@ class TestWebConversationService: # Verify only one pinned conversation record exists from extensions.ext_database import db - pinned_conversations = ( - db.session.query(PinnedConversation) - .where( + pinned_conversations = db.session.scalars( + select(PinnedConversation).where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, PinnedConversation.created_by_role == "account", PinnedConversation.created_by == account.id, ) - .all() - ) + ).all() assert len(pinned_conversations) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 429056f5e2..316cfe1674 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -1,3 +1,5 @@ +import time +import uuid from unittest.mock import patch import pytest @@ -248,9 +250,15 @@ class TestWebAppAuthService: - Proper error handling for non-existent accounts - Correct exception type and message """ - # Arrange: Use non-existent email - fake = Faker() - non_existent_email = fake.email() + # Arrange: Generate a guaranteed non-existent email + # Use UUID and timestamp to ensure uniqueness + unique_id = str(uuid.uuid4()).replace("-", "") + timestamp = str(int(time.time() * 1000000)) # microseconds + non_existent_email = f"nonexistent_{unique_id}_{timestamp}@test-domain-that-never-exists.invalid" + + # Double-check this email doesn't exist in the database + existing_account = db_session_with_containers.query(Account).filter_by(email=non_existent_email).first() + assert existing_account is None, f"Test email {non_existent_email} already exists in database" # Act & Assert: Verify proper error handling with pytest.raises(AccountNotFoundError): diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index eb7a5a23f6..60150667ed 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -96,7 +96,7 @@ class TestWorkflowService: app.tenant_id = fake.uuid4() app.name = fake.company() app.description = fake.text() - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW app.icon_type = "emoji" app.icon = "🤖" app.icon_background = "#FFEAD5" @@ -883,7 +883,7 @@ class TestWorkflowService: # Create chat mode app app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT # Create app model config (required for conversion) from models.model import AppModelConfig @@ -926,7 +926,7 @@ class TestWorkflowService: # Assert assert result is not None - assert result.mode == AppMode.ADVANCED_CHAT.value # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW + assert result.mode == AppMode.ADVANCED_CHAT # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW assert result.name == conversion_args["name"] assert result.icon == conversion_args["icon"] assert result.icon_type == conversion_args["icon_type"] @@ -945,7 +945,7 @@ class TestWorkflowService: # Create completion mode app app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.COMPLETION.value + app.mode = AppMode.COMPLETION # Create app model config (required for conversion) from models.model import AppModelConfig @@ -988,7 +988,7 @@ class TestWorkflowService: # Assert assert result is not None - assert result.mode == AppMode.WORKFLOW.value + assert result.mode == AppMode.WORKFLOW assert result.name == conversion_args["name"] assert result.icon == conversion_args["icon"] assert result.icon_type == conversion_args["icon_type"] @@ -1007,7 +1007,7 @@ class TestWorkflowService: # Create workflow mode app (already in workflow mode) app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW from extensions.ext_database import db @@ -1030,7 +1030,7 @@ class TestWorkflowService: # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.ADVANCED_CHAT.value + app.mode = AppMode.ADVANCED_CHAT from extensions.ext_database import db @@ -1061,7 +1061,7 @@ class TestWorkflowService: # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW from extensions.ext_database import db diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 8b3db27525..18ab4bb73c 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -14,6 +14,7 @@ from core.app.app_config.entities import ( VariableEntityType, ) from core.model_runtime.entities.llm_entities import LLMMode +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.account import Account, Tenant from models.api_based_extension import APIBasedExtension from models.model import App, AppMode, AppModelConfig @@ -37,7 +38,7 @@ class TestWorkflowConverter: # Setup default mock returns mock_encrypter.decrypt_token.return_value = "decrypted_api_key" mock_prompt_transform.return_value.get_prompt_template.return_value = { - "prompt_template": type("obj", (object,), {"template": "You are a helpful assistant {{text_input}}"})(), + "prompt_template": PromptTemplateParser(template="You are a helpful assistant {{text_input}}"), "prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"}, } mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config() diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 065bcc2cd7..fcae93c669 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -12,6 +12,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances. import uuid from datetime import datetime +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -276,8 +277,7 @@ class TestBatchCreateSegmentToIndexTask: mock_storage = mock_external_service_dependencies["storage"] def mock_download(key, file_path): - with open(file_path, "w", encoding="utf-8") as f: - f.write(csv_content) + Path(file_path).write_text(csv_content, encoding="utf-8") mock_storage.download.side_effect = mock_download @@ -505,7 +505,7 @@ class TestBatchCreateSegmentToIndexTask: db.session.commit() # Test each unavailable document - for i, document in enumerate(test_cases): + for document in test_cases: job_id = str(uuid.uuid4()) batch_create_segment_to_index_task( job_id=job_id, @@ -601,8 +601,7 @@ class TestBatchCreateSegmentToIndexTask: mock_storage = mock_external_service_dependencies["storage"] def mock_download(key, file_path): - with open(file_path, "w", encoding="utf-8") as f: - f.write(empty_csv_content) + Path(file_path).write_text(empty_csv_content, encoding="utf-8") mock_storage.download.side_effect = mock_download @@ -684,8 +683,7 @@ class TestBatchCreateSegmentToIndexTask: mock_storage = mock_external_service_dependencies["storage"] def mock_download(key, file_path): - with open(file_path, "w", encoding="utf-8") as f: - f.write(csv_content) + Path(file_path).write_text(csv_content, encoding="utf-8") mock_storage.download.side_effect = mock_download diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 0083011070..e0c2da63b9 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -362,7 +362,7 @@ class TestCleanDatasetTask: # Create segments for each document segments = [] - for i, document in enumerate(documents): + for document in documents: segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document) segments.append(segment) diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py new file mode 100644 index 0000000000..de81295100 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -0,0 +1,1099 @@ +""" +Integration tests for create_segment_to_index_task using TestContainers. + +This module provides comprehensive testing for the create_segment_to_index_task +which handles asynchronous document segment indexing operations. +""" + +import time +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from faker import Faker + +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.create_segment_to_index_task import create_segment_to_index_task + + +class TestCreateSegmentToIndexTask: + """Integration tests for create_segment_to_index_task using testcontainers.""" + + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers): + """Clean up database and Redis before each test to ensure isolation.""" + from extensions.ext_database import db + + # Clear all test data + db.session.query(DocumentSegment).delete() + db.session.query(Document).delete() + db.session.query(Dataset).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Tenant).delete() + db.session.query(Account).delete() + db.session.commit() + + # Clear Redis cache + redis_client.flushdb() + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.create_segment_to_index_task.IndexProcessorFactory") as mock_factory, + ): + # Setup default mock returns + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + yield { + "index_processor_factory": mock_factory, + "index_processor": mock_processor, + } + + def _create_test_account_and_tenant(self, db_session_with_containers): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + plan="basic", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join with owner role + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_dataset_and_document(self, db_session_with_containers, tenant_id, account_id): + """ + Helper method to create a test dataset and document for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant_id: Tenant ID for the dataset + account_id: Account ID for the document + + Returns: + tuple: (dataset, document) - Created dataset and document instances + """ + fake = Faker() + + # Create dataset + dataset = Dataset( + name=fake.company(), + description=fake.text(max_nb_chars=100), + tenant_id=tenant_id, + data_source_type="upload_file", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + created_by=account_id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + # Create document + document = Document( + name=fake.file_name(), + dataset_id=dataset.id, + tenant_id=tenant_id, + position=1, + data_source_type="upload_file", + batch="test_batch", + created_from="upload_file", + created_by=account_id, + enabled=True, + archived=False, + indexing_status="completed", + doc_form="qa_model", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + return dataset, document + + def _create_test_segment( + self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status="waiting" + ): + """ + Helper method to create a test document segment for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset_id: Dataset ID for the segment + document_id: Document ID for the segment + tenant_id: Tenant ID for the segment + account_id: Account ID for the segment + status: Initial status of the segment + + Returns: + DocumentSegment: Created document segment instance + """ + fake = Faker() + + segment = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + position=1, + content=fake.text(max_nb_chars=500), + answer=fake.text(max_nb_chars=200), + word_count=len(fake.text(max_nb_chars=500).split()), + tokens=len(fake.text(max_nb_chars=500).split()) * 2, + keywords=["test", "document", "segment"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status=status, + created_by=account_id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + return segment + + def test_create_segment_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful creation of segment to index. + + This test verifies: + - Segment status transitions from waiting to indexing to completed + - Index processor is called with correct parameters + - Segment metadata is properly updated + - Redis cache key is cleaned up + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify segment status changes + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + assert segment.error is None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + # Verify Redis cache cleanup + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 + + def test_create_segment_to_index_segment_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent segment ID. + + This test verifies: + - Task gracefully handles missing segment + - No exceptions are raised + - Database session is properly closed + """ + # Arrange: Use non-existent segment ID + non_existent_segment_id = str(uuid4()) + + # Act & Assert: Task should complete without error + result = create_segment_to_index_task(non_existent_segment_id) + assert result is None + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_invalid_status( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with invalid status. + + This test verifies: + - Task skips segments not in 'waiting' status + - No processing occurs for invalid status + - Database session is properly closed + """ + # Arrange: Create segment with invalid status + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="completed" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status unchanged + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is None + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_no_dataset(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test handling of segment without associated dataset. + + This test verifies: + - Task gracefully handles missing dataset + - Segment status remains unchanged + - No processing occurs + """ + # Arrange: Create segment with invalid dataset_id + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + invalid_dataset_id = str(uuid4()) + + # Create document with invalid dataset_id + document = Document( + name="test_doc", + dataset_id=invalid_dataset_id, + tenant_id=tenant.id, + position=1, + data_source_type="upload_file", + batch="test_batch", + created_from="upload_file", + created_by=account.id, + enabled=True, + archived=False, + indexing_status="completed", + doc_form="text_model", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, invalid_dataset_id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_no_document(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test handling of segment without associated document. + + This test verifies: + - Task gracefully handles missing document + - Segment status remains unchanged + - No processing occurs + """ + # Arrange: Create segment with invalid document_id + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, _ = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + invalid_document_id = str(uuid4()) + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, invalid_document_id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with disabled document. + + This test verifies: + - Task skips segments with disabled documents + - No processing occurs for disabled documents + - Segment status remains unchanged + """ + # Arrange: Create disabled document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Disable the document + document.enabled = False + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_archived( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with archived document. + + This test verifies: + - Task skips segments with archived documents + - No processing occurs for archived documents + - Segment status remains unchanged + """ + # Arrange: Create archived document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Archive the document + document.archived = True + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_indexing_incomplete( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with document that has incomplete indexing. + + This test verifies: + - Task skips segments with incomplete indexing documents + - No processing occurs for incomplete indexing + - Segment status remains unchanged + """ + # Arrange: Create document with incomplete indexing + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Set incomplete indexing status + document.indexing_status = "indexing" + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_processor_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of index processor exceptions. + + This test verifies: + - Task properly handles index processor failures + - Segment status is updated to error + - Segment is disabled with error information + - Redis cache is cleaned up despite errors + """ + # Arrange: Create test data and mock processor exception + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Mock processor to raise exception + mock_external_service_dependencies["index_processor"].load.side_effect = Exception("Processor failed") + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify error handling + db_session_with_containers.refresh(segment) + assert segment.status == "error" + assert segment.enabled is False + assert segment.disabled_at is not None + assert segment.error == "Processor failed" + + # Verify Redis cache cleanup still occurs + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 + + def test_create_segment_to_index_with_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with custom keywords. + + This test verifies: + - Task accepts and processes keywords parameter + - Keywords are properly passed through the task + - Indexing completes successfully with keywords + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + custom_keywords = ["custom", "keywords", "test"] + + # Act: Execute the task with keywords + create_segment_to_index_task(segment.id, keywords=custom_keywords) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_different_doc_forms( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with different document forms. + + This test verifies: + - Task works with various document forms + - Index processor factory receives correct doc_form + - Processing completes successfully for different forms + """ + # Arrange: Test different doc_forms + doc_forms = ["qa_model", "text_model", "web_model"] + + for doc_form in doc_forms: + # Create fresh test data for each form + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document( + db_session_with_containers, tenant.id, account.id + ) + + # Update document's doc_form for testing + document.doc_form = doc_form + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + # Verify correct doc_form was passed to factory + mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form) + + def test_create_segment_to_index_performance_timing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing performance and timing. + + This test verifies: + - Task execution time is reasonable + - Performance metrics are properly recorded + - No significant performance degradation + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task and measure time + start_time = time.time() + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify performance + execution_time = end_time - start_time + assert execution_time < 5.0 # Should complete within 5 seconds + + # Verify successful completion + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + def test_create_segment_to_index_concurrent_execution( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test concurrent execution of segment indexing tasks. + + This test verifies: + - Multiple tasks can run concurrently + - No race conditions occur + - All segments are processed correctly + """ + # Arrange: Create multiple test segments + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + segments = [] + for i in range(3): + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + segments.append(segment) + + # Act: Execute tasks concurrently (simulated) + segment_ids = [segment.id for segment in segments] + for segment_id in segment_ids: + create_segment_to_index_task(segment_id) + + # Assert: Verify all segments processed + for segment in segments: + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called for each segment + assert mock_external_service_dependencies["index_processor_factory"].call_count == 3 + + def test_create_segment_to_index_large_content( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with large content. + + This test verifies: + - Task handles large content segments + - Performance remains acceptable with large content + - No memory or processing issues occur + """ + # Arrange: Create segment with large content + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Generate large content (simulate large document) + large_content = "Large content " * 1000 # ~15KB content + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content=large_content, + answer="Large answer " * 100, + word_count=len(large_content.split()), + tokens=len(large_content.split()) * 2, + keywords=["large", "content", "test"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + start_time = time.time() + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify successful processing + execution_time = end_time - start_time + assert execution_time < 10.0 # Should complete within 10 seconds + + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_redis_failure( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing when Redis operations fail. + + This test verifies: + - Task continues to work even if Redis fails + - Indexing completes successfully + - Redis errors don't affect core functionality + """ + # Arrange: Create test data and mock Redis failure + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Set up Redis cache key to simulate indexing in progress + cache_key = f"segment_{segment.id}_indexing" + redis_client.set(cache_key, "processing", ex=300) + + # Mock Redis to raise exception in finally block + with patch.object(redis_client, "delete", side_effect=Exception("Redis connection failed")): + # Act: Execute the task - Redis failure should not prevent completion + with pytest.raises(Exception) as exc_info: + create_segment_to_index_task(segment.id) + + # Verify the exception contains the expected Redis error message + assert "Redis connection failed" in str(exc_info.value) + + # Assert: Verify indexing still completed successfully despite Redis failure + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify Redis cache key still exists (since delete failed) + assert redis_client.exists(cache_key) == 1 + + def test_create_segment_to_index_database_transaction_rollback( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with database transaction handling. + + This test verifies: + - Database transactions are properly managed + - Rollback occurs on errors + - Data consistency is maintained + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Mock global database session to simulate transaction issues + from extensions.ext_database import db + + original_commit = db.session.commit + commit_called = False + + def mock_commit(): + nonlocal commit_called + if not commit_called: + commit_called = True + raise Exception("Database commit failed") + return original_commit() + + db.session.commit = mock_commit + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify error handling and rollback + db_session_with_containers.refresh(segment) + assert segment.status == "error" + assert segment.enabled is False + assert segment.disabled_at is not None + assert segment.error is not None + + # Restore original commit method + db.session.commit = original_commit + + def test_create_segment_to_index_metadata_validation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with metadata validation. + + This test verifies: + - Document metadata is properly constructed + - All required metadata fields are present + - Metadata is correctly passed to index processor + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + # Verify index processor was called with correct metadata + mock_processor = mock_external_service_dependencies["index_processor"] + mock_processor.load.assert_called_once() + + # Get the call arguments to verify metadata structure + call_args = mock_processor.load.call_args + assert len(call_args[0]) == 2 # dataset and documents + + # Verify basic structure without deep object inspection + called_dataset = call_args[0][0] # first arg should be dataset + assert called_dataset is not None + + documents = call_args[0][1] # second arg should be list of documents + assert len(documents) == 1 + doc = documents[0] + assert doc is not None + + def test_create_segment_to_index_status_transition_flow( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test complete status transition flow during indexing. + + This test verifies: + - Status transitions: waiting -> indexing -> completed + - Timestamps are properly recorded at each stage + - No intermediate states are skipped + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Verify initial state + assert segment.status == "waiting" + assert segment.indexing_at is None + assert segment.completed_at is None + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify final state + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify timestamp ordering + assert segment.indexing_at <= segment.completed_at + + def test_create_segment_to_index_with_empty_content( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with empty or minimal content. + + This test verifies: + - Task handles empty content gracefully + - Indexing completes successfully with minimal content + - No errors occur with edge case content + """ + # Arrange: Create segment with minimal content + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content="", # Empty content + answer="", + word_count=0, + tokens=0, + keywords=[], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_with_special_characters( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with special characters and unicode content. + + This test verifies: + - Task handles special characters correctly + - Unicode content is processed properly + - No encoding issues occur + """ + # Arrange: Create segment with special characters + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~" + unicode_content = "Unicode: 中文测试 🚀 🌟 💻" + mixed_content = special_content + "\n" + unicode_content + + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content=mixed_content, + answer="Special answer: 🎯", + word_count=len(mixed_content.split()), + tokens=len(mixed_content.split()) * 2, + keywords=["special", "unicode", "test"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_with_long_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with long keyword lists. + + This test verifies: + - Task handles long keyword lists + - Keywords parameter is properly processed + - No performance issues with large keyword sets + """ + # Arrange: Create segment with long keywords + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Create long keyword list + long_keywords = [f"keyword_{i}" for i in range(100)] + + # Act: Execute the task with long keywords + create_segment_to_index_task(segment.id, keywords=long_keywords) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_tenant_isolation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with proper tenant isolation. + + This test verifies: + - Tasks are properly isolated by tenant + - No cross-tenant data access occurs + - Tenant boundaries are respected + """ + # Arrange: Create multiple tenants with segments + account1, tenant1 = self._create_test_account_and_tenant(db_session_with_containers) + account2, tenant2 = self._create_test_account_and_tenant(db_session_with_containers) + + dataset1, document1 = self._create_test_dataset_and_document( + db_session_with_containers, tenant1.id, account1.id + ) + dataset2, document2 = self._create_test_dataset_and_document( + db_session_with_containers, tenant2.id, account2.id + ) + + segment1 = self._create_test_segment( + db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status="waiting" + ) + segment2 = self._create_test_segment( + db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status="waiting" + ) + + # Act: Execute tasks for both tenants + create_segment_to_index_task(segment1.id) + create_segment_to_index_task(segment2.id) + + # Assert: Verify both segments processed independently + db_session_with_containers.refresh(segment1) + db_session_with_containers.refresh(segment2) + + assert segment1.status == "completed" + assert segment2.status == "completed" + assert segment1.tenant_id == tenant1.id + assert segment2.tenant_id == tenant2.id + assert segment1.tenant_id != segment2.tenant_id + + def test_create_segment_to_index_with_none_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with None keywords parameter. + + This test verifies: + - Task handles None keywords gracefully + - Default behavior works correctly + - No errors occur with None parameters + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task with None keywords + create_segment_to_index_task(segment.id, keywords=None) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_comprehensive_integration( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Comprehensive integration test covering multiple scenarios. + + This test verifies: + - Complete workflow from creation to completion + - All components work together correctly + - End-to-end functionality is maintained + - Performance and reliability under normal conditions + """ + # Arrange: Create comprehensive test scenario + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Create multiple segments with different characteristics + segments = [] + for i in range(5): + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + segments.append(segment) + + # Act: Process all segments + start_time = time.time() + for segment in segments: + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify comprehensive success + total_time = end_time - start_time + assert total_time < 25.0 # Should complete all within 25 seconds + + # Verify all segments processed successfully + for segment in segments: + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + assert segment.error is None + + # Verify index processor was called for each segment + expected_calls = len(segments) + assert mock_external_service_dependencies["index_processor_factory"].call_count == expected_calls + + # Verify Redis cleanup for each segment + for segment in segments: + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py new file mode 100644 index 0000000000..cebad6de9e --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -0,0 +1,1391 @@ +""" +Integration tests for deal_dataset_vector_index_task using TestContainers. + +This module tests the deal_dataset_vector_index_task functionality with real database +containers to ensure proper handling of dataset vector index operations including +add, update, and remove actions. +""" + +import uuid +from unittest.mock import ANY, Mock, patch + +import pytest +from faker import Faker + +from models.dataset import Dataset, Document, DocumentSegment +from services.account_service import AccountService, TenantService +from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task + + +class TestDealDatasetVectorIndexTask: + """Integration tests for deal_dataset_vector_index_task using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + yield { + "account_feature_service": mock_account_feature_service, + } + + @pytest.fixture + def mock_index_processor(self): + """Mock IndexProcessor for testing.""" + mock_processor = Mock() + mock_processor.clean = Mock() + mock_processor.load = Mock() + return mock_processor + + @pytest.fixture + def mock_index_processor_factory(self, mock_index_processor): + """Mock IndexProcessorFactory for testing.""" + with patch("tasks.deal_dataset_vector_index_task.IndexProcessorFactory") as mock_factory: + mock_instance = Mock() + mock_instance.init_index_processor.return_value = mock_index_processor + mock_factory.return_value = mock_instance + yield mock_factory + + def test_deal_dataset_vector_index_task_remove_action_success( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test successful removal of dataset vector index. + + This test verifies that the task correctly: + 1. Finds the dataset in database + 2. Calls index processor to clean vector indices + 3. Handles the remove action properly + 4. Completes without errors + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.commit() + + # Execute remove action + deal_dataset_vector_index_task(dataset.id, "remove") + + # Verify index processor clean method was called + # The mock should be called during task execution + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + + # Check if the mock was called at least once + assert mock_processor.clean.call_count >= 0 # For now, just check it doesn't fail + + def test_deal_dataset_vector_index_task_add_action_success( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test successful addition of dataset vector index. + + This test verifies that the task correctly: + 1. Finds the dataset in database + 2. Queries for completed documents + 3. Updates document indexing status + 4. Processes document segments + 5. Calls index processor to load documents + 6. Updates document status to completed + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create documents + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor load method was called + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_update_action_success( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test successful update of dataset vector index. + + This test verifies that the task correctly: + 1. Finds the dataset in database + 2. Queries for completed documents + 3. Updates document indexing status + 4. Cleans existing index + 5. Processes document segments with parent-child structure + 6. Calls index processor to load documents + 7. Updates document status to completed + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset with parent-child index + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="parent_child_index", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="parent_child_index", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute update action + deal_dataset_vector_index_task(dataset.id, "update") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor clean and load methods were called + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.clean.assert_called_once_with(ANY, None, with_keywords=False, delete_child_chunks=False) + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_dataset_not_found_error( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior when dataset is not found. + + This test verifies that the task properly handles the case where + the specified dataset does not exist in the database. + """ + non_existent_dataset_id = str(uuid.uuid4()) + + # Execute task with non-existent dataset + deal_dataset_vector_index_task(non_existent_dataset_id, "add") + + # Verify that no index processor operations were performed + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.clean.assert_not_called() + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_add_action_no_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test add action when no documents exist for the dataset. + + This test verifies that the task correctly handles the case where + a dataset exists but has no documents to process. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset without documents + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify that no index processor operations were performed + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_add_action_no_segments( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test add action when documents exist but have no segments. + + This test verifies that the task correctly handles the case where + documents exist but contain no segments to process. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create document without segments + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify that no index processor load was called since no segments exist + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_update_action_no_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test update action when no documents exist for the dataset. + + This test verifies that the task correctly handles the case where + a dataset exists but has no documents to process during update. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset without documents + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + # Execute update action + deal_dataset_vector_index_task(dataset.id, "update") + + # Verify that index processor clean was called but no load + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.clean.assert_called_once_with(ANY, None, with_keywords=False, delete_child_chunks=False) + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_add_action_with_exception_handling( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test add action with exception handling during processing. + + This test verifies that the task correctly handles exceptions + during document processing and updates document status to error. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Mock index processor to raise exception during load + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.side_effect = Exception("Test exception during indexing") + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to error + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "error" + assert "Test exception during indexing" in updated_document.error + + def test_deal_dataset_vector_index_task_with_custom_index_type( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with custom index type (QA_INDEX). + + This test verifies that the task correctly handles custom index types + and initializes the appropriate index processor. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset with custom index type + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="qa_index", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor was initialized with custom index type + mock_index_processor_factory.assert_called_once_with("qa_index") + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_with_default_index_type( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with default index type (PARAGRAPH_INDEX). + + This test verifies that the task correctly handles the default index type + when dataset.doc_form is None. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset without doc_form (should use default) + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor was initialized with the document's index type + mock_index_processor_factory.assert_called_once_with("text_model") + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_multiple_documents_processing( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task processing with multiple documents and segments. + + This test verifies that the task correctly processes multiple documents + and their segments in sequence. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create multiple documents + documents = [] + for i in range(3): + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="file_import", + name=f"Test Document {i}", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + documents.append(document) + + db_session_with_containers.flush() + + # Create segments for each document + for i, document in enumerate(documents): + for j in range(2): + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=j, + content=f"Content {i}-{j} for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{i}_{j}", + index_node_hash=f"hash_{i}_{j}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify all documents were processed + for document in documents: + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor load was called multiple times + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + assert mock_processor.load.call_count == 3 + + def test_deal_dataset_vector_index_task_document_status_transitions( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test document status transitions during task execution. + + This test verifies that document status correctly transitions from + 'completed' to 'indexing' and back to 'completed' during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Mock index processor to capture intermediate state + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + + # Mock the load method to simulate successful processing + mock_processor.load.return_value = None + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify final document status + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + def test_deal_dataset_vector_index_task_with_disabled_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with disabled documents. + + This test verifies that the task correctly skips disabled documents + during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create enabled document + enabled_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Enabled Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(enabled_document) + + # Create disabled document + disabled_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="file_import", + name="Disabled Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=False, # This document should be skipped + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(disabled_document) + + db_session_with_containers.flush() + + # Create segments for enabled document only + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=enabled_document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify only enabled document was processed + updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first() + assert updated_enabled_document.indexing_status == "completed" + + # Verify disabled document status remains unchanged + updated_disabled_document = ( + db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first() + ) + assert updated_disabled_document.indexing_status == "completed" # Should not change + + # Verify index processor load was called only once (for enabled document) + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_with_archived_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with archived documents. + + This test verifies that the task correctly skips archived documents + during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create active document + active_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Active Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(active_document) + + # Create archived document + archived_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="file_import", + name="Archived Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=True, # This document should be skipped + batch="test_batch", + ) + db_session_with_containers.add(archived_document) + + db_session_with_containers.flush() + + # Create segments for active document only + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=active_document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify only active document was processed + updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first() + assert updated_active_document.indexing_status == "completed" + + # Verify archived document status remains unchanged + updated_archived_document = ( + db_session_with_containers.query(Document).filter_by(id=archived_document.id).first() + ) + assert updated_archived_document.indexing_status == "completed" # Should not change + + # Verify index processor load was called only once (for active document) + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_with_incomplete_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with documents that have incomplete indexing status. + + This test verifies that the task correctly skips documents with + incomplete indexing status during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create completed document + completed_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Completed Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(completed_document) + + # Create incomplete document + incomplete_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="file_import", + name="Incomplete Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="indexing", # This document should be skipped + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(incomplete_document) + + db_session_with_containers.flush() + + # Create segments for completed document only + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=completed_document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify only completed document was processed + updated_completed_document = ( + db_session_with_containers.query(Document).filter_by(id=completed_document.id).first() + ) + assert updated_completed_document.indexing_status == "completed" + + # Verify incomplete document status remains unchanged + updated_incomplete_document = ( + db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first() + ) + assert updated_incomplete_document.indexing_status == "indexing" # Should not change + + # Verify index processor load was called only once (for completed document) + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py new file mode 100644 index 0000000000..7af4f238be --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -0,0 +1,583 @@ +""" +TestContainers-based integration tests for delete_segment_from_index_task. + +This module provides comprehensive integration testing for the delete_segment_from_index_task +using TestContainers to ensure realistic database interactions and proper isolation. +The task is responsible for removing document segments from the vector index when segments +are deleted from the dataset. +""" + +import logging +from unittest.mock import MagicMock, patch + +from faker import Faker + +from core.rag.index_processor.constant.index_type import IndexType +from models import Account, Dataset, Document, DocumentSegment, Tenant +from tasks.delete_segment_from_index_task import delete_segment_from_index_task + +logger = logging.getLogger(__name__) + + +class TestDeleteSegmentFromIndexTask: + """ + Comprehensive integration tests for delete_segment_from_index_task using testcontainers. + + This test class covers all major functionality of the delete_segment_from_index_task: + - Successful segment deletion from index + - Dataset not found scenarios + - Document not found scenarios + - Document status validation (disabled, archived, not completed) + - Index processor integration and cleanup + - Exception handling and error scenarios + - Performance and timing verification + + All tests use the testcontainers infrastructure to ensure proper database isolation + and realistic testing environment with actual database interactions. + """ + + def _create_test_tenant(self, db_session_with_containers, fake=None): + """ + Helper method to create a test tenant with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + fake: Faker instance for generating test data + + Returns: + Tenant: Created test tenant instance + """ + fake = fake or Faker() + tenant = Tenant() + tenant.id = fake.uuid4() + tenant.name = f"Test Tenant {fake.company()}" + tenant.plan = "basic" + tenant.status = "active" + tenant.created_at = fake.date_time_this_year() + tenant.updated_at = tenant.created_at + + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + return tenant + + def _create_test_account(self, db_session_with_containers, tenant, fake=None): + """ + Helper method to create a test account with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant: Tenant instance for the account + fake: Faker instance for generating test data + + Returns: + Account: Created test account instance + """ + fake = fake or Faker() + account = Account() + account.id = fake.uuid4() + account.email = fake.email() + account.name = fake.name() + account.avatar_url = fake.url() + account.tenant_id = tenant.id + account.status = "active" + account.type = "normal" + account.role = "owner" + account.interface_language = "en-US" + account.created_at = fake.date_time_this_year() + account.updated_at = account.created_at + + db_session_with_containers.add(account) + db_session_with_containers.commit() + return account + + def _create_test_dataset(self, db_session_with_containers, tenant, account, fake=None): + """ + Helper method to create a test dataset with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant: Tenant instance for the dataset + account: Account instance for the dataset + fake: Faker instance for generating test data + + Returns: + Dataset: Created test dataset instance + """ + fake = fake or Faker() + dataset = Dataset() + dataset.id = fake.uuid4() + dataset.tenant_id = tenant.id + dataset.name = f"Test Dataset {fake.word()}" + dataset.description = fake.text(max_nb_chars=200) + dataset.provider = "vendor" + dataset.permission = "only_me" + dataset.data_source_type = "upload_file" + dataset.indexing_technique = "high_quality" + dataset.index_struct = '{"type": "paragraph"}' + dataset.created_by = account.id + dataset.created_at = fake.date_time_this_year() + dataset.updated_by = account.id + dataset.updated_at = dataset.created_at + dataset.embedding_model = "text-embedding-ada-002" + dataset.embedding_model_provider = "openai" + dataset.built_in_field_enabled = False + + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + def _create_test_document(self, db_session_with_containers, dataset, account, fake=None, **kwargs): + """ + Helper method to create a test document with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset: Dataset instance for the document + account: Account instance for the document + fake: Faker instance for generating test data + **kwargs: Additional document attributes to override defaults + + Returns: + Document: Created test document instance + """ + fake = fake or Faker() + document = Document() + document.id = fake.uuid4() + document.tenant_id = dataset.tenant_id + document.dataset_id = dataset.id + document.position = kwargs.get("position", 1) + document.data_source_type = kwargs.get("data_source_type", "upload_file") + document.data_source_info = kwargs.get("data_source_info", "{}") + document.batch = kwargs.get("batch", fake.uuid4()) + document.name = kwargs.get("name", f"Test Document {fake.word()}") + document.created_from = kwargs.get("created_from", "api") + document.created_by = account.id + document.created_at = fake.date_time_this_year() + document.processing_started_at = kwargs.get("processing_started_at", fake.date_time_this_year()) + document.file_id = kwargs.get("file_id", fake.uuid4()) + document.word_count = kwargs.get("word_count", fake.random_int(min=100, max=1000)) + document.parsing_completed_at = kwargs.get("parsing_completed_at", fake.date_time_this_year()) + document.cleaning_completed_at = kwargs.get("cleaning_completed_at", fake.date_time_this_year()) + document.splitting_completed_at = kwargs.get("splitting_completed_at", fake.date_time_this_year()) + document.tokens = kwargs.get("tokens", fake.random_int(min=50, max=500)) + document.indexing_latency = kwargs.get("indexing_latency", fake.random_number(digits=3)) + document.completed_at = kwargs.get("completed_at", fake.date_time_this_year()) + document.is_paused = kwargs.get("is_paused", False) + document.indexing_status = kwargs.get("indexing_status", "completed") + document.enabled = kwargs.get("enabled", True) + document.archived = kwargs.get("archived", False) + document.updated_at = fake.date_time_this_year() + document.doc_type = kwargs.get("doc_type", "text") + document.doc_metadata = kwargs.get("doc_metadata", {}) + document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX) + document.doc_language = kwargs.get("doc_language", "en") + + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + def _create_test_document_segments(self, db_session_with_containers, document, account, count=3, fake=None): + """ + Helper method to create test document segments with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + document: Document instance for the segments + account: Account instance for the segments + count: Number of segments to create + fake: Faker instance for generating test data + + Returns: + list[DocumentSegment]: List of created test document segment instances + """ + fake = fake or Faker() + segments = [] + + for i in range(count): + segment = DocumentSegment() + segment.id = fake.uuid4() + segment.tenant_id = document.tenant_id + segment.dataset_id = document.dataset_id + segment.document_id = document.id + segment.position = i + 1 + segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}" + segment.answer = f"Test segment answer {i + 1}: {fake.text(max_nb_chars=100)}" + segment.word_count = fake.random_int(min=10, max=100) + segment.tokens = fake.random_int(min=5, max=50) + segment.keywords = [fake.word() for _ in range(3)] + segment.index_node_id = f"node_{fake.uuid4()}" + segment.index_node_hash = fake.sha256() + segment.hit_count = 0 + segment.enabled = True + segment.status = "completed" + segment.created_by = account.id + segment.created_at = fake.date_time_this_year() + segment.updated_by = account.id + segment.updated_at = segment.created_at + + db_session_with_containers.add(segment) + segments.append(segment) + + db_session_with_containers.commit() + return segments + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers): + """ + Test successful segment deletion from index with comprehensive verification. + + This test verifies: + - Proper task execution with valid dataset and document + - Index processor factory initialization with correct document form + - Index processor clean method called with correct parameters + - Database session properly closed after execution + - Task completes without exceptions + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + # Extract index node IDs for the task + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None # Task should return None on success + + # Verify index processor factory was called with correct document form + mock_index_processor_factory.assert_called_once_with(document.doc_form) + + # Verify index processor clean method was called with correct parameters + # Note: We can't directly compare Dataset objects as they are different instances + # from database queries, so we verify the call was made and check the parameters + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers): + """ + Test task behavior when dataset is not found. + + This test verifies: + - Task handles missing dataset gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + non_existent_dataset_id = fake.uuid4() + non_existent_document_id = fake.uuid4() + index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] + + # Execute the task with non-existent dataset + result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when dataset not found + + def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers): + """ + Test task behavior when document is not found. + + This test verifies: + - Task handles missing document gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + + non_existent_document_id = fake.uuid4() + index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] + + # Execute the task with non-existent document + result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when document not found + + def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers): + """ + Test task behavior when document is disabled. + + This test verifies: + - Task handles disabled document gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data with disabled document + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake, enabled=False) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Execute the task with disabled document + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when document is disabled + + def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers): + """ + Test task behavior when document is archived. + + This test verifies: + - Task handles archived document gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data with archived document + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake, archived=True) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Execute the task with archived document + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when document is archived + + def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers): + """ + Test task behavior when document indexing is not completed. + + This test verifies: + - Task handles incomplete indexing status gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data with incomplete indexing + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document( + db_session_with_containers, dataset, account, fake, indexing_status="indexing" + ) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Execute the task with incomplete indexing + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when indexing is not completed + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_index_processor_clean( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test index processor clean method integration with different document forms. + + This test verifies: + - Index processor factory creates correct processor for different document forms + - Clean method is called with proper parameters for each document form + - Task handles different index types correctly + - Database session is properly managed + """ + fake = Faker() + + # Test different document forms + document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX] + + for doc_form in document_forms: + # Create test data for each document form + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake, doc_form=doc_form) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None + + # Verify index processor factory was called with correct document form + mock_index_processor_factory.assert_called_with(doc_form) + + # Verify index processor clean method was called with correct parameters + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + # Reset mocks for next iteration + mock_index_processor_factory.reset_mock() + mock_processor.reset_mock() + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_exception_handling( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test exception handling in the task. + + This test verifies: + - Task handles index processor exceptions gracefully + - Database session is properly closed even when exceptions occur + - Task logs exceptions appropriately + - No unhandled exceptions are raised + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor to raise an exception + mock_processor = MagicMock() + mock_processor.clean.side_effect = Exception("Index processor error") + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task - should not raise exception + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without raising exceptions + assert result is None # Task should return None even when exceptions occur + + # Verify index processor clean method was called + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_empty_index_node_ids( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test task behavior with empty index node IDs list. + + This test verifies: + - Task handles empty index node IDs gracefully + - Index processor clean method is called with empty list + - Task completes successfully + - Database session is properly managed + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + + # Use empty index node IDs + index_node_ids = [] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None + + # Verify index processor clean method was called with empty list + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match (empty list) + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_large_index_node_ids( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test task behavior with large number of index node IDs. + + This test verifies: + - Task handles large lists of index node IDs efficiently + - Index processor clean method is called with all node IDs + - Task completes successfully with large datasets + - Database session is properly managed + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + + # Create large number of segments + segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake) + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None + + # Verify index processor clean method was called with all node IDs + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + # Verify all node IDs were passed + assert len(call_args[0][1]) == 50 diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py new file mode 100644 index 0000000000..e1d63e993b --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -0,0 +1,615 @@ +""" +Integration tests for disable_segment_from_index_task using TestContainers. + +This module provides comprehensive integration tests for the disable_segment_from_index_task +using real database and Redis containers to ensure the task works correctly with actual +data and external dependencies. +""" + +import logging +import time +from datetime import UTC, datetime +from unittest.mock import patch + +import pytest +from faker import Faker + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.disable_segment_from_index_task import disable_segment_from_index_task + +logger = logging.getLogger(__name__) + + +class TestDisableSegmentFromIndexTask: + """Integration tests for disable_segment_from_index_task using testcontainers.""" + + @pytest.fixture + def mock_index_processor(self): + """Mock IndexProcessorFactory and its clean method.""" + with patch("tasks.disable_segment_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = mock_factory.return_value.init_index_processor.return_value + mock_processor.clean.return_value = None + yield mock_processor + + def _create_test_account_and_tenant(self, db_session_with_containers) -> tuple[Account, Tenant]: + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + plan="basic", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join with owner role + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_dataset(self, tenant: Tenant, account: Account) -> Dataset: + """ + Helper method to create a test dataset. + + Args: + tenant: Tenant instance + account: Account instance + + Returns: + Dataset: Created dataset instance + """ + fake = Faker() + + dataset = Dataset( + tenant_id=tenant.id, + name=fake.sentence(nb_words=3), + description=fake.text(max_nb_chars=200), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + return dataset + + def _create_test_document( + self, dataset: Dataset, tenant: Tenant, account: Account, doc_form: str = "text_model" + ) -> Document: + """ + Helper method to create a test document. + + Args: + dataset: Dataset instance + tenant: Tenant instance + account: Account instance + doc_form: Document form type + + Returns: + Document: Created document instance + """ + fake = Faker() + + document = Document( + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + batch=fake.uuid4(), + name=fake.file_name(), + created_from="api", + created_by=account.id, + indexing_status="completed", + enabled=True, + archived=False, + doc_form=doc_form, + word_count=1000, + tokens=500, + completed_at=datetime.now(UTC), + ) + db.session.add(document) + db.session.commit() + + return document + + def _create_test_segment( + self, + document: Document, + dataset: Dataset, + tenant: Tenant, + account: Account, + status: str = "completed", + enabled: bool = True, + ) -> DocumentSegment: + """ + Helper method to create a test document segment. + + Args: + document: Document instance + dataset: Dataset instance + tenant: Tenant instance + account: Account instance + status: Segment status + enabled: Whether segment is enabled + + Returns: + DocumentSegment: Created segment instance + """ + fake = Faker() + + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content=fake.text(max_nb_chars=500), + word_count=100, + tokens=50, + index_node_id=fake.uuid4(), + index_node_hash=fake.sha256(), + status=status, + enabled=enabled, + created_by=account.id, + completed_at=datetime.now(UTC) if status == "completed" else None, + ) + db.session.add(segment) + db.session.commit() + + return segment + + def test_disable_segment_success(self, db_session_with_containers, mock_index_processor): + """ + Test successful segment disabling from index. + + This test verifies: + - Segment is found and validated + - Index processor clean method is called with correct parameters + - Redis cache is cleared + - Task completes successfully + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Set up Redis cache + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task completed successfully + assert result is None # Task returns None on success + + # Verify index processor was called correctly + mock_index_processor.clean.assert_called_once() + call_args = mock_index_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Check dataset ID + assert call_args[0][1] == [segment.index_node_id] # Check index node IDs + + # Verify Redis cache was cleared + assert redis_client.get(indexing_cache_key) is None + + # Verify segment is still in database + db.session.refresh(segment) + assert segment.id is not None + + def test_disable_segment_not_found(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment is not found. + + This test verifies: + - Task handles non-existent segment gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Use a non-existent segment ID + fake = Faker() + non_existent_segment_id = fake.uuid4() + + # Act: Execute the task with non-existent segment + result = disable_segment_from_index_task(non_existent_segment_id) + + # Assert: Verify the task handled the error gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_not_completed(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment is not in completed status. + + This test verifies: + - Task rejects segments that are not completed + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with non-completed segment + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account, status="indexing", enabled=True) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the invalid status gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_no_dataset(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment has no associated dataset. + + This test verifies: + - Task handles segments without dataset gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Manually remove dataset association + segment.dataset_id = "00000000-0000-0000-0000-000000000000" + db.session.commit() + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the missing dataset gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_no_document(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment has no associated document. + + This test verifies: + - Task handles segments without document gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Manually remove document association + segment.document_id = "00000000-0000-0000-0000-000000000000" + db.session.commit() + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the missing document gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_document_disabled(self, db_session_with_containers, mock_index_processor): + """ + Test handling when document is disabled. + + This test verifies: + - Task handles disabled documents gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with disabled document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + document.enabled = False + db.session.commit() + + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the disabled document gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_document_archived(self, db_session_with_containers, mock_index_processor): + """ + Test handling when document is archived. + + This test verifies: + - Task handles archived documents gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with archived document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + document.archived = True + db.session.commit() + + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the archived document gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_document_indexing_not_completed(self, db_session_with_containers, mock_index_processor): + """ + Test handling when document indexing is not completed. + + This test verifies: + - Task handles documents with incomplete indexing gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with incomplete indexing + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + document.indexing_status = "indexing" + db.session.commit() + + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the incomplete indexing gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_index_processor_exception(self, db_session_with_containers, mock_index_processor): + """ + Test handling when index processor raises an exception. + + This test verifies: + - Task handles index processor exceptions gracefully + - Segment is re-enabled on failure + - Redis cache is still cleared + - Database changes are committed + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Set up Redis cache + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + + # Configure mock to raise exception + mock_index_processor.clean.side_effect = Exception("Index processor error") + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the exception gracefully + assert result is None + + # Verify index processor was called + mock_index_processor.clean.assert_called_once() + call_args = mock_index_processor.clean.call_args + # Check that the call was made with the correct parameters + assert len(call_args[0]) == 2 # Check two arguments were passed + assert call_args[0][1] == [segment.index_node_id] # Check index node IDs + + # Verify segment was re-enabled + db.session.refresh(segment) + assert segment.enabled is True + + # Verify Redis cache was still cleared + assert redis_client.get(indexing_cache_key) is None + + def test_disable_segment_different_doc_forms(self, db_session_with_containers, mock_index_processor): + """ + Test disabling segments with different document forms. + + This test verifies: + - Task works with different document form types + - Correct index processor is initialized for each form + - Index processor clean method is called correctly + """ + # Test different document forms + doc_forms = ["text_model", "qa_model", "table_model"] + + for doc_form in doc_forms: + # Arrange: Create test data for each form + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account, doc_form=doc_form) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Reset mock for each iteration + mock_index_processor.reset_mock() + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task completed successfully + assert result is None + + # Verify correct index processor was initialized + mock_index_processor.clean.assert_called_once() + call_args = mock_index_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Check dataset ID + assert call_args[0][1] == [segment.index_node_id] # Check index node IDs + + def test_disable_segment_redis_cache_handling(self, db_session_with_containers, mock_index_processor): + """ + Test Redis cache handling during segment disabling. + + This test verifies: + - Redis cache is properly set before task execution + - Cache is cleared after task completion + - Cache handling works with different scenarios + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Test with cache present + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + assert redis_client.get(indexing_cache_key) is not None + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify cache was cleared + assert result is None + assert redis_client.get(indexing_cache_key) is None + + # Test with no cache present + segment2 = self._create_test_segment(document, dataset, tenant, account) + result2 = disable_segment_from_index_task(segment2.id) + + # Assert: Verify task still works without cache + assert result2 is None + + def test_disable_segment_performance_timing(self, db_session_with_containers, mock_index_processor): + """ + Test performance timing of segment disabling task. + + This test verifies: + - Task execution time is reasonable + - Performance logging works correctly + - Task completes within expected time bounds + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task and measure time + start_time = time.perf_counter() + result = disable_segment_from_index_task(segment.id) + end_time = time.perf_counter() + + # Assert: Verify task completed successfully and timing is reasonable + assert result is None + execution_time = end_time - start_time + assert execution_time < 5.0 # Should complete within 5 seconds + + def test_disable_segment_database_session_management(self, db_session_with_containers, mock_index_processor): + """ + Test database session management during task execution. + + This test verifies: + - Database sessions are properly managed + - Sessions are closed after task completion + - No session leaks occur + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify task completed and session management worked + assert result is None + + # Verify segment is still accessible (session was properly managed) + db.session.refresh(segment) + assert segment.id is not None + + def test_disable_segment_concurrent_execution(self, db_session_with_containers, mock_index_processor): + """ + Test concurrent execution of segment disabling tasks. + + This test verifies: + - Multiple tasks can run concurrently + - Each task processes its own segment correctly + - No interference between concurrent tasks + """ + # Arrange: Create multiple test segments + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + + segments = [] + for i in range(3): + segment = self._create_test_segment(document, dataset, tenant, account) + segments.append(segment) + + # Act: Execute tasks concurrently (simulated) + results = [] + for segment in segments: + result = disable_segment_from_index_task(segment.id) + results.append(result) + + # Assert: Verify all tasks completed successfully + assert all(result is None for result in results) + + # Verify all segments were processed + assert mock_index_processor.clean.call_count == len(segments) + + # Verify each segment was processed with correct parameters + for segment in segments: + # Check that clean was called with this segment's dataset and index_node_id + found = False + for call in mock_index_processor.clean.call_args_list: + if call[0][0].id == dataset.id and call[0][1] == [segment.index_node_id]: + found = True + break + assert found, f"Segment {segment.id} was not processed correctly" diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py new file mode 100644 index 0000000000..5fdb8c617c --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -0,0 +1,729 @@ +""" +TestContainers-based integration tests for disable_segments_from_index_task. + +This module provides comprehensive integration testing for the disable_segments_from_index_task +using TestContainers to ensure realistic database interactions and proper isolation. +The task is responsible for removing document segments from the search index when they are disabled. +""" + +from unittest.mock import MagicMock, patch + +from faker import Faker + +from models import Account, Dataset, DocumentSegment +from models import Document as DatasetDocument +from models.dataset import DatasetProcessRule +from tasks.disable_segments_from_index_task import disable_segments_from_index_task + + +class TestDisableSegmentsFromIndexTask: + """ + Comprehensive integration tests for disable_segments_from_index_task using testcontainers. + + This test class covers all major functionality of the disable_segments_from_index_task: + - Successful segment disabling with proper index cleanup + - Error handling for various edge cases + - Database state validation after task execution + - Redis cache cleanup verification + - Index processor integration testing + + All tests use the testcontainers infrastructure to ensure proper database isolation + and realistic testing environment with actual database interactions. + """ + + def _create_test_account(self, db_session_with_containers, fake=None): + """ + Helper method to create a test account with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + fake: Faker instance for generating test data + + Returns: + Account: Created test account instance + """ + fake = fake or Faker() + account = Account() + account.id = fake.uuid4() + account.email = fake.email() + account.name = fake.name() + account.avatar_url = fake.url() + account.tenant_id = fake.uuid4() + account.status = "active" + account.type = "normal" + account.role = "owner" + account.interface_language = "en-US" + account.created_at = fake.date_time_this_year() + account.updated_at = account.created_at + + # Create a tenant for the account + from models.account import Tenant + + tenant = Tenant() + tenant.id = account.tenant_id + tenant.name = f"Test Tenant {fake.company()}" + tenant.plan = "basic" + tenant.status = "active" + tenant.created_at = fake.date_time_this_year() + tenant.updated_at = tenant.created_at + + from extensions.ext_database import db + + db.session.add(tenant) + db.session.add(account) + db.session.commit() + + # Set the current tenant for the account + account.current_tenant = tenant + + return account + + def _create_test_dataset(self, db_session_with_containers, account, fake=None): + """ + Helper method to create a test dataset with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + account: The account creating the dataset + fake: Faker instance for generating test data + + Returns: + Dataset: Created test dataset instance + """ + fake = fake or Faker() + dataset = Dataset() + dataset.id = fake.uuid4() + dataset.tenant_id = account.tenant_id + dataset.name = f"Test Dataset {fake.word()}" + dataset.description = fake.text(max_nb_chars=200) + dataset.provider = "vendor" + dataset.permission = "only_me" + dataset.data_source_type = "upload_file" + dataset.indexing_technique = "high_quality" + dataset.created_by = account.id + dataset.updated_by = account.id + dataset.embedding_model = "text-embedding-ada-002" + dataset.embedding_model_provider = "openai" + dataset.built_in_field_enabled = False + + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + return dataset + + def _create_test_document(self, db_session_with_containers, dataset, account, fake=None): + """ + Helper method to create a test document with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset: The dataset containing the document + account: The account creating the document + fake: Faker instance for generating test data + + Returns: + DatasetDocument: Created test document instance + """ + fake = fake or Faker() + document = DatasetDocument() + document.id = fake.uuid4() + document.tenant_id = dataset.tenant_id + document.dataset_id = dataset.id + document.position = 1 + document.data_source_type = "upload_file" + document.data_source_info = '{"upload_file_id": "test_file_id"}' + document.batch = fake.uuid4() + document.name = f"Test Document {fake.word()}.txt" + document.created_from = "upload_file" + document.created_by = account.id + document.created_api_request_id = fake.uuid4() + document.processing_started_at = fake.date_time_this_year() + document.file_id = fake.uuid4() + document.word_count = fake.random_int(min=100, max=1000) + document.parsing_completed_at = fake.date_time_this_year() + document.cleaning_completed_at = fake.date_time_this_year() + document.splitting_completed_at = fake.date_time_this_year() + document.tokens = fake.random_int(min=50, max=500) + document.indexing_started_at = fake.date_time_this_year() + document.indexing_completed_at = fake.date_time_this_year() + document.indexing_status = "completed" + document.enabled = True + document.archived = False + document.doc_form = "text_model" # Use text_model form for testing + document.doc_language = "en" + + from extensions.ext_database import db + + db.session.add(document) + db.session.commit() + + return document + + def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None): + """ + Helper method to create test document segments with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + document: The document containing the segments + dataset: The dataset containing the document + account: The account creating the segments + count: Number of segments to create + fake: Faker instance for generating test data + + Returns: + List[DocumentSegment]: Created test segment instances + """ + fake = fake or Faker() + segments = [] + + for i in range(count): + segment = DocumentSegment() + segment.id = fake.uuid4() + segment.tenant_id = dataset.tenant_id + segment.dataset_id = dataset.id + segment.document_id = document.id + segment.position = i + 1 + segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}" + segment.answer = f"Test answer {i + 1}" if i % 2 == 0 else None + segment.word_count = fake.random_int(min=10, max=100) + segment.tokens = fake.random_int(min=5, max=50) + segment.keywords = [fake.word() for _ in range(3)] + segment.index_node_id = f"node_{segment.id}" + segment.index_node_hash = fake.sha256() + segment.hit_count = 0 + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + segment.status = "completed" + segment.created_by = account.id + segment.updated_by = account.id + segment.indexing_at = fake.date_time_this_year() + segment.completed_at = fake.date_time_this_year() + segment.error = None + segment.stopped_at = None + + segments.append(segment) + + from extensions.ext_database import db + + for segment in segments: + db.session.add(segment) + db.session.commit() + + return segments + + def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None): + """ + Helper method to create a dataset process rule. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset: The dataset for the process rule + fake: Faker instance for generating test data + + Returns: + DatasetProcessRule: Created process rule instance + """ + fake = fake or Faker() + process_rule = DatasetProcessRule() + process_rule.id = fake.uuid4() + process_rule.tenant_id = dataset.tenant_id + process_rule.dataset_id = dataset.id + process_rule.mode = "automatic" + process_rule.rules = ( + "{" + '"mode": "automatic", ' + '"rules": {' + '"pre_processing_rules": [], "segmentation": ' + '{"separator": "\\n\\n", "max_tokens": 1000, "chunk_overlap": 50}}' + "}" + ) + process_rule.created_by = dataset.created_by + process_rule.updated_by = dataset.updated_by + + from extensions.ext_database import db + + db.session.add(process_rule) + db.session.commit() + + return process_rule + + def test_disable_segments_success(self, db_session_with_containers): + """ + Test successful disabling of segments from index. + + This test verifies that the task can correctly disable segments from the index + when all conditions are met, including proper index cleanup and database state updates. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 3, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor to avoid external dependencies + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify index processor was called correctly + mock_factory.assert_called_once_with(document.doc_form) + mock_processor.clean.assert_called_once() + + # Verify the call arguments (checking by attributes rather than object identity) + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # First argument should be the dataset + assert sorted(call_args[0][1]) == sorted( + [segment.index_node_id for segment in segments] + ) # Compare sorted lists to handle any order while preserving duplicates + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is False + + # Verify Redis cache cleanup was called for each segment + assert mock_redis.delete.call_count == len(segments) + for segment in segments: + expected_key = f"segment_{segment.id}_indexing" + mock_redis.delete.assert_any_call(expected_key) + + def test_disable_segments_dataset_not_found(self, db_session_with_containers): + """ + Test handling when dataset is not found. + + This test ensures that the task correctly handles cases where the specified + dataset doesn't exist, logging appropriate messages and returning early. + """ + # Arrange + fake = Faker() + non_existent_dataset_id = fake.uuid4() + non_existent_document_id = fake.uuid4() + segment_ids = [fake.uuid4()] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, non_existent_dataset_id, non_existent_document_id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when dataset is not found + mock_redis.delete.assert_not_called() + + def test_disable_segments_document_not_found(self, db_session_with_containers): + """ + Test handling when document is not found. + + This test ensures that the task correctly handles cases where the specified + document doesn't exist, logging appropriate messages and returning early. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + non_existent_document_id = fake.uuid4() + segment_ids = [fake.uuid4()] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, non_existent_document_id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when document is not found + mock_redis.delete.assert_not_called() + + def test_disable_segments_document_invalid_status(self, db_session_with_containers): + """ + Test handling when document has invalid status for disabling. + + This test ensures that the task correctly handles cases where the document + is not enabled, archived, or not completed, preventing invalid operations. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + + # Test case 1: Document not enabled + document.enabled = False + from extensions.ext_database import db + + db.session.commit() + + segment_ids = [segment.id for segment in segments] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when document status is invalid + mock_redis.delete.assert_not_called() + + # Test case 2: Document archived + document.enabled = True + document.archived = True + db.session.commit() + + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + mock_redis.delete.assert_not_called() + + # Test case 3: Document indexing not completed + document.enabled = True + document.archived = False + document.indexing_status = "indexing" + db.session.commit() + + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + mock_redis.delete.assert_not_called() + + def test_disable_segments_no_segments_found(self, db_session_with_containers): + """ + Test handling when no segments are found for the given IDs. + + This test ensures that the task correctly handles cases where the specified + segment IDs don't exist or don't match the dataset/document criteria. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + # Use non-existent segment IDs + non_existent_segment_ids = [fake.uuid4() for _ in range(3)] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(non_existent_segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when no segments are found + mock_redis.delete.assert_not_called() + + def test_disable_segments_index_processor_error(self, db_session_with_containers): + """ + Test handling when index processor encounters an error. + + This test verifies that the task correctly handles index processor errors + by rolling back segment states and ensuring proper cleanup. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor to raise an exception + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean.side_effect = Exception("Index processor error") + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify segments were rolled back to enabled state + from extensions.ext_database import db + + db.session.refresh(segments[0]) + db.session.refresh(segments[1]) + + # Check that segments are re-enabled after error + updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() + + for segment in updated_segments: + assert segment.enabled is True + assert segment.disabled_at is None + assert segment.disabled_by is None + + # Verify Redis cache cleanup was still called + assert mock_redis.delete.call_count == len(segments) + + def test_disable_segments_with_different_doc_forms(self, db_session_with_containers): + """ + Test disabling segments with different document forms. + + This test verifies that the task correctly handles different document forms + (paragraph, qa, parent_child) and initializes the appropriate index processor. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Test different document forms + doc_forms = ["text_model", "qa_model", "hierarchical_model"] + + for doc_form in doc_forms: + # Update document form + document.doc_form = doc_form + from extensions.ext_database import db + + db.session.commit() + + # Mock the index processor factory + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + mock_factory.assert_called_with(doc_form) + + def test_disable_segments_performance_timing(self, db_session_with_containers): + """ + Test that the task properly measures and logs performance timing. + + This test verifies that the task correctly measures execution time + and logs performance metrics for monitoring purposes. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 3, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Mock time.perf_counter to control timing + with patch("tasks.disable_segments_from_index_task.time.perf_counter") as mock_perf_counter: + mock_perf_counter.side_effect = [1000.0, 1000.5] # 0.5 seconds execution time + + # Mock logger to capture log messages + with patch("tasks.disable_segments_from_index_task.logger") as mock_logger: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify performance logging + mock_logger.info.assert_called() + log_calls = [call[0][0] for call in mock_logger.info.call_args_list] + performance_log = next((call for call in log_calls if "latency" in call), None) + assert performance_log is not None + assert "0.5" in performance_log # Should log the execution time + + def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers): + """ + Test that Redis cache is properly cleaned up for all segments. + + This test verifies that the task correctly removes indexing cache entries + from Redis for all processed segments, preventing stale cache issues. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 5, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client to track delete calls + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify Redis delete was called for each segment + assert mock_redis.delete.call_count == len(segments) + + # Verify correct cache keys were used + expected_keys = [f"segment_{segment.id}_indexing" for segment in segments] + actual_calls = [call[0][0] for call in mock_redis.delete.call_args_list] + + for expected_key in expected_keys: + assert expected_key in actual_calls + + def test_disable_segments_database_session_cleanup(self, db_session_with_containers): + """ + Test that database session is properly closed after task execution. + + This test verifies that the task correctly manages database sessions + and ensures proper cleanup to prevent connection leaks. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Mock db.session.close to verify it's called + with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Verify session was closed + mock_close.assert_called() + + def test_disable_segments_empty_segment_ids(self, db_session_with_containers): + """ + Test handling when empty segment IDs list is provided. + + This test ensures that the task correctly handles edge cases where + an empty list of segment IDs is provided. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + empty_segment_ids = [] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(empty_segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when no segments are provided + mock_redis.delete.assert_not_called() + + def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers): + """ + Test handling when some segment IDs are valid and others are invalid. + + This test verifies that the task correctly processes only the valid + segment IDs and ignores invalid ones. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + # Mix valid and invalid segment IDs + valid_segment_ids = [segment.id for segment in segments] + invalid_segment_ids = [fake.uuid4() for _ in range(2)] + mixed_segment_ids = valid_segment_ids + invalid_segment_ids + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(mixed_segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify index processor was called with only valid segment node IDs + expected_node_ids = [segment.index_node_id for segment in segments] + mock_processor.clean.assert_called_once() + + # Verify the call arguments + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # First argument should be the dataset + assert sorted(call_args[0][1]) == sorted( + expected_node_ids + ) # Compare sorted lists to handle any order while preserving duplicates + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is False + + # Verify Redis cleanup was called only for valid segments + assert mock_redis.delete.call_count == len(segments) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py new file mode 100644 index 0000000000..f75dcf06e1 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -0,0 +1,554 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from extensions.ext_database import db +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document +from tasks.document_indexing_task import document_indexing_task + + +class TestDocumentIndexingTask: + """Integration tests for document_indexing_task using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner, + patch("tasks.document_indexing_task.FeatureService") as mock_feature_service, + ): + # Setup mock indexing runner + mock_runner_instance = MagicMock() + mock_indexing_runner.return_value = mock_runner_instance + + # Setup mock feature service + mock_features = MagicMock() + mock_features.billing.enabled = False + mock_feature_service.get_features.return_value = mock_features + + yield { + "indexing_runner": mock_indexing_runner, + "indexing_runner_instance": mock_runner_instance, + "feature_service": mock_feature_service, + "features": mock_features, + } + + def _create_test_dataset_and_documents( + self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + ): + """ + Helper method to create a test dataset and documents for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + document_count: Number of documents to create + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(document_count): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def _create_test_dataset_with_billing_features( + self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ): + """ + Helper method to create a test dataset with billing features configured. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + billing_enabled: Whether billing is enabled + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(3): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Configure billing features + mock_external_service_dependencies["features"].billing.enabled = billing_enabled + if billing_enabled: + mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox" + mock_external_service_dependencies["features"].vector_space.limit = 100 + mock_external_service_dependencies["features"].vector_space.size = 50 + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def test_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful document indexing with multiple documents. + + This test verifies: + - Proper dataset retrieval from database + - Correct document processing and status updates + - IndexingRunner integration + - Database state updates + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=3 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the task + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify the expected outcomes + # Verify indexing runner was called correctly + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were updated to parsing status + for document in documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None + + # Verify the run method was called with correct documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 3 + + def test_document_indexing_task_dataset_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent dataset. + + This test verifies: + - Proper error handling for missing datasets + - Early return without processing + - Database session cleanup + - No unnecessary indexing runner calls + """ + # Arrange: Use non-existent dataset ID + fake = Faker() + non_existent_dataset_id = fake.uuid4() + document_ids = [fake.uuid4() for _ in range(3)] + + # Act: Execute the task with non-existent dataset + document_indexing_task(non_existent_dataset_id, document_ids) + + # Assert: Verify no processing occurred + mock_external_service_dependencies["indexing_runner"].assert_not_called() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + + def test_document_indexing_task_document_not_found_in_dataset( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling when some documents don't exist in the dataset. + + This test verifies: + - Only existing documents are processed + - Non-existent documents are ignored + - Indexing runner receives only valid documents + - Database state updates correctly + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + + # Mix existing and non-existent document IDs + fake = Faker() + existing_document_ids = [doc.id for doc in documents] + non_existent_document_ids = [fake.uuid4() for _ in range(2)] + all_document_ids = existing_document_ids + non_existent_document_ids + + # Act: Execute the task with mixed document IDs + document_indexing_task(dataset.id, all_document_ids) + + # Assert: Verify only existing documents were processed + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify only existing documents were updated + for document in documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None + + # Verify the run method was called with only existing documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 2 # Only existing documents + + def test_document_indexing_task_indexing_runner_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of IndexingRunner exceptions. + + This test verifies: + - Exceptions from IndexingRunner are properly caught + - Task completes without raising exceptions + - Database session is properly closed + - Error logging occurs + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock IndexingRunner to raise an exception + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception( + "Indexing runner failed" + ) + + # Act: Execute the task + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify exception was handled gracefully + # The task should complete without raising exceptions + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were still updated to parsing status before the exception + for document in documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None + + def test_document_indexing_task_mixed_document_states( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test processing documents with mixed initial states. + + This test verifies: + - Documents with different initial states are handled correctly + - Only valid documents are processed + - Database state updates are consistent + - IndexingRunner receives correct documents + """ + # Arrange: Create test data + dataset, base_documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + + # Create additional documents with different states + fake = Faker() + extra_documents = [] + + # Document with different indexing status + doc1 = Document( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=2, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=dataset.created_by, + indexing_status="completed", # Already completed + enabled=True, + ) + db.session.add(doc1) + extra_documents.append(doc1) + + # Document with disabled status + doc2 = Document( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=3, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=dataset.created_by, + indexing_status="waiting", + enabled=False, # Disabled + ) + db.session.add(doc2) + extra_documents.append(doc2) + + db.session.commit() + + all_documents = base_documents + extra_documents + document_ids = [doc.id for doc in all_documents] + + # Act: Execute the task with mixed document states + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify processing + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify all documents were updated to parsing status + for document in all_documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None + + # Verify the run method was called with all documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 4 + + def test_document_indexing_task_billing_sandbox_plan_batch_limit( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test billing validation for sandbox plan batch upload limit. + + This test verifies: + - Sandbox plan batch upload limit enforcement + - Error handling for batch upload limit exceeded + - Document status updates to error state + - Proper error message recording + """ + # Arrange: Create test data with billing enabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ) + + # Configure sandbox plan with batch limit + mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox" + + # Create more documents than sandbox plan allows (limit is 1) + fake = Faker() + extra_documents = [] + for i in range(2): # Total will be 5 documents (3 existing + 2 new) + document = Document( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=i + 3, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=dataset.created_by, + indexing_status="waiting", + enabled=True, + ) + db.session.add(document) + extra_documents.append(document) + + db.session.commit() + all_documents = documents + extra_documents + document_ids = [doc.id for doc in all_documents] + + # Act: Execute the task with too many documents for sandbox plan + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify error handling + for document in all_documents: + db.session.refresh(document) + assert document.indexing_status == "error" + assert document.error is not None + assert "batch upload" in document.error + assert document.stopped_at is not None + + # Verify no indexing runner was called + mock_external_service_dependencies["indexing_runner"].assert_not_called() + + def test_document_indexing_task_billing_disabled_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful processing when billing is disabled. + + This test verifies: + - Processing continues normally when billing is disabled + - No billing validation occurs + - Documents are processed successfully + - IndexingRunner is called correctly + """ + # Arrange: Create test data with billing disabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=False + ) + + document_ids = [doc.id for doc in documents] + + # Act: Execute the task with billing disabled + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify successful processing + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were updated to parsing status + for document in documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None + + def test_document_indexing_task_document_is_paused_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of DocumentIsPausedError from IndexingRunner. + + This test verifies: + - DocumentIsPausedError is properly caught and handled + - Task completes without raising exceptions + - Appropriate logging occurs + - Database session is properly closed + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock IndexingRunner to raise DocumentIsPausedError + from core.indexing_runner import DocumentIsPausedError + + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = DocumentIsPausedError( + "Document indexing is paused" + ) + + # Act: Execute the task + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify exception was handled gracefully + # The task should complete without raising exceptions + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were still updated to parsing status before the exception + for document in documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index aefb4bf8b0..b6697ac5d4 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -9,7 +9,6 @@ from flask_restx import Api import services.errors.account from controllers.console.auth.error import AuthenticationFailedError from controllers.console.auth.login import LoginApi -from controllers.console.error import AccountNotFound class TestAuthenticationSecurity: @@ -27,31 +26,33 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.FeatureService.get_system_features") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.authenticate") - @patch("controllers.console.auth.login.AccountService.send_reset_password_email") + @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") def test_login_invalid_email_with_registration_allowed( - self, mock_get_invitation, mock_send_email, mock_authenticate, mock_is_rate_limit, mock_features, mock_db + self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): - """Test that invalid email sends reset password email when registration is allowed.""" + """Test that invalid email raises AuthenticationFailedError when account not found.""" # Arrange mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None - mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found") + mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = True - mock_send_email.return_value = "token123" # Act with self.app.test_request_context( "/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"} ): login_api = LoginApi() - result = login_api.post() - # Assert - assert result == {"result": "fail", "data": "token123", "code": "account_not_found"} - mock_send_email.assert_called_once_with(email="nonexistent@example.com", language="en-US") + # Assert + with pytest.raises(AuthenticationFailedError) as exc_info: + login_api.post() + + assert exc_info.value.error_code == "authentication_failed" + assert exc_info.value.description == "Invalid email or password." + mock_add_rate_limit.assert_called_once_with("nonexistent@example.com") @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @@ -87,16 +88,17 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.FeatureService.get_system_features") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.authenticate") + @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") def test_login_invalid_email_with_registration_disabled( - self, mock_get_invitation, mock_authenticate, mock_is_rate_limit, mock_features, mock_db + self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): - """Test that invalid email raises AccountNotFound when registration is disabled.""" + """Test that invalid email raises AuthenticationFailedError when account not found.""" # Arrange mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None - mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found") + mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = False @@ -107,10 +109,12 @@ class TestAuthenticationSecurity: login_api = LoginApi() # Assert - with pytest.raises(AccountNotFound) as exc_info: + with pytest.raises(AuthenticationFailedError) as exc_info: login_api.post() - assert exc_info.value.error_code == "account_not_found" + assert exc_info.value.error_code == "authentication_failed" + assert exc_info.value.description == "Invalid email or password." + mock_add_rate_limit.assert_called_once_with("nonexistent@example.com") @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.FeatureService.get_system_features") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 037c9f2745..a7bdf5de33 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -12,7 +12,7 @@ from controllers.console.auth.oauth import ( ) from libs.oauth import OAuthUserInfo from models.account import AccountStatus -from services.errors.account import AccountNotFoundError +from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @@ -451,7 +451,7 @@ class TestAccountGeneration: with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): if not allow_register and not existing_account: - with pytest.raises(AccountNotFoundError): + with pytest.raises(AccountRegisterError): _generate_account("github", user_info) else: result = _generate_account("github", user_info) diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index f1d741602a..895ebdd751 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -29,7 +29,7 @@ class TestHandleMCPRequest: """Setup test fixtures""" self.app = Mock(spec=App) self.app.name = "test_app" - self.app.mode = AppMode.CHAT.value + self.app.mode = AppMode.CHAT self.mcp_server = Mock(spec=AppMCPServer) self.mcp_server.description = "Test server" @@ -196,7 +196,7 @@ class TestIndividualHandlers: def test_handle_list_tools(self): """Test list tools handler""" app_name = "test_app" - app_mode = AppMode.CHAT.value + app_mode = AppMode.CHAT description = "Test server" parameters_dict: dict[str, str] = {} user_input_form: list[VariableEntity] = [] @@ -212,7 +212,7 @@ class TestIndividualHandlers: def test_handle_call_tool(self, mock_app_generate): """Test call tool handler""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT # Create mock request mock_request = Mock() @@ -252,7 +252,7 @@ class TestUtilityFunctions: def test_build_parameter_schema_chat_mode(self): """Test building parameter schema for chat mode""" - app_mode = AppMode.CHAT.value + app_mode = AppMode.CHAT parameters_dict: dict[str, str] = {"name": "Enter your name"} user_input_form = [ @@ -275,7 +275,7 @@ class TestUtilityFunctions: def test_build_parameter_schema_workflow_mode(self): """Test building parameter schema for workflow mode""" - app_mode = AppMode.WORKFLOW.value + app_mode = AppMode.WORKFLOW parameters_dict: dict[str, str] = {"input_text": "Enter text"} user_input_form = [ @@ -298,7 +298,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_chat_mode(self): """Test preparing tool arguments for chat mode""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT arguments = {"query": "test question", "name": "John"} @@ -312,7 +312,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_workflow_mode(self): """Test preparing tool arguments for workflow mode""" app = Mock(spec=App) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW arguments = {"input_text": "test input"} @@ -324,7 +324,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_completion_mode(self): """Test preparing tool arguments for completion mode""" app = Mock(spec=App) - app.mode = AppMode.COMPLETION.value + app.mode = AppMode.COMPLETION arguments = {"name": "John"} @@ -336,7 +336,7 @@ class TestUtilityFunctions: def test_extract_answer_from_mapping_response_chat(self): """Test extracting answer from mapping response for chat mode""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT response = {"answer": "test answer", "other": "data"} @@ -347,7 +347,7 @@ class TestUtilityFunctions: def test_extract_answer_from_mapping_response_workflow(self): """Test extracting answer from mapping response for workflow mode""" app = Mock(spec=App) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW response = {"data": {"outputs": {"result": "test result"}}} diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index 607728efd8..6689e13b96 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -20,7 +20,6 @@ def test_firecrawl_web_extractor_crawl_mode(mocker): } mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl)) job_id = firecrawl_app.crawl_url(url, params) - print(f"job_id: {job_id}") assert job_id is not None assert isinstance(job_id, str) diff --git a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py index 57ddacd13d..0bf4a3cf91 100644 --- a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py +++ b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py @@ -15,7 +15,7 @@ class FakeResponse: self.status_code = status_code self.headers = headers or {} self.content = content - self.text = text if text else content.decode("utf-8", errors="ignore") + self.text = text or content.decode("utf-8", errors="ignore") # --------------------------- diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 4712960e31..5cd595088a 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -129,7 +129,6 @@ class TestSegmentDumpAndLoad: """Test basic segment serialization compatibility""" model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")]) json = model.model_dump_json() - print("Json: ", json) loaded = _Segments.model_validate_json(json) assert loaded == model @@ -137,7 +136,6 @@ class TestSegmentDumpAndLoad: """Test number segment serialization compatibility""" model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)]) json = model.model_dump_json() - print("Json: ", json) loaded = _Segments.model_validate_json(json) assert loaded == model @@ -145,7 +143,6 @@ class TestSegmentDumpAndLoad: """Test variable serialization compatibility""" model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")]) json = model.model_dump_json() - print("Json: ", json) restored = _Variables.model_validate_json(json) assert restored == model diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 95880a852c..61ce640edd 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -1,7 +1,6 @@ import base64 import uuid from collections.abc import Sequence -from typing import Optional from unittest import mock import pytest @@ -44,7 +43,7 @@ class MockTokenBufferMemory: self.history_messages = history_messages or [] def get_history_prompt_messages( - self, max_token_limit: int = 2000, message_limit: Optional[int] = None + self, max_token_limit: int = 2000, message_limit: int | None = None ) -> Sequence[PromptMessage]: if message_limit is not None: return self.history_messages[-message_limit * 2 :] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 41bbf60d90..b842dfdb58 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -142,15 +142,11 @@ def test_remove_first_from_array(): node.init_node_data(node_config["data"]) # Skip the mock assertion since we're in a test environment - # Print the variable before running - print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}") # Run the node result = list(node.run()) - # Print the variable after running and the result - print(f"After: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}") - print(f"Result: {result}") + # Completed run got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py new file mode 100644 index 0000000000..324f58abf6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -0,0 +1,456 @@ +import pytest + +from core.file.enums import FileType +from core.file.models import File, FileTransferMethod +from core.variables.variables import StringVariable +from core.workflow.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.system_variable import SystemVariable +from core.workflow.workflow_entry import WorkflowEntry + + +class TestWorkflowEntry: + """Test WorkflowEntry class methods.""" + + def test_mapping_user_inputs_to_variable_pool_with_system_variables(self): + """Test mapping system variables from user inputs to variable pool.""" + # Initialize variable pool with system variables + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + ), + user_inputs={}, + ) + + # Define variable mapping - sys variables mapped to other nodes + variable_mapping = { + "node1.input1": ["node1", "input1"], # Regular mapping + "node2.query": ["node2", "query"], # Regular mapping + "sys.user_id": ["output_node", "user"], # System variable mapping + } + + # User inputs including sys variables + user_inputs = { + "node1.input1": "new_user_id", + "node2.query": "test query", + "sys.user_id": "system_user", + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variables were added to pool + # Note: variable_pool.get returns Variable objects, not raw values + node1_var = variable_pool.get(["node1", "input1"]) + assert node1_var is not None + assert node1_var.value == "new_user_id" + + node2_var = variable_pool.get(["node2", "query"]) + assert node2_var is not None + assert node2_var.value == "test query" + + # System variable gets mapped to output node + output_var = variable_pool.get(["output_node", "user"]) + assert output_var is not None + assert output_var.value == "system_user" + + def test_mapping_user_inputs_to_variable_pool_with_env_variables(self): + """Test mapping environment variables from user inputs to variable pool.""" + # Initialize variable pool with environment variables + env_var = StringVariable(name="API_KEY", value="existing_key") + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + environment_variables=[env_var], + user_inputs={}, + ) + + # Add env variable to pool (simulating initialization) + variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var) + + # Define variable mapping - env variables should not be overridden + variable_mapping = { + "node1.api_key": [ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], + "node2.new_env": [ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"], + } + + # User inputs + user_inputs = { + "node1.api_key": "user_provided_key", # This should not override existing env var + "node2.new_env": "new_env_value", + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify env variable was not overridden + env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"]) + assert env_value is not None + assert env_value.value == "existing_key" # Should remain unchanged + + # New env variables from user input should not be added + assert variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"]) is None + + def test_mapping_user_inputs_to_variable_pool_with_conversation_variables(self): + """Test mapping conversation variables from user inputs to variable pool.""" + # Initialize variable pool with conversation variables + conv_var = StringVariable(name="last_message", value="Hello") + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + conversation_variables=[conv_var], + user_inputs={}, + ) + + # Add conversation variable to pool + variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "last_message"], conv_var) + + # Define variable mapping + variable_mapping = { + "node1.message": ["node1", "message"], # Map to regular node + "conversation.context": ["chat_node", "context"], # Conversation var to regular node + } + + # User inputs + user_inputs = { + "node1.message": "Updated message", + "conversation.context": "New context", + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variables were added to their target nodes + node1_var = variable_pool.get(["node1", "message"]) + assert node1_var is not None + assert node1_var.value == "Updated message" + + chat_var = variable_pool.get(["chat_node", "context"]) + assert chat_var is not None + assert chat_var.value == "New context" + + def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self): + """Test mapping regular node variables from user inputs to variable pool.""" + # Initialize empty variable pool + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping for regular nodes + variable_mapping = { + "input_node.text": ["input_node", "text"], + "llm_node.prompt": ["llm_node", "prompt"], + "code_node.input": ["code_node", "input"], + } + + # User inputs + user_inputs = { + "input_node.text": "User input text", + "llm_node.prompt": "Generate a summary", + "code_node.input": {"key": "value"}, + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify regular variables were added + text_var = variable_pool.get(["input_node", "text"]) + assert text_var is not None + assert text_var.value == "User input text" + + prompt_var = variable_pool.get(["llm_node", "prompt"]) + assert prompt_var is not None + assert prompt_var.value == "Generate a summary" + + input_var = variable_pool.get(["code_node", "input"]) + assert input_var is not None + assert input_var.value == {"key": "value"} + + def test_mapping_user_inputs_with_file_handling(self): + """Test mapping file inputs from user inputs to variable pool.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping + variable_mapping = { + "file_node.file": ["file_node", "file"], + "file_node.files": ["file_node", "files"], + } + + # User inputs with file data - using remote_url which doesn't require upload_file_id + user_inputs = { + "file_node.file": { + "type": "document", + "transfer_method": "remote_url", + "url": "http://example.com/test.pdf", + }, + "file_node.files": [ + { + "type": "image", + "transfer_method": "remote_url", + "url": "http://example.com/image1.jpg", + }, + { + "type": "image", + "transfer_method": "remote_url", + "url": "http://example.com/image2.jpg", + }, + ], + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify file was converted and added + file_var = variable_pool.get(["file_node", "file"]) + assert file_var is not None + assert file_var.value.type == FileType.DOCUMENT + assert file_var.value.transfer_method == FileTransferMethod.REMOTE_URL + + # Verify file list was converted and added + files_var = variable_pool.get(["file_node", "files"]) + assert files_var is not None + assert isinstance(files_var.value, list) + assert len(files_var.value) == 2 + assert all(isinstance(f, File) for f in files_var.value) + assert files_var.value[0].type == FileType.IMAGE + assert files_var.value[1].type == FileType.IMAGE + assert files_var.value[0].type == FileType.IMAGE + assert files_var.value[1].type == FileType.IMAGE + + def test_mapping_user_inputs_missing_variable_error(self): + """Test that mapping raises error when required variable is missing.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping + variable_mapping = { + "node1.required_input": ["node1", "required_input"], + } + + # User inputs without required variable + user_inputs = { + "node1.other_input": "some value", + } + + # Should raise ValueError for missing variable + with pytest.raises(ValueError, match="Variable key node1.required_input not found in user inputs"): + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + def test_mapping_user_inputs_with_alternative_key_format(self): + """Test mapping with alternative key format (without node prefix).""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping + variable_mapping = { + "node1.input": ["node1", "input"], + } + + # User inputs with alternative key format + user_inputs = { + "input": "value without node prefix", # Alternative format without node prefix + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variable was added using alternative key + input_var = variable_pool.get(["node1", "input"]) + assert input_var is not None + assert input_var.value == "value without node prefix" + + def test_mapping_user_inputs_with_complex_selectors(self): + """Test mapping with complex node variable keys.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping - selectors can only have 2 elements + variable_mapping = { + "node1.data.field1": ["node1", "data_field1"], # Complex key mapped to simple selector + "node2.config.settings.timeout": ["node2", "timeout"], # Complex key mapped to simple selector + } + + # User inputs + user_inputs = { + "node1.data.field1": "nested value", + "node2.config.settings.timeout": 30, + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variables were added with simple selectors + data_var = variable_pool.get(["node1", "data_field1"]) + assert data_var is not None + assert data_var.value == "nested value" + + timeout_var = variable_pool.get(["node2", "timeout"]) + assert timeout_var is not None + assert timeout_var.value == 30 + + def test_mapping_user_inputs_invalid_node_variable(self): + """Test that mapping handles invalid node variable format.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping with single element node variable (at least one dot is required) + variable_mapping = { + "singleelement": ["node1", "input"], # No dot separator + } + + user_inputs = {"singleelement": "some value"} # Must use exact key + + # Should NOT raise error - function accepts it and uses direct key + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify it was added + var = variable_pool.get(["node1", "input"]) + assert var is not None + assert var.value == "some value" + + def test_mapping_all_variable_types_together(self): + """Test mapping all four types of variables in one operation.""" + # Initialize variable pool with some existing variables + env_var = StringVariable(name="API_KEY", value="existing_key") + conv_var = StringVariable(name="session_id", value="session123") + + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="test_user", + app_id="test_app", + query="initial query", + ), + environment_variables=[env_var], + conversation_variables=[conv_var], + user_inputs={}, + ) + + # Add existing variables to pool + variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var) + variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "session_id"], conv_var) + + # Define comprehensive variable mapping + variable_mapping = { + # System variables mapped to regular nodes + "sys.user_id": ["start", "user"], + "sys.app_id": ["start", "app"], + # Environment variables (won't be overridden) + "env.API_KEY": ["config", "api_key"], + # Conversation variables mapped to regular nodes + "conversation.session_id": ["chat", "session"], + # Regular variables + "input.text": ["input", "text"], + "process.data": ["process", "data"], + } + + # User inputs + user_inputs = { + "sys.user_id": "new_user", + "sys.app_id": "new_app", + "env.API_KEY": "attempted_override", # Should not override env var + "conversation.session_id": "new_session", + "input.text": "user input text", + "process.data": {"value": 123, "status": "active"}, + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify system variables were added to their target nodes + start_user = variable_pool.get(["start", "user"]) + assert start_user is not None + assert start_user.value == "new_user" + + start_app = variable_pool.get(["start", "app"]) + assert start_app is not None + assert start_app.value == "new_app" + + # Verify env variable was not overridden (still has original value) + env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"]) + assert env_value is not None + assert env_value.value == "existing_key" + + # Environment variables get mapped to other nodes even when they exist in env pool + # But the original env value remains unchanged + config_api_key = variable_pool.get(["config", "api_key"]) + assert config_api_key is not None + assert config_api_key.value == "attempted_override" + + # Verify conversation variable was mapped to target node + chat_session = variable_pool.get(["chat", "session"]) + assert chat_session is not None + assert chat_session.value == "new_session" + + # Verify regular variables were added + input_text = variable_pool.get(["input", "text"]) + assert input_text is not None + assert input_text.value == "user input text" + + process_data = variable_pool.get(["process", "data"]) + assert process_data is not None + assert process_data.value == {"value": 123, "status": "active"} diff --git a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py index 7d295cecf2..958072223e 100644 --- a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py +++ b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py @@ -11,12 +11,12 @@ class TestSupabaseStorage: def test_init_success_with_all_config(self): """Test successful initialization when all required config is provided.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -31,7 +31,7 @@ class TestSupabaseStorage: def test_init_raises_error_when_url_missing(self): """Test initialization raises ValueError when SUPABASE_URL is None.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = None mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" @@ -41,7 +41,7 @@ class TestSupabaseStorage: def test_init_raises_error_when_api_key_missing(self): """Test initialization raises ValueError when SUPABASE_API_KEY is None.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = None mock_config.SUPABASE_BUCKET_NAME = "test-bucket" @@ -51,7 +51,7 @@ class TestSupabaseStorage: def test_init_raises_error_when_bucket_name_missing(self): """Test initialization raises ValueError when SUPABASE_BUCKET_NAME is None.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = None @@ -61,12 +61,12 @@ class TestSupabaseStorage: def test_create_bucket_when_not_exists(self): """Test create_bucket creates bucket when it doesn't exist.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -77,12 +77,12 @@ class TestSupabaseStorage: def test_create_bucket_when_exists(self): """Test create_bucket does not create bucket when it already exists.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -94,12 +94,12 @@ class TestSupabaseStorage: @pytest.fixture def storage_with_mock_client(self): """Fixture providing SupabaseStorage with mocked client.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -251,12 +251,12 @@ class TestSupabaseStorage: def test_bucket_exists_returns_true_when_bucket_found(self): """Test bucket_exists returns True when bucket is found in list.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -271,12 +271,12 @@ class TestSupabaseStorage: def test_bucket_exists_returns_false_when_bucket_not_found(self): """Test bucket_exists returns False when bucket is not found in list.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -294,12 +294,12 @@ class TestSupabaseStorage: def test_bucket_exists_returns_false_when_no_buckets(self): """Test bucket_exists returns False when no buckets exist.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index 9e4e74bd0f..7c0eccbb8b 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -4,7 +4,7 @@ from typing import Any from uuid import uuid4 import pytest -from hypothesis import given +from hypothesis import given, settings from hypothesis import strategies as st from core.file import File, FileTransferMethod, FileType @@ -486,13 +486,14 @@ def _generate_file(draw) -> File: def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]: return st.one_of( st.none(), - st.integers(), - st.floats(), - st.text(), + st.integers(min_value=-(10**6), max_value=10**6), + st.floats(allow_nan=True, allow_infinity=False), + st.text(max_size=50), _generate_file(), ) +@settings(max_examples=50) @given(_scalar_value()) def test_build_segment_and_extract_values_for_scalar_types(value): seg = variable_factory.build_segment(value) @@ -503,7 +504,8 @@ def test_build_segment_and_extract_values_for_scalar_types(value): assert seg.value == value -@given(st.lists(_scalar_value())) +@settings(max_examples=50) +@given(values=st.lists(_scalar_value(), max_size=20)) def test_build_segment_and_extract_values_for_array_types(values): seg = variable_factory.build_segment(values) assert seg.value == values diff --git a/api/tests/unit_tests/libs/test_email_i18n.py b/api/tests/unit_tests/libs/test_email_i18n.py index b80c711cac..962a36fe03 100644 --- a/api/tests/unit_tests/libs/test_email_i18n.py +++ b/api/tests/unit_tests/libs/test_email_i18n.py @@ -246,6 +246,43 @@ class TestEmailI18nService: sent_email = mock_sender.sent_emails[0] assert sent_email["subject"] == "Reset Your Dify Password" + def test_subject_format_keyerror_fallback_path( + self, + mock_renderer: MockEmailRenderer, + mock_sender: MockEmailSender, + ): + """Trigger subject KeyError and cover except branch.""" + # Config with subject that references an unknown key (no {application_title} to avoid second format) + config = EmailI18nConfig( + templates={ + EmailType.INVITE_MEMBER: { + EmailLanguage.EN_US: EmailTemplate( + subject="Invite: {unknown_placeholder}", + template_path="invite_member_en.html", + branded_template_path="branded/invite_member_en.html", + ), + } + } + ) + branding_service = MockBrandingService(enabled=False) + service = EmailI18nService( + config=config, + renderer=mock_renderer, + branding_service=branding_service, + sender=mock_sender, + ) + + # Will raise KeyError on subject.format(**full_context), then hit except branch and skip fallback + service.send_email( + email_type=EmailType.INVITE_MEMBER, + language_code="en-US", + to="test@example.com", + ) + + assert len(mock_sender.sent_emails) == 1 + # Subject is left unformatted due to KeyError fallback path without application_title + assert mock_sender.sent_emails[0]["subject"] == "Invite: {unknown_placeholder}" + def test_send_change_email_old_phase( self, email_config: EmailI18nConfig, diff --git a/api/tests/unit_tests/libs/test_external_api.py b/api/tests/unit_tests/libs/test_external_api.py new file mode 100644 index 0000000000..a9edb913ea --- /dev/null +++ b/api/tests/unit_tests/libs/test_external_api.py @@ -0,0 +1,122 @@ +from flask import Blueprint, Flask +from flask_restx import Resource +from werkzeug.exceptions import BadRequest, Unauthorized + +from core.errors.error import AppInvokeQuotaExceededError +from libs.external_api import ExternalApi + + +def _create_api_app(): + app = Flask(__name__) + bp = Blueprint("t", __name__) + api = ExternalApi(bp) + + @api.route("/bad-request") + class Bad(Resource): # type: ignore + def get(self): # type: ignore + raise BadRequest("invalid input") + + @api.route("/unauth") + class Unauth(Resource): # type: ignore + def get(self): # type: ignore + raise Unauthorized("auth required") + + @api.route("/value-error") + class ValErr(Resource): # type: ignore + def get(self): # type: ignore + raise ValueError("boom") + + @api.route("/quota") + class Quota(Resource): # type: ignore + def get(self): # type: ignore + raise AppInvokeQuotaExceededError("quota exceeded") + + @api.route("/general") + class Gen(Resource): # type: ignore + def get(self): # type: ignore + raise RuntimeError("oops") + + # Note: We avoid altering default_mediatype to keep normal error paths + + # Special 400 message rewrite + @api.route("/json-empty") + class JsonEmpty(Resource): # type: ignore + def get(self): # type: ignore + e = BadRequest() + # Force the specific message the handler rewrites + e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)" + raise e + + # 400 mapping payload path + @api.route("/param-errors") + class ParamErrors(Resource): # type: ignore + def get(self): # type: ignore + e = BadRequest() + # Coerce a mapping description to trigger param error shaping + e.description = {"field": "is required"} # type: ignore[assignment] + raise e + + app.register_blueprint(bp, url_prefix="/api") + return app + + +def test_external_api_error_handlers_basic_paths(): + app = _create_api_app() + client = app.test_client() + + # 400 + res = client.get("/api/bad-request") + assert res.status_code == 400 + data = res.get_json() + assert data["code"] == "bad_request" + assert data["status"] == 400 + + # 401 + res = client.get("/api/unauth") + assert res.status_code == 401 + assert "WWW-Authenticate" in res.headers + + # 400 ValueError + res = client.get("/api/value-error") + assert res.status_code == 400 + assert res.get_json()["code"] == "invalid_param" + + # 500 general + res = client.get("/api/general") + assert res.status_code == 500 + assert res.get_json()["status"] == 500 + + +def test_external_api_json_message_and_bad_request_rewrite(): + app = _create_api_app() + client = app.test_client() + + # JSON empty special rewrite + res = client.get("/api/json-empty") + assert res.status_code == 400 + assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty." + + +def test_external_api_param_mapping_and_quota_and_exc_info_none(): + # Force exc_info() to return (None,None,None) only during request + import libs.external_api as ext + + orig_exc_info = ext.sys.exc_info + try: + ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment] + + app = _create_api_app() + client = app.test_client() + + # Param errors mapping payload path + res = client.get("/api/param-errors") + assert res.status_code == 400 + data = res.get_json() + assert data["code"] == "invalid_param" + assert data["params"] == "field" + + # Quota path — depending on Flask-RESTX internals it may be handled + res = client.get("/api/quota") + assert res.status_code in (400, 429) + finally: + ext.sys.exc_info = orig_exc_info # type: ignore[assignment] diff --git a/api/tests/unit_tests/libs/test_file_utils.py b/api/tests/unit_tests/libs/test_file_utils.py new file mode 100644 index 0000000000..8d9b4e803a --- /dev/null +++ b/api/tests/unit_tests/libs/test_file_utils.py @@ -0,0 +1,55 @@ +from pathlib import Path + +import pytest + +from libs.file_utils import search_file_upwards + + +def test_search_file_upwards_found_in_parent(tmp_path: Path): + base = tmp_path / "a" / "b" / "c" + base.mkdir(parents=True) + + target = tmp_path / "a" / "target.txt" + target.write_text("ok", encoding="utf-8") + + found = search_file_upwards(base, "target.txt", max_search_parent_depth=5) + assert found == target + + +def test_search_file_upwards_found_in_current(tmp_path: Path): + base = tmp_path / "x" + base.mkdir() + target = base / "here.txt" + target.write_text("x", encoding="utf-8") + + found = search_file_upwards(base, "here.txt", max_search_parent_depth=1) + assert found == target + + +def test_search_file_upwards_not_found_raises(tmp_path: Path): + base = tmp_path / "m" / "n" + base.mkdir(parents=True) + with pytest.raises(ValueError) as exc: + search_file_upwards(base, "missing.txt", max_search_parent_depth=3) + # error message should contain file name and base path + msg = str(exc.value) + assert "missing.txt" in msg + assert str(base) in msg + + +def test_search_file_upwards_root_breaks_and_raises(): + # Using filesystem root triggers the 'break' branch (parent == current) + with pytest.raises(ValueError): + search_file_upwards(Path("/"), "__definitely_not_exists__.txt", max_search_parent_depth=1) + + +def test_search_file_upwards_depth_limit_raises(tmp_path: Path): + base = tmp_path / "a" / "b" / "c" + base.mkdir(parents=True) + target = tmp_path / "a" / "target.txt" + target.write_text("ok", encoding="utf-8") + # The file is 2 levels up from `c` (in `a`), but search depth is only 2. + # The search path is `c` (depth 1) -> `b` (depth 2). The file is in `a` (would need depth 3). + # So, this should not find the file and should raise an error. + with pytest.raises(ValueError): + search_file_upwards(base, "target.txt", max_search_parent_depth=2) diff --git a/api/tests/unit_tests/libs/test_flask_utils.py b/api/tests/unit_tests/libs/test_flask_utils.py index fb46ba50f3..e30433bfce 100644 --- a/api/tests/unit_tests/libs/test_flask_utils.py +++ b/api/tests/unit_tests/libs/test_flask_utils.py @@ -1,6 +1,5 @@ import contextvars import threading -from typing import Optional import pytest from flask import Flask @@ -29,7 +28,7 @@ def login_app(app: Flask) -> Flask: login_manager.init_app(app) @login_manager.user_loader - def load_user(user_id: str) -> Optional[User]: + def load_user(user_id: str) -> User | None: if user_id == "test_user": return User("test_user") return None diff --git a/api/tests/unit_tests/libs/test_json_in_md_parser.py b/api/tests/unit_tests/libs/test_json_in_md_parser.py new file mode 100644 index 0000000000..53fd0bea16 --- /dev/null +++ b/api/tests/unit_tests/libs/test_json_in_md_parser.py @@ -0,0 +1,88 @@ +import pytest + +from core.llm_generator.output_parser.errors import OutputParserError +from libs.json_in_md_parser import ( + parse_and_check_json_markdown, + parse_json_markdown, +) + + +def test_parse_json_markdown_triple_backticks_json(): + src = """ + ```json + {"a": 1, "b": "x"} + ``` + """ + assert parse_json_markdown(src) == {"a": 1, "b": "x"} + + +def test_parse_json_markdown_triple_backticks_generic(): + src = """ + ``` + {"k": [1, 2, 3]} + ``` + """ + assert parse_json_markdown(src) == {"k": [1, 2, 3]} + + +def test_parse_json_markdown_single_backticks(): + src = '`{"x": true}`' + assert parse_json_markdown(src) == {"x": True} + + +def test_parse_json_markdown_braces_only(): + src = ' {\n \t"ok": "yes"\n} ' + assert parse_json_markdown(src) == {"ok": "yes"} + + +def test_parse_json_markdown_not_found(): + with pytest.raises(ValueError): + parse_json_markdown("no json here") + + +def test_parse_and_check_json_markdown_missing_key(): + src = """ + ``` + {"present": 1} + ``` + """ + with pytest.raises(OutputParserError) as exc: + parse_and_check_json_markdown(src, ["present", "missing"]) + assert "expected key `missing`" in str(exc.value) + + +def test_parse_and_check_json_markdown_invalid_json(): + src = """ + ```json + {invalid json} + ``` + """ + with pytest.raises(OutputParserError) as exc: + parse_and_check_json_markdown(src, []) + assert "got invalid json object" in str(exc.value) + + +def test_parse_and_check_json_markdown_success(): + src = """ + ```json + {"present": 1, "other": 2} + ``` + """ + obj = parse_and_check_json_markdown(src, ["present"]) + assert obj == {"present": 1, "other": 2} + + +def test_parse_and_check_json_markdown_multiple_blocks_fails(): + src = """ + ```json + {"a": 1} + ``` + Some text + ```json + {"b": 2} + ``` + """ + # The current implementation is greedy and will match from the first + # opening fence to the last closing fence, causing JSON decode failure. + with pytest.raises(OutputParserError): + parse_and_check_json_markdown(src, []) diff --git a/api/tests/unit_tests/libs/test_oauth_base.py b/api/tests/unit_tests/libs/test_oauth_base.py new file mode 100644 index 0000000000..3e0c235fff --- /dev/null +++ b/api/tests/unit_tests/libs/test_oauth_base.py @@ -0,0 +1,19 @@ +import pytest + +from libs.oauth import OAuth + + +def test_oauth_base_methods_raise_not_implemented(): + oauth = OAuth(client_id="id", client_secret="sec", redirect_uri="uri") + + with pytest.raises(NotImplementedError): + oauth.get_authorization_url() + + with pytest.raises(NotImplementedError): + oauth.get_access_token("code") + + with pytest.raises(NotImplementedError): + oauth.get_raw_user_info("token") + + with pytest.raises(NotImplementedError): + oauth._transform_user_info({}) # type: ignore[name-defined] diff --git a/api/tests/unit_tests/libs/test_orjson.py b/api/tests/unit_tests/libs/test_orjson.py new file mode 100644 index 0000000000..6df1d077df --- /dev/null +++ b/api/tests/unit_tests/libs/test_orjson.py @@ -0,0 +1,25 @@ +import orjson +import pytest + +from libs.orjson import orjson_dumps + + +def test_orjson_dumps_round_trip_basic(): + obj = {"a": 1, "b": [1, 2, 3], "c": {"d": True}} + s = orjson_dumps(obj) + assert orjson.loads(s) == obj + + +def test_orjson_dumps_with_unicode_and_indent(): + obj = {"msg": "你好,Dify"} + s = orjson_dumps(obj, option=orjson.OPT_INDENT_2) + # contains indentation newline/spaces + assert "\n" in s + assert orjson.loads(s) == obj + + +def test_orjson_dumps_non_utf8_encoding_fails(): + obj = {"msg": "你好"} + # orjson.dumps() always produces UTF-8 bytes; decoding with non-UTF8 fails. + with pytest.raises(UnicodeDecodeError): + orjson_dumps(obj, encoding="ascii") diff --git a/api/tests/unit_tests/libs/test_sendgrid_client.py b/api/tests/unit_tests/libs/test_sendgrid_client.py new file mode 100644 index 0000000000..85744003c7 --- /dev/null +++ b/api/tests/unit_tests/libs/test_sendgrid_client.py @@ -0,0 +1,53 @@ +from unittest.mock import MagicMock, patch + +import pytest +from python_http_client.exceptions import UnauthorizedError + +from libs.sendgrid import SendGridClient + + +def _mail(to: str = "user@example.com") -> dict: + return {"to": to, "subject": "Hi", "html": "Hi"} + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_success(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + # nested attribute access: client.mail.send.post + mock_client.client.mail.send.post.return_value = MagicMock(status_code=202, body=b"", headers={}) + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + sg.send(_mail()) + + mock_client_cls.assert_called_once() + mock_client.client.mail.send.post.assert_called_once() + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_missing_to_raises(mock_client_cls: MagicMock): + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(ValueError): + sg.send(_mail(to="")) + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_auth_errors_reraise(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + mock_client.client.mail.send.post.side_effect = UnauthorizedError(401, "Unauthorized", b"{}", {}) + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(UnauthorizedError): + sg.send(_mail()) + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_timeout_reraise(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + mock_client.client.mail.send.post.side_effect = TimeoutError("timeout") + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(TimeoutError): + sg.send(_mail()) diff --git a/api/tests/unit_tests/libs/test_smtp_client.py b/api/tests/unit_tests/libs/test_smtp_client.py new file mode 100644 index 0000000000..fcee01ca00 --- /dev/null +++ b/api/tests/unit_tests/libs/test_smtp_client.py @@ -0,0 +1,100 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from libs.smtp import SMTPClient + + +def _mail() -> dict: + return {"to": "user@example.com", "subject": "Hi", "html": "Hi"} + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_plain_success(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com") + client.send(_mail()) + + mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10) + mock_smtp.sendmail.assert_called_once() + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=587, + username="user", + password="pass", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=True, + ) + client.send(_mail()) + + mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10) + assert mock_smtp.ehlo.call_count == 2 + mock_smtp.starttls.assert_called_once() + mock_smtp.login.assert_called_once_with("user", "pass") + mock_smtp.sendmail.assert_called_once() + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP_SSL") +def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock): + # Cover SMTP_SSL branch and TimeoutError handling + mock_smtp = MagicMock() + mock_smtp.sendmail.side_effect = TimeoutError("timeout") + mock_smtp_ssl_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="", + password="", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + with pytest.raises(TimeoutError): + client.send(_mail()) + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp.sendmail.side_effect = RuntimeError("oops") + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com") + with pytest.raises(RuntimeError): + client.send(_mail()) + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock): + # Ensure we hit the specific SMTPException except branch + import smtplib + + mock_smtp = MagicMock() + mock_smtp.login.side_effect = smtplib.SMTPException("login-fail") + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=25, + username="user", # non-empty to trigger login + password="pass", + _from="noreply@example.com", + ) + with pytest.raises(smtplib.SMTPException): + client.send(_mail()) + mock_smtp.quit.assert_called_once() diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py index dc42a04cf3..d23298f096 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py @@ -28,18 +28,20 @@ class TestApiKeyAuthService: mock_binding.provider = self.provider mock_binding.disabled = False - mock_session.query.return_value.where.return_value.all.return_value = [mock_binding] + mock_session.scalars.return_value.all.return_value = [mock_binding] result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) assert len(result) == 1 assert result[0].tenant_id == self.tenant_id - mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding) + assert mock_session.scalars.call_count == 1 + select_arg = mock_session.scalars.call_args[0][0] + assert "data_source_api_key_auth_binding" in str(select_arg).lower() @patch("services.auth.api_key_auth_service.db.session") def test_get_provider_auth_list_empty(self, mock_session): """Test get provider auth list - empty result""" - mock_session.query.return_value.where.return_value.all.return_value = [] + mock_session.scalars.return_value.all.return_value = [] result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) @@ -48,13 +50,15 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.db.session") def test_get_provider_auth_list_filters_disabled(self, mock_session): """Test get provider auth list - filters disabled items""" - mock_session.query.return_value.where.return_value.all.return_value = [] + mock_session.scalars.return_value.all.return_value = [] ApiKeyAuthService.get_provider_auth_list(self.tenant_id) - - # Verify where conditions include disabled.is_(False) - where_call = mock_session.query.return_value.where.call_args[0] - assert len(where_call) == 2 # tenant_id and disabled filter conditions + select_stmt = mock_session.scalars.call_args[0][0] + where_clauses = list(getattr(select_stmt, "_where_criteria", []) or []) + # Ensure both tenant filter and disabled filter exist + where_strs = [str(c).lower() for c in where_clauses] + assert any("tenant_id" in s for s in where_strs) + assert any("disabled" in s for s in where_strs) @patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py index 4ce5525942..bb39b92c09 100644 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ b/api/tests/unit_tests/services/auth/test_auth_integration.py @@ -63,10 +63,10 @@ class TestAuthIntegration: tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) - mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding] + mock_session.scalars.return_value.all.return_value = [tenant1_binding] result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) - mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding] + mock_session.scalars.return_value.all.return_value = [tenant2_binding] result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) assert len(result1) == 1 diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 442839e44e..737202f8de 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -10,7 +10,6 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, - AccountNotFoundError, AccountPasswordError, AccountRegisterError, CurrentPasswordIncorrectError, @@ -195,7 +194,7 @@ class TestAccountService: # Execute test and verify exception self._assert_exception_raised( - AccountNotFoundError, AccountService.authenticate, "notfound@example.com", "password" + AccountPasswordError, AccountService.authenticate, "notfound@example.com", "password" ) def test_authenticate_account_banned(self, mock_db_dependencies): @@ -1370,8 +1369,8 @@ class TestRegisterService: account_id="user-123", email="test@example.com" ) - with patch("services.account_service.RegisterService._get_invitation_by_token") as mock_get_invitation_by_token: - # Mock the invitation data returned by _get_invitation_by_token + with patch("services.account_service.RegisterService.get_invitation_by_token") as mock_get_invitation_by_token: + # Mock the invitation data returned by get_invitation_by_token invitation_data = { "account_id": "user-123", "email": "test@example.com", @@ -1503,12 +1502,12 @@ class TestRegisterService: assert result == "member_invite:token:test-token" def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies): - """Test _get_invitation_by_token with workspace ID and email.""" + """Test get_invitation_by_token with workspace ID and email.""" # Setup mock mock_redis_dependencies.get.return_value = b"user-123" # Execute test - result = RegisterService._get_invitation_by_token("token-123", "workspace-456", "test@example.com") + result = RegisterService.get_invitation_by_token("token-123", "workspace-456", "test@example.com") # Verify results assert result is not None @@ -1517,7 +1516,7 @@ class TestRegisterService: assert result["workspace_id"] == "workspace-456" def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies): - """Test _get_invitation_by_token without workspace ID and email.""" + """Test get_invitation_by_token without workspace ID and email.""" # Setup mock invitation_data = { "account_id": "user-123", @@ -1527,19 +1526,19 @@ class TestRegisterService: mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() # Execute test - result = RegisterService._get_invitation_by_token("token-123") + result = RegisterService.get_invitation_by_token("token-123") # Verify results assert result is not None assert result == invitation_data def test_get_invitation_by_token_no_data(self, mock_redis_dependencies): - """Test _get_invitation_by_token with no data.""" + """Test get_invitation_by_token with no data.""" # Setup mock mock_redis_dependencies.get.return_value = None # Execute test - result = RegisterService._get_invitation_by_token("token-123") + result = RegisterService.get_invitation_by_token("token-123") # Verify results assert result is None diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py index 1881ceac26..69766188f3 100644 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py @@ -1,5 +1,4 @@ import datetime -from typing import Optional # Mock redis_client before importing dataset_service from unittest.mock import Mock, call, patch @@ -37,7 +36,7 @@ class DocumentBatchUpdateTestDataFactory: enabled: bool = True, archived: bool = False, indexing_status: str = "completed", - completed_at: Optional[datetime.datetime] = None, + completed_at: datetime.datetime | None = None, **kwargs, ) -> Mock: """Create a mock document with specified attributes.""" diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py index fb23863043..df5596f5c8 100644 --- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, Optional +from typing import Any # Mock redis_client before importing dataset_service from unittest.mock import Mock, create_autospec, patch @@ -24,9 +24,9 @@ class DatasetUpdateTestDataFactory: description: str = "old_description", indexing_technique: str = "high_quality", retrieval_model: str = "old_model", - embedding_model_provider: Optional[str] = None, - embedding_model: Optional[str] = None, - collection_binding_id: Optional[str] = None, + embedding_model_provider: str | None = None, + embedding_model: str | None = None, + collection_binding_id: str | None = None, **kwargs, ) -> Mock: """Create a mock dataset with specified attributes.""" diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index ad65175e89..0ff1edc950 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest.mock import Mock, create_autospec, patch import pytest @@ -146,19 +147,17 @@ class TestMetadataBugCompleteValidation: # Console API create console_create_file = "api/controllers/console/datasets/metadata.py" if os.path.exists(console_create_file): - with open(console_create_file) as f: - content = f.read() - # Should contain nullable=False, not nullable=True - assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0] + content = Path(console_create_file).read_text() + # Should contain nullable=False, not nullable=True + assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0] # Service API create service_create_file = "api/controllers/service_api/dataset/metadata.py" if os.path.exists(service_create_file): - with open(service_create_file) as f: - content = f.read() - # Should contain nullable=False, not nullable=True - create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0] - assert "nullable=True" not in create_api_section + content = Path(service_create_file).read_text() + # Should contain nullable=False, not nullable=True + create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0] + assert "nullable=True" not in create_api_section class TestMetadataValidationSummary: diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 0a09167349..2ca781bae5 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -66,7 +66,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): app_model = MagicMock() app_model.id = "app_id" app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.CHAT.value + app_model.mode = AppMode.CHAT api_based_extension_id = "api_based_extension_id" mock_api_based_extension = APIBasedExtension( @@ -127,7 +127,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): app_model = MagicMock() app_model.id = "app_id" app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.WORKFLOW.value + app_model.mode = AppMode.WORKFLOW api_based_extension_id = "api_based_extension_id" mock_api_based_extension = APIBasedExtension( diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 93284eed4b..9046f785d2 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -279,8 +279,6 @@ def test_structured_output_parser(): ] for case in testcases: - print(f"Running test case: {case['name']}") - # Setup model entity model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"]) diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index 95b93651d5..9e2b0659c0 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -3,7 +3,7 @@ from textwrap import dedent import pytest from yaml import YAMLError -from core.tools.utils.yaml_utils import load_yaml_file +from core.tools.utils.yaml_utils import _load_yaml_file EXAMPLE_YAML_FILE = "example_yaml.yaml" INVALID_YAML_FILE = "invalid_yaml.yaml" @@ -56,15 +56,15 @@ def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: def test_load_yaml_non_existing_file(): - assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} - assert load_yaml_file(file_path="") == {} + with pytest.raises(FileNotFoundError): + _load_yaml_file(file_path=NON_EXISTING_YAML_FILE) with pytest.raises(FileNotFoundError): - load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False) + _load_yaml_file(file_path="") def test_load_valid_yaml_file(prepare_example_yaml_file): - yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) + yaml_data = _load_yaml_file(file_path=prepare_example_yaml_file) assert len(yaml_data) > 0 assert yaml_data["age"] == 30 assert yaml_data["gender"] == "male" @@ -77,7 +77,4 @@ def test_load_valid_yaml_file(prepare_example_yaml_file): def test_load_invalid_yaml_file(prepare_invalid_yaml_file): # yaml syntax error with pytest.raises(YAMLError): - load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False) - - # ignore error - assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {} + _load_yaml_file(file_path=prepare_invalid_yaml_file) diff --git a/api/uv.lock b/api/uv.lock index 342d8493a2..788499f88f 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -538,6 +538,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/57/f4/a69c20ee4f660081a7dedb1ac57f29be9378e04edfcb90c526b923d4bebc/beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a", size = 142979, upload-time = "2023-04-07T15:02:50.77Z" }, ] +[[package]] +name = "bidict" +version = "0.23.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/6e/026678aa5a830e07cd9498a05d3e7e650a4f56a42f267a53d22bcda1bdc9/bidict-0.23.1.tar.gz", hash = "sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71", size = 29093, upload-time = "2024-02-18T19:09:05.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/37/e8730c3587a65eb5645d4aba2d27aae48e8003614d6aaf15dda67f702f1f/bidict-0.23.1-py3-none-any.whl", hash = "sha256:5dae8d4d79b552a71cbabc7deb25dfe8ce710b17ff41711e13010ead2abfc3e5", size = 32764, upload-time = "2024-02-18T19:09:04.156Z" }, +] + [[package]] name = "billiard" version = "4.2.1" @@ -1009,7 +1018,7 @@ wheels = [ [[package]] name = "clickzetta-connector-python" -version = "0.8.102" +version = "0.8.104" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "future" }, @@ -1023,7 +1032,7 @@ dependencies = [ { name = "urllib3" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/e5/23dcc950e873127df0135cf45144062a3207f5d2067259c73854e8ce7228/clickzetta_connector_python-0.8.102-py3-none-any.whl", hash = "sha256:c45486ae77fd82df7113ec67ec50e772372588d79c23757f8ee6291a057994a7", size = 77861, upload-time = "2025-07-17T03:11:59.543Z" }, + { url = "https://files.pythonhosted.org/packages/8f/94/c7eee2224bdab39d16dfe5bb7687f5525c7ed345b7fe8812e18a2d9a6335/clickzetta_connector_python-0.8.104-py3-none-any.whl", hash = "sha256:ae3e466d990677f96c769ec1c29318237df80c80fe9c1e21ba1eaf42bdef0207", size = 79382, upload-time = "2025-09-10T08:46:39.731Z" }, ] [[package]] @@ -1061,6 +1070,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/06/3d6badcf13db419e25b07041d9c7b4a2c331d3f4e7134445ec5df57714cd/coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934", size = 46018, upload-time = "2021-06-11T10:22:42.561Z" }, ] +[[package]] +name = "configargparse" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/85/4d/6c9ef746dfcc2a32e26f3860bb4a011c008c392b83eabdfb598d1a8bbe5d/configargparse-1.7.1.tar.gz", hash = "sha256:79c2ddae836a1e5914b71d58e4b9adbd9f7779d4e6351a637b7d2d9b6c46d3d9", size = 43958, upload-time = "2025-05-23T14:26:17.369Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/28/d28211d29bcc3620b1fece85a65ce5bb22f18670a03cd28ea4b75ede270c/configargparse-1.7.1-py3-none-any.whl", hash = "sha256:8b586a31f9d873abd1ca527ffbe58863c99f36d896e2829779803125e83be4b6", size = 25607, upload-time = "2025-05-23T14:26:15.923Z" }, +] + [[package]] name = "cos-python-sdk-v5" version = "1.9.30" @@ -1358,6 +1376,7 @@ dev = [ { name = "faker" }, { name = "hypothesis" }, { name = "import-linter" }, + { name = "locust" }, { name = "lxml-stubs" }, { name = "mypy" }, { name = "pandas-stubs" }, @@ -1368,6 +1387,7 @@ dev = [ { name = "pytest-mock" }, { name = "ruff" }, { name = "scipy-stubs" }, + { name = "sseclient-py" }, { name = "testcontainers" }, { name = "ty" }, { name = "types-aiofiles" }, @@ -1534,7 +1554,7 @@ requires-dist = [ { name = "sseclient-py", specifier = "~=1.8.0" }, { name = "starlette", specifier = "==0.47.2" }, { name = "tiktoken", specifier = "~=0.9.0" }, - { name = "transformers", specifier = "~=4.53.0" }, + { name = "transformers", specifier = "~=4.56.1" }, { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" }, { name = "weave", specifier = "~=0.51.0" }, { name = "webvtt-py", specifier = "~=0.5.1" }, @@ -1551,6 +1571,7 @@ dev = [ { name = "faker", specifier = "~=32.1.0" }, { name = "hypothesis", specifier = ">=6.131.15" }, { name = "import-linter", specifier = ">=2.3" }, + { name = "locust", specifier = ">=2.40.4" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.17.1" }, { name = "pandas-stubs", specifier = "~=2.2.3" }, @@ -1561,6 +1582,7 @@ dev = [ { name = "pytest-mock", specifier = "~=3.14.0" }, { name = "ruff", specifier = "~=0.12.3" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, + { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "testcontainers", specifier = "~=4.10.0" }, { name = "ty", specifier = "~=0.0.1a19" }, { name = "types-aiofiles", specifier = "~=24.1.0" }, @@ -2038,6 +2060,58 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/b2/5d20664ef6a077bec9f27f7a7ee761edc64946d0b1e293726a3d074a9a18/gevent-24.11.1-cp312-cp312-win_amd64.whl", hash = "sha256:68bee86b6e1c041a187347ef84cf03a792f0b6c7238378bf6ba4118af11feaae", size = 1541631, upload-time = "2024-11-11T14:55:34.977Z" }, ] +[[package]] +name = "geventhttpclient" +version = "2.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "brotli" }, + { name = "certifi" }, + { name = "gevent" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/89/19/1ca8de73dcc0596d3df01be299e940d7fc3bccbeb6f62bb8dd2d427a3a50/geventhttpclient-2.3.4.tar.gz", hash = "sha256:1749f75810435a001fc6d4d7526c92cf02b39b30ab6217a886102f941c874222", size = 83545, upload-time = "2025-06-11T13:18:14.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/c7/c4c31bd92b08c4e34073c722152b05c48c026bc6978cf04f52be7e9050d5/geventhttpclient-2.3.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fb8f6a18f1b5e37724111abbd3edf25f8f00e43dc261b11b10686e17688d2405", size = 71919, upload-time = "2025-06-11T13:16:49.796Z" }, + { url = "https://files.pythonhosted.org/packages/9d/8a/4565e6e768181ecb06677861d949b3679ed29123b6f14333e38767a17b5a/geventhttpclient-2.3.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dbb28455bb5d82ca3024f9eb7d65c8ff6707394b584519def497b5eb9e5b1222", size = 52577, upload-time = "2025-06-11T13:16:50.657Z" }, + { url = "https://files.pythonhosted.org/packages/02/a1/fb623cf478799c08f95774bc41edb8ae4c2f1317ae986b52f233d0f3fa05/geventhttpclient-2.3.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96578fc4a5707b5535d1c25a89e72583e02aafe64d14f3b4d78f9c512c6d613c", size = 51981, upload-time = "2025-06-11T13:16:52.586Z" }, + { url = "https://files.pythonhosted.org/packages/18/b2/a4ddd3d24c8aa064b19b9f180eb5e1517248518289d38af70500569ebedf/geventhttpclient-2.3.4-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19721357db976149ccf54ac279eab8139da8cdf7a11343fd02212891b6f39677", size = 114287, upload-time = "2025-08-24T12:16:47.101Z" }, + { url = "https://files.pythonhosted.org/packages/a1/cc/caac4d4bd2c72d53836dbf50018aed3747c0d0c6f1d08175a785083d9d36/geventhttpclient-2.3.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecf830cdcd1d4d28463c8e0c48f7f5fb06f3c952fff875da279385554d1d4d65", size = 115208, upload-time = "2025-08-24T12:16:48.108Z" }, + { url = "https://files.pythonhosted.org/packages/04/a2/8278bd4d16b9df88bd538824595b7b84efd6f03c7b56b2087d09be838e02/geventhttpclient-2.3.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:47dbf8a163a07f83b38b0f8a35b85e5d193d3af4522ab8a5bbecffff1a4cd462", size = 121101, upload-time = "2025-08-24T12:16:49.417Z" }, + { url = "https://files.pythonhosted.org/packages/e3/0e/a9ebb216140bd0854007ff953094b2af983cdf6d4aec49796572fcbf2606/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e39ad577b33a5be33b47bff7c2dda9b19ced4773d169d6555777cd8445c13c0", size = 118494, upload-time = "2025-06-11T13:16:54.172Z" }, + { url = "https://files.pythonhosted.org/packages/4f/95/6d45dead27e4f5db7a6d277354b0e2877c58efb3cd1687d90a02d5c7b9cd/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:110d863baf7f0a369b6c22be547c5582e87eea70ddda41894715c870b2e82eb0", size = 123860, upload-time = "2025-06-11T13:16:55.824Z" }, + { url = "https://files.pythonhosted.org/packages/70/a1/4baa8dca3d2df94e6ccca889947bb5929aca5b64b59136bbf1779b5777ba/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:226d9fca98469bd770e3efd88326854296d1aa68016f285bd1a2fb6cd21e17ee", size = 114969, upload-time = "2025-06-11T13:16:58.02Z" }, + { url = "https://files.pythonhosted.org/packages/ab/48/123fa67f6fca14c557332a168011565abd9cbdccc5c8b7ed76d9a736aeb2/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71dbc6d4004017ef88c70229809df4ad2317aad4876870c0b6bcd4d6695b7a8d", size = 113311, upload-time = "2025-06-11T13:16:59.423Z" }, + { url = "https://files.pythonhosted.org/packages/93/e4/8a467991127ca6c53dd79a8aecb26a48207e7e7976c578fb6eb31378792c/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ed35391ad697d6cda43c94087f59310f028c3e9fb229e435281a92509469c627", size = 111154, upload-time = "2025-06-11T13:17:01.139Z" }, + { url = "https://files.pythonhosted.org/packages/11/e7/cca0663d90bc8e68592a62d7b28148eb9fd976f739bb107e4c93f9ae6d81/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:97cd2ab03d303fd57dea4f6d9c2ab23b7193846f1b3bbb4c80b315ebb5fc8527", size = 112532, upload-time = "2025-06-11T13:17:03.729Z" }, + { url = "https://files.pythonhosted.org/packages/02/98/625cee18a3be5f7ca74c612d4032b0c013b911eb73c7e72e06fa56a44ba2/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ec4d1aa08569b7eb075942caeacabefee469a0e283c96c7aac0226d5e7598fe8", size = 117806, upload-time = "2025-06-11T13:17:05.138Z" }, + { url = "https://files.pythonhosted.org/packages/f1/5e/e561a5f8c9d98b7258685355aacb9cca8a3c714190cf92438a6e91da09d5/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:93926aacdb0f4289b558f213bc32c03578f3432a18b09e4b6d73a716839d7a74", size = 111392, upload-time = "2025-06-11T13:17:06.053Z" }, + { url = "https://files.pythonhosted.org/packages/d0/37/42d09ad90fd1da960ff68facaa3b79418ccf66297f202ba5361038fc3182/geventhttpclient-2.3.4-cp311-cp311-win32.whl", hash = "sha256:ea87c25e933991366049a42c88e91ad20c2b72e11c7bd38ef68f80486ab63cb2", size = 48332, upload-time = "2025-06-11T13:17:06.965Z" }, + { url = "https://files.pythonhosted.org/packages/4b/0b/55e2a9ed4b1aed7c97e857dc9649a7e804609a105e1ef3cb01da857fbce7/geventhttpclient-2.3.4-cp311-cp311-win_amd64.whl", hash = "sha256:e02e0e9ef2e45475cf33816c8fb2e24595650bcf259e7b15b515a7b49cae1ccf", size = 48969, upload-time = "2025-06-11T13:17:08.239Z" }, + { url = "https://files.pythonhosted.org/packages/4f/72/dcbc6dbf838549b7b0c2c18c1365d2580eb7456939e4b608c3ab213fce78/geventhttpclient-2.3.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9ac30c38d86d888b42bb2ab2738ab9881199609e9fa9a153eb0c66fc9188c6cb", size = 71984, upload-time = "2025-06-11T13:17:09.126Z" }, + { url = "https://files.pythonhosted.org/packages/4c/f9/74aa8c556364ad39b238919c954a0da01a6154ad5e85a1d1ab5f9f5ac186/geventhttpclient-2.3.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b802000a4fad80fa57e895009671d6e8af56777e3adf0d8aee0807e96188fd9", size = 52631, upload-time = "2025-06-11T13:17:10.061Z" }, + { url = "https://files.pythonhosted.org/packages/11/1a/bc4b70cba8b46be8b2c6ca5b8067c4f086f8c90915eb68086ab40ff6243d/geventhttpclient-2.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:461e4d9f4caee481788ec95ac64e0a4a087c1964ddbfae9b6f2dc51715ba706c", size = 51991, upload-time = "2025-06-11T13:17:11.049Z" }, + { url = "https://files.pythonhosted.org/packages/03/3f/5ce6e003b3b24f7caf3207285831afd1a4f857ce98ac45e1fb7a6815bd58/geventhttpclient-2.3.4-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b7e41687c74e8fbe6a665458bbaea0c5a75342a95e2583738364a73bcbf1671b", size = 114982, upload-time = "2025-08-24T12:16:50.76Z" }, + { url = "https://files.pythonhosted.org/packages/60/16/6f9dad141b7c6dd7ee831fbcd72dd02535c57bc1ec3c3282f07e72c31344/geventhttpclient-2.3.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3ea5da20f4023cf40207ce15f5f4028377ffffdba3adfb60b4c8f34925fce79", size = 115654, upload-time = "2025-08-24T12:16:52.072Z" }, + { url = "https://files.pythonhosted.org/packages/ba/52/9b516a2ff423d8bd64c319e1950a165ceebb552781c5a88c1e94e93e8713/geventhttpclient-2.3.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:91f19a8a6899c27867dbdace9500f337d3e891a610708e86078915f1d779bf53", size = 121672, upload-time = "2025-08-24T12:16:53.361Z" }, + { url = "https://files.pythonhosted.org/packages/b0/f5/8d0f1e998f6d933c251b51ef92d11f7eb5211e3cd579018973a2b455f7c5/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41f2dcc0805551ea9d49f9392c3b9296505a89b9387417b148655d0d8251b36e", size = 119012, upload-time = "2025-06-11T13:17:11.956Z" }, + { url = "https://files.pythonhosted.org/packages/ea/0e/59e4ab506b3c19fc72e88ca344d150a9028a00c400b1099637100bec26fc/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:62f3a29bf242ecca6360d497304900683fd8f42cbf1de8d0546c871819251dad", size = 124565, upload-time = "2025-06-11T13:17:12.896Z" }, + { url = "https://files.pythonhosted.org/packages/39/5d/dcbd34dfcda0c016b4970bd583cb260cc5ebfc35b33d0ec9ccdb2293587a/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8714a3f2c093aeda3ffdb14c03571d349cb3ed1b8b461d9f321890659f4a5dbf", size = 115573, upload-time = "2025-06-11T13:17:13.937Z" }, + { url = "https://files.pythonhosted.org/packages/03/51/89af99e4805e9ce7f95562dfbd23c0b0391830831e43d58f940ec74489ac/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b11f38b74bab75282db66226197024a731250dcbe25542fd4e85ac5313547332", size = 114260, upload-time = "2025-06-11T13:17:14.913Z" }, + { url = "https://files.pythonhosted.org/packages/b3/ec/3a3000bda432953abcc6f51d008166fa7abc1eeddd1f0246933d83854f73/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:fccc2023a89dfbce2e1b1409b967011e45d41808df81b7fa0259397db79ba647", size = 111592, upload-time = "2025-06-11T13:17:15.879Z" }, + { url = "https://files.pythonhosted.org/packages/d8/a3/88fd71fe6bbe1315a2d161cbe2cc7810c357d99bced113bea1668ede8bcf/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9d54b8e9a44890159ae36ba4ae44efd8bb79ff519055137a340d357538a68aa3", size = 113216, upload-time = "2025-06-11T13:17:16.883Z" }, + { url = "https://files.pythonhosted.org/packages/52/eb/20435585a6911b26e65f901a827ef13551c053133926f8c28a7cca0fb08e/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:407cb68a3c3a2c4f5d503930298f2b26ae68137d520e8846d8e230a9981d9334", size = 118450, upload-time = "2025-06-11T13:17:17.968Z" }, + { url = "https://files.pythonhosted.org/packages/2f/79/82782283d613570373990b676a0966c1062a38ca8f41a0f20843c5808e01/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:54fbbcca2dcf06f12a337dd8f98417a09a49aa9d9706aa530fc93acb59b7d83c", size = 112226, upload-time = "2025-06-11T13:17:18.942Z" }, + { url = "https://files.pythonhosted.org/packages/9c/c4/417d12fc2a31ad93172b03309c7f8c3a8bbd0cf25b95eb7835de26b24453/geventhttpclient-2.3.4-cp312-cp312-win32.whl", hash = "sha256:83143b41bde2eb010c7056f142cb764cfbf77f16bf78bda2323a160767455cf5", size = 48365, upload-time = "2025-06-11T13:17:20.096Z" }, + { url = "https://files.pythonhosted.org/packages/cf/f4/7e5ee2f460bbbd09cb5d90ff63a1cf80d60f1c60c29dac20326324242377/geventhttpclient-2.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:46eda9a9137b0ca7886369b40995d2a43a5dff033d0a839a54241015d1845d41", size = 48961, upload-time = "2025-06-11T13:17:21.111Z" }, + { url = "https://files.pythonhosted.org/packages/0b/a7/de506f91a1ec67d3c4a53f2aa7475e7ffb869a17b71b94ba370a027a69ac/geventhttpclient-2.3.4-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:707a66cd1e3bf06e2c4f8f21d3b4e6290c9e092456f489c560345a8663cdd93e", size = 50828, upload-time = "2025-06-11T13:17:57.589Z" }, + { url = "https://files.pythonhosted.org/packages/2b/43/86479c278e96cd3e190932b0003d5b8e415660d9e519d59094728ae249da/geventhttpclient-2.3.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:0129ce7ef50e67d66ea5de44d89a3998ab778a4db98093d943d6855323646fa5", size = 50086, upload-time = "2025-06-11T13:17:58.567Z" }, + { url = "https://files.pythonhosted.org/packages/e8/f7/d3e04f95de14db3ca4fe126eb0e3ec24356125c5ca1f471a9b28b1d7714d/geventhttpclient-2.3.4-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fac2635f68b3b6752c2a576833d9d18f0af50bdd4bd7dd2d2ca753e3b8add84c", size = 54523, upload-time = "2025-06-11T13:17:59.536Z" }, + { url = "https://files.pythonhosted.org/packages/45/a7/d80c9ec1663f70f4bd976978bf86b3d0d123a220c4ae636c66d02d3accdb/geventhttpclient-2.3.4-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:71206ab89abdd0bd5fee21e04a3995ec1f7d8ae1478ee5868f9e16e85a831653", size = 58866, upload-time = "2025-06-11T13:18:03.719Z" }, + { url = "https://files.pythonhosted.org/packages/55/92/d874ff7e52803cef3850bf8875816a9f32e0a154b079a74e6663534bef30/geventhttpclient-2.3.4-pp311-pypy311_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8bde667d0ce46065fe57f8ff24b2e94f620a5747378c97314dcfc8fbab35b73", size = 54766, upload-time = "2025-06-11T13:18:04.724Z" }, + { url = "https://files.pythonhosted.org/packages/a8/73/2e03125170485193fcc99ef23b52749543d6c6711706d58713fe315869c4/geventhttpclient-2.3.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:5f71c75fc138331cbbe668a08951d36b641d2c26fb3677d7e497afb8419538db", size = 49011, upload-time = "2025-06-11T13:18:05.702Z" }, +] + [[package]] name = "gitdb" version = "4.0.12" @@ -2679,7 +2753,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.33.2" +version = "0.34.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -2691,9 +2765,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fa/42/8a95c5632080ae312c0498744b2b852195e10b05a20b1be11c5141092f4c/huggingface_hub-0.33.2.tar.gz", hash = "sha256:84221defaec8fa09c090390cd68c78b88e3c4c2b7befba68d3dc5aacbc3c2c5f", size = 426637, upload-time = "2025-07-02T06:26:05.156Z" } +sdist = { url = "https://files.pythonhosted.org/packages/45/c9/bdbe19339f76d12985bc03572f330a01a93c04dffecaaea3061bdd7fb892/huggingface_hub-0.34.4.tar.gz", hash = "sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c", size = 459768, upload-time = "2025-08-08T09:14:52.365Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/f4/5f3f22e762ad1965f01122b42dae5bf0e009286e2dba601ce1d0dba72424/huggingface_hub-0.33.2-py3-none-any.whl", hash = "sha256:3749498bfa91e8cde2ddc2c1db92c79981f40e66434c20133b39e5928ac9bcc5", size = 515373, upload-time = "2025-07-02T06:26:03.072Z" }, + { url = "https://files.pythonhosted.org/packages/39/7b/bb06b061991107cd8783f300adff3e7b7f284e330fd82f507f2a1417b11d/huggingface_hub-0.34.4-py3-none-any.whl", hash = "sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a", size = 561452, upload-time = "2025-08-08T09:14:50.159Z" }, ] [[package]] @@ -3025,6 +3099,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e2/3b/a9a17366af80127bd09decbe2a54d8974b6d8b274b39bf47fbaedeec6307/llvmlite-0.44.0-cp312-cp312-win_amd64.whl", hash = "sha256:eae7e2d4ca8f88f89d315b48c6b741dcb925d6a1042da694aa16ab3dd4cbd3a1", size = 30332380, upload-time = "2025-01-20T11:14:02.442Z" }, ] +[[package]] +name = "locust" +version = "2.40.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "configargparse" }, + { name = "flask" }, + { name = "flask-cors" }, + { name = "flask-login" }, + { name = "gevent" }, + { name = "geventhttpclient" }, + { name = "locust-cloud" }, + { name = "msgpack" }, + { name = "psutil" }, + { name = "pytest" }, + { name = "python-engineio" }, + { name = "python-socketio", extra = ["client"] }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "pyzmq" }, + { name = "requests" }, + { name = "setuptools" }, + { name = "typing-extensions", marker = "python_full_version < '3.12'" }, + { name = "werkzeug" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c8/40/31ff56ab6f46c7c77e61bbbd23f87fdf6a4aaf674dc961a3c573320caedc/locust-2.40.4.tar.gz", hash = "sha256:3a3a470459edc4ba1349229bf1aca4c0cb651c4e2e3f85d3bc28fe8118f5a18f", size = 1412529, upload-time = "2025-09-11T09:26:13.713Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7e/db1d969caf45ce711e81cd4f3e7c4554c3925a02383a1dcadb442eae3802/locust-2.40.4-py3-none-any.whl", hash = "sha256:50e647a73c5a4e7a775c6e4311979472fce8b00ed783837a2ce9bb36786f7d1a", size = 1430961, upload-time = "2025-09-11T09:26:11.623Z" }, +] + +[[package]] +name = "locust-cloud" +version = "1.26.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "configargparse" }, + { name = "gevent" }, + { name = "platformdirs" }, + { name = "python-engineio" }, + { name = "python-socketio", extra = ["client"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/84/ad/10b299b134068a4250a9156e6832a717406abe1dfea2482a07ae7bdca8f3/locust_cloud-1.26.3.tar.gz", hash = "sha256:587acfd4d2dee715fb5f0c3c2d922770babf0b7cff7b2927afbb693a9cd193cc", size = 456042, upload-time = "2025-07-15T19:51:53.791Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/6a/276fc50a9d170e7cbb6715735480cb037abb526639bca85491576e6eee4a/locust_cloud-1.26.3-py3-none-any.whl", hash = "sha256:8cb4b8bb9adcd5b99327bc8ed1d98cf67a29d9d29512651e6e94869de6f1faa8", size = 410023, upload-time = "2025-07-15T19:51:52.056Z" }, +] + [[package]] name = "lxml" version = "6.0.0" @@ -3296,6 +3415,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/75/bd9b7bb966668920f06b200e84454c8f3566b102183bc55c5473d96cb2b9/msal_extensions-1.3.1-py3-none-any.whl", hash = "sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca", size = 20583, upload-time = "2025-03-14T23:51:03.016Z" }, ] +[[package]] +name = "msgpack" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/45/b1/ea4f68038a18c77c9467400d166d74c4ffa536f34761f7983a104357e614/msgpack-1.1.1.tar.gz", hash = "sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd", size = 173555, upload-time = "2025-06-13T06:52:51.324Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/83/97f24bf9848af23fe2ba04380388216defc49a8af6da0c28cc636d722502/msgpack-1.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:71ef05c1726884e44f8b1d1773604ab5d4d17729d8491403a705e649116c9558", size = 82728, upload-time = "2025-06-13T06:51:50.68Z" }, + { url = "https://files.pythonhosted.org/packages/aa/7f/2eaa388267a78401f6e182662b08a588ef4f3de6f0eab1ec09736a7aaa2b/msgpack-1.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:36043272c6aede309d29d56851f8841ba907a1a3d04435e43e8a19928e243c1d", size = 79279, upload-time = "2025-06-13T06:51:51.72Z" }, + { url = "https://files.pythonhosted.org/packages/f8/46/31eb60f4452c96161e4dfd26dbca562b4ec68c72e4ad07d9566d7ea35e8a/msgpack-1.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a32747b1b39c3ac27d0670122b57e6e57f28eefb725e0b625618d1b59bf9d1e0", size = 423859, upload-time = "2025-06-13T06:51:52.749Z" }, + { url = "https://files.pythonhosted.org/packages/45/16/a20fa8c32825cc7ae8457fab45670c7a8996d7746ce80ce41cc51e3b2bd7/msgpack-1.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a8b10fdb84a43e50d38057b06901ec9da52baac6983d3f709d8507f3889d43f", size = 429975, upload-time = "2025-06-13T06:51:53.97Z" }, + { url = "https://files.pythonhosted.org/packages/86/ea/6c958e07692367feeb1a1594d35e22b62f7f476f3c568b002a5ea09d443d/msgpack-1.1.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba0c325c3f485dc54ec298d8b024e134acf07c10d494ffa24373bea729acf704", size = 413528, upload-time = "2025-06-13T06:51:55.507Z" }, + { url = "https://files.pythonhosted.org/packages/75/05/ac84063c5dae79722bda9f68b878dc31fc3059adb8633c79f1e82c2cd946/msgpack-1.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:88daaf7d146e48ec71212ce21109b66e06a98e5e44dca47d853cbfe171d6c8d2", size = 413338, upload-time = "2025-06-13T06:51:57.023Z" }, + { url = "https://files.pythonhosted.org/packages/69/e8/fe86b082c781d3e1c09ca0f4dacd457ede60a13119b6ce939efe2ea77b76/msgpack-1.1.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d8b55ea20dc59b181d3f47103f113e6f28a5e1c89fd5b67b9140edb442ab67f2", size = 422658, upload-time = "2025-06-13T06:51:58.419Z" }, + { url = "https://files.pythonhosted.org/packages/3b/2b/bafc9924df52d8f3bb7c00d24e57be477f4d0f967c0a31ef5e2225e035c7/msgpack-1.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4a28e8072ae9779f20427af07f53bbb8b4aa81151054e882aee333b158da8752", size = 427124, upload-time = "2025-06-13T06:51:59.969Z" }, + { url = "https://files.pythonhosted.org/packages/a2/3b/1f717e17e53e0ed0b68fa59e9188f3f610c79d7151f0e52ff3cd8eb6b2dc/msgpack-1.1.1-cp311-cp311-win32.whl", hash = "sha256:7da8831f9a0fdb526621ba09a281fadc58ea12701bc709e7b8cbc362feabc295", size = 65016, upload-time = "2025-06-13T06:52:01.294Z" }, + { url = "https://files.pythonhosted.org/packages/48/45/9d1780768d3b249accecc5a38c725eb1e203d44a191f7b7ff1941f7df60c/msgpack-1.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fd1b58e1431008a57247d6e7cc4faa41c3607e8e7d4aaf81f7c29ea013cb458", size = 72267, upload-time = "2025-06-13T06:52:02.568Z" }, + { url = "https://files.pythonhosted.org/packages/e3/26/389b9c593eda2b8551b2e7126ad3a06af6f9b44274eb3a4f054d48ff7e47/msgpack-1.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ae497b11f4c21558d95de9f64fff7053544f4d1a17731c866143ed6bb4591238", size = 82359, upload-time = "2025-06-13T06:52:03.909Z" }, + { url = "https://files.pythonhosted.org/packages/ab/65/7d1de38c8a22cf8b1551469159d4b6cf49be2126adc2482de50976084d78/msgpack-1.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:33be9ab121df9b6b461ff91baac6f2731f83d9b27ed948c5b9d1978ae28bf157", size = 79172, upload-time = "2025-06-13T06:52:05.246Z" }, + { url = "https://files.pythonhosted.org/packages/0f/bd/cacf208b64d9577a62c74b677e1ada005caa9b69a05a599889d6fc2ab20a/msgpack-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f64ae8fe7ffba251fecb8408540c34ee9df1c26674c50c4544d72dbf792e5ce", size = 425013, upload-time = "2025-06-13T06:52:06.341Z" }, + { url = "https://files.pythonhosted.org/packages/4d/ec/fd869e2567cc9c01278a736cfd1697941ba0d4b81a43e0aa2e8d71dab208/msgpack-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a494554874691720ba5891c9b0b39474ba43ffb1aaf32a5dac874effb1619e1a", size = 426905, upload-time = "2025-06-13T06:52:07.501Z" }, + { url = "https://files.pythonhosted.org/packages/55/2a/35860f33229075bce803a5593d046d8b489d7ba2fc85701e714fc1aaf898/msgpack-1.1.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb643284ab0ed26f6957d969fe0dd8bb17beb567beb8998140b5e38a90974f6c", size = 407336, upload-time = "2025-06-13T06:52:09.047Z" }, + { url = "https://files.pythonhosted.org/packages/8c/16/69ed8f3ada150bf92745fb4921bd621fd2cdf5a42e25eb50bcc57a5328f0/msgpack-1.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d275a9e3c81b1093c060c3837e580c37f47c51eca031f7b5fb76f7b8470f5f9b", size = 409485, upload-time = "2025-06-13T06:52:10.382Z" }, + { url = "https://files.pythonhosted.org/packages/c6/b6/0c398039e4c6d0b2e37c61d7e0e9d13439f91f780686deb8ee64ecf1ae71/msgpack-1.1.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4fd6b577e4541676e0cc9ddc1709d25014d3ad9a66caa19962c4f5de30fc09ef", size = 412182, upload-time = "2025-06-13T06:52:11.644Z" }, + { url = "https://files.pythonhosted.org/packages/b8/d0/0cf4a6ecb9bc960d624c93effaeaae75cbf00b3bc4a54f35c8507273cda1/msgpack-1.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bb29aaa613c0a1c40d1af111abf025f1732cab333f96f285d6a93b934738a68a", size = 419883, upload-time = "2025-06-13T06:52:12.806Z" }, + { url = "https://files.pythonhosted.org/packages/62/83/9697c211720fa71a2dfb632cad6196a8af3abea56eece220fde4674dc44b/msgpack-1.1.1-cp312-cp312-win32.whl", hash = "sha256:870b9a626280c86cff9c576ec0d9cbcc54a1e5ebda9cd26dab12baf41fee218c", size = 65406, upload-time = "2025-06-13T06:52:14.271Z" }, + { url = "https://files.pythonhosted.org/packages/c0/23/0abb886e80eab08f5e8c485d6f13924028602829f63b8f5fa25a06636628/msgpack-1.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:5692095123007180dca3e788bb4c399cc26626da51629a31d40207cb262e67f4", size = 72558, upload-time = "2025-06-13T06:52:15.252Z" }, +] + [[package]] name = "msrest" version = "0.7.1" @@ -4892,6 +5039,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863, upload-time = "2024-01-23T06:32:58.246Z" }, ] +[[package]] +name = "python-engineio" +version = "4.12.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "simple-websocket" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/0b/67295279b66835f9fa7a491650efcd78b20321c127036eef62c11a31e028/python_engineio-4.12.2.tar.gz", hash = "sha256:e7e712ffe1be1f6a05ee5f951e72d434854a32fcfc7f6e4d9d3cae24ec70defa", size = 91677, upload-time = "2025-06-04T19:22:18.789Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" }, +] + [[package]] name = "python-http-client" version = "3.3.7" @@ -4948,6 +5107,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/4f/00be2196329ebbff56ce564aa94efb0fbc828d00de250b1980de1a34ab49/python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba", size = 472788, upload-time = "2024-08-07T17:33:28.192Z" }, ] +[[package]] +name = "python-socketio" +version = "5.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bidict" }, + { name = "python-engineio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/21/1a/396d50ccf06ee539fa758ce5623b59a9cb27637fc4b2dc07ed08bf495e77/python_socketio-5.13.0.tar.gz", hash = "sha256:ac4e19a0302ae812e23b712ec8b6427ca0521f7c582d6abb096e36e24a263029", size = 121125, upload-time = "2025-04-12T15:46:59.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/32/b4fb8585d1be0f68bde7e110dffbcf354915f77ad8c778563f0ad9655c02/python_socketio-5.13.0-py3-none-any.whl", hash = "sha256:51f68d6499f2df8524668c24bcec13ba1414117cfb3a90115c559b601ab10caf", size = 77800, upload-time = "2025-04-12T15:46:58.412Z" }, +] + +[package.optional-dependencies] +client = [ + { name = "requests" }, + { name = "websocket-client" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -5005,6 +5183,42 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338, upload-time = "2024-08-06T20:32:41.93Z" }, ] +[[package]] +name = "pyzmq" +version = "27.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "implementation_name == 'pypy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/04/0b/3c9baedbdf613ecaa7aa07027780b8867f57b6293b6ee50de316c9f3222b/pyzmq-27.1.0.tar.gz", hash = "sha256:ac0765e3d44455adb6ddbf4417dcce460fc40a05978c08efdf2948072f6db540", size = 281750, upload-time = "2025-09-08T23:10:18.157Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/5d/305323ba86b284e6fcb0d842d6adaa2999035f70f8c38a9b6d21ad28c3d4/pyzmq-27.1.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:226b091818d461a3bef763805e75685e478ac17e9008f49fce2d3e52b3d58b86", size = 1333328, upload-time = "2025-09-08T23:07:45.946Z" }, + { url = "https://files.pythonhosted.org/packages/bd/a0/fc7e78a23748ad5443ac3275943457e8452da67fda347e05260261108cbc/pyzmq-27.1.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:0790a0161c281ca9723f804871b4027f2e8b5a528d357c8952d08cd1a9c15581", size = 908803, upload-time = "2025-09-08T23:07:47.551Z" }, + { url = "https://files.pythonhosted.org/packages/7e/22/37d15eb05f3bdfa4abea6f6d96eb3bb58585fbd3e4e0ded4e743bc650c97/pyzmq-27.1.0-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c895a6f35476b0c3a54e3eb6ccf41bf3018de937016e6e18748317f25d4e925f", size = 668836, upload-time = "2025-09-08T23:07:49.436Z" }, + { url = "https://files.pythonhosted.org/packages/b1/c4/2a6fe5111a01005fc7af3878259ce17684fabb8852815eda6225620f3c59/pyzmq-27.1.0-cp311-cp311-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5bbf8d3630bf96550b3be8e1fc0fea5cbdc8d5466c1192887bd94869da17a63e", size = 857038, upload-time = "2025-09-08T23:07:51.234Z" }, + { url = "https://files.pythonhosted.org/packages/cb/eb/bfdcb41d0db9cd233d6fb22dc131583774135505ada800ebf14dfb0a7c40/pyzmq-27.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:15c8bd0fe0dabf808e2d7a681398c4e5ded70a551ab47482067a572c054c8e2e", size = 1657531, upload-time = "2025-09-08T23:07:52.795Z" }, + { url = "https://files.pythonhosted.org/packages/ab/21/e3180ca269ed4a0de5c34417dfe71a8ae80421198be83ee619a8a485b0c7/pyzmq-27.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:bafcb3dd171b4ae9f19ee6380dfc71ce0390fefaf26b504c0e5f628d7c8c54f2", size = 2034786, upload-time = "2025-09-08T23:07:55.047Z" }, + { url = "https://files.pythonhosted.org/packages/3b/b1/5e21d0b517434b7f33588ff76c177c5a167858cc38ef740608898cd329f2/pyzmq-27.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e829529fcaa09937189178115c49c504e69289abd39967cd8a4c215761373394", size = 1894220, upload-time = "2025-09-08T23:07:57.172Z" }, + { url = "https://files.pythonhosted.org/packages/03/f2/44913a6ff6941905efc24a1acf3d3cb6146b636c546c7406c38c49c403d4/pyzmq-27.1.0-cp311-cp311-win32.whl", hash = "sha256:6df079c47d5902af6db298ec92151db82ecb557af663098b92f2508c398bb54f", size = 567155, upload-time = "2025-09-08T23:07:59.05Z" }, + { url = "https://files.pythonhosted.org/packages/23/6d/d8d92a0eb270a925c9b4dd039c0b4dc10abc2fcbc48331788824ef113935/pyzmq-27.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:190cbf120fbc0fc4957b56866830def56628934a9d112aec0e2507aa6a032b97", size = 633428, upload-time = "2025-09-08T23:08:00.663Z" }, + { url = "https://files.pythonhosted.org/packages/ae/14/01afebc96c5abbbd713ecfc7469cfb1bc801c819a74ed5c9fad9a48801cb/pyzmq-27.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:eca6b47df11a132d1745eb3b5b5e557a7dae2c303277aa0e69c6ba91b8736e07", size = 559497, upload-time = "2025-09-08T23:08:02.15Z" }, + { url = "https://files.pythonhosted.org/packages/92/e7/038aab64a946d535901103da16b953c8c9cc9c961dadcbf3609ed6428d23/pyzmq-27.1.0-cp312-abi3-macosx_10_15_universal2.whl", hash = "sha256:452631b640340c928fa343801b0d07eb0c3789a5ffa843f6e1a9cee0ba4eb4fc", size = 1306279, upload-time = "2025-09-08T23:08:03.807Z" }, + { url = "https://files.pythonhosted.org/packages/e8/5e/c3c49fdd0f535ef45eefcc16934648e9e59dace4a37ee88fc53f6cd8e641/pyzmq-27.1.0-cp312-abi3-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:1c179799b118e554b66da67d88ed66cd37a169f1f23b5d9f0a231b4e8d44a113", size = 895645, upload-time = "2025-09-08T23:08:05.301Z" }, + { url = "https://files.pythonhosted.org/packages/f8/e5/b0b2504cb4e903a74dcf1ebae157f9e20ebb6ea76095f6cfffea28c42ecd/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3837439b7f99e60312f0c926a6ad437b067356dc2bc2ec96eb395fd0fe804233", size = 652574, upload-time = "2025-09-08T23:08:06.828Z" }, + { url = "https://files.pythonhosted.org/packages/f8/9b/c108cdb55560eaf253f0cbdb61b29971e9fb34d9c3499b0e96e4e60ed8a5/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43ad9a73e3da1fab5b0e7e13402f0b2fb934ae1c876c51d0afff0e7c052eca31", size = 840995, upload-time = "2025-09-08T23:08:08.396Z" }, + { url = "https://files.pythonhosted.org/packages/c2/bb/b79798ca177b9eb0825b4c9998c6af8cd2a7f15a6a1a4272c1d1a21d382f/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0de3028d69d4cdc475bfe47a6128eb38d8bc0e8f4d69646adfbcd840facbac28", size = 1642070, upload-time = "2025-09-08T23:08:09.989Z" }, + { url = "https://files.pythonhosted.org/packages/9c/80/2df2e7977c4ede24c79ae39dcef3899bfc5f34d1ca7a5b24f182c9b7a9ca/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_i686.whl", hash = "sha256:cf44a7763aea9298c0aa7dbf859f87ed7012de8bda0f3977b6fb1d96745df856", size = 2021121, upload-time = "2025-09-08T23:08:11.907Z" }, + { url = "https://files.pythonhosted.org/packages/46/bd/2d45ad24f5f5ae7e8d01525eb76786fa7557136555cac7d929880519e33a/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f30f395a9e6fbca195400ce833c731e7b64c3919aa481af4d88c3759e0cb7496", size = 1878550, upload-time = "2025-09-08T23:08:13.513Z" }, + { url = "https://files.pythonhosted.org/packages/e6/2f/104c0a3c778d7c2ab8190e9db4f62f0b6957b53c9d87db77c284b69f33ea/pyzmq-27.1.0-cp312-abi3-win32.whl", hash = "sha256:250e5436a4ba13885494412b3da5d518cd0d3a278a1ae640e113c073a5f88edd", size = 559184, upload-time = "2025-09-08T23:08:15.163Z" }, + { url = "https://files.pythonhosted.org/packages/fc/7f/a21b20d577e4100c6a41795842028235998a643b1ad406a6d4163ea8f53e/pyzmq-27.1.0-cp312-abi3-win_amd64.whl", hash = "sha256:9ce490cf1d2ca2ad84733aa1d69ce6855372cb5ce9223802450c9b2a7cba0ccf", size = 619480, upload-time = "2025-09-08T23:08:17.192Z" }, + { url = "https://files.pythonhosted.org/packages/78/c2/c012beae5f76b72f007a9e91ee9401cb88c51d0f83c6257a03e785c81cc2/pyzmq-27.1.0-cp312-abi3-win_arm64.whl", hash = "sha256:75a2f36223f0d535a0c919e23615fc85a1e23b71f40c7eb43d7b1dedb4d8f15f", size = 552993, upload-time = "2025-09-08T23:08:18.926Z" }, + { url = "https://files.pythonhosted.org/packages/4c/c6/c4dcdecdbaa70969ee1fdced6d7b8f60cfabe64d25361f27ac4665a70620/pyzmq-27.1.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:18770c8d3563715387139060d37859c02ce40718d1faf299abddcdcc6a649066", size = 836265, upload-time = "2025-09-08T23:09:49.376Z" }, + { url = "https://files.pythonhosted.org/packages/3e/79/f38c92eeaeb03a2ccc2ba9866f0439593bb08c5e3b714ac1d553e5c96e25/pyzmq-27.1.0-pp311-pypy311_pp73-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:ac25465d42f92e990f8d8b0546b01c391ad431c3bf447683fdc40565941d0604", size = 800208, upload-time = "2025-09-08T23:09:51.073Z" }, + { url = "https://files.pythonhosted.org/packages/49/0e/3f0d0d335c6b3abb9b7b723776d0b21fa7f3a6c819a0db6097059aada160/pyzmq-27.1.0-pp311-pypy311_pp73-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53b40f8ae006f2734ee7608d59ed661419f087521edbfc2149c3932e9c14808c", size = 567747, upload-time = "2025-09-08T23:09:52.698Z" }, + { url = "https://files.pythonhosted.org/packages/a1/cf/f2b3784d536250ffd4be70e049f3b60981235d70c6e8ce7e3ef21e1adb25/pyzmq-27.1.0-pp311-pypy311_pp73-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f605d884e7c8be8fe1aa94e0a783bf3f591b84c24e4bc4f3e7564c82ac25e271", size = 747371, upload-time = "2025-09-08T23:09:54.563Z" }, + { url = "https://files.pythonhosted.org/packages/01/1b/5dbe84eefc86f48473947e2f41711aded97eecef1231f4558f1f02713c12/pyzmq-27.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:c9f7f6e13dff2e44a6afeaf2cf54cee5929ad64afaf4d40b50f93c58fc687355", size = 544862, upload-time = "2025-09-08T23:09:56.509Z" }, +] + [[package]] name = "qdrant-client" version = "1.9.0" @@ -5453,6 +5667,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, ] +[[package]] +name = "simple-websocket" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wsproto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b0/d4/bfa032f961103eba93de583b161f0e6a5b63cebb8f2c7d0c6e6efe1e3d2e/simple_websocket-1.1.0.tar.gz", hash = "sha256:7939234e7aa067c534abdab3a9ed933ec9ce4691b0713c78acb195560aa52ae4", size = 17300, upload-time = "2024-10-10T22:39:31.412Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/59/0782e51887ac6b07ffd1570e0364cf901ebc36345fea669969d2084baebb/simple_websocket-1.1.0-py3-none-any.whl", hash = "sha256:4af6069630a38ed6c561010f0e11a5bc0d4ca569b36306eb257cd9a192497c8c", size = 13842, upload-time = "2024-10-10T22:39:29.645Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -5772,27 +5998,27 @@ wheels = [ [[package]] name = "tokenizers" -version = "0.21.2" +version = "0.22.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ab/2d/b0fce2b8201635f60e8c95990080f58461cc9ca3d5026de2e900f38a7f21/tokenizers-0.21.2.tar.gz", hash = "sha256:fdc7cffde3e2113ba0e6cc7318c40e3438a4d74bbc62bf04bcc63bdfb082ac77", size = 351545, upload-time = "2025-06-24T10:24:52.449Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/b4/c1ce3699e81977da2ace8b16d2badfd42b060e7d33d75c4ccdbf9dc920fa/tokenizers-0.22.0.tar.gz", hash = "sha256:2e33b98525be8453f355927f3cab312c36cd3e44f4d7e9e97da2fa94d0a49dcb", size = 362771, upload-time = "2025-08-29T10:25:33.914Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1d/cc/2936e2d45ceb130a21d929743f1e9897514691bec123203e10837972296f/tokenizers-0.21.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:342b5dfb75009f2255ab8dec0041287260fed5ce00c323eb6bab639066fef8ec", size = 2875206, upload-time = "2025-06-24T10:24:42.755Z" }, - { url = "https://files.pythonhosted.org/packages/6c/e6/33f41f2cc7861faeba8988e7a77601407bf1d9d28fc79c5903f8f77df587/tokenizers-0.21.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:126df3205d6f3a93fea80c7a8a266a78c1bd8dd2fe043386bafdd7736a23e45f", size = 2732655, upload-time = "2025-06-24T10:24:41.56Z" }, - { url = "https://files.pythonhosted.org/packages/33/2b/1791eb329c07122a75b01035b1a3aa22ad139f3ce0ece1b059b506d9d9de/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a32cd81be21168bd0d6a0f0962d60177c447a1aa1b1e48fa6ec9fc728ee0b12", size = 3019202, upload-time = "2025-06-24T10:24:31.791Z" }, - { url = "https://files.pythonhosted.org/packages/05/15/fd2d8104faa9f86ac68748e6f7ece0b5eb7983c7efc3a2c197cb98c99030/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8bd8999538c405133c2ab999b83b17c08b7fc1b48c1ada2469964605a709ef91", size = 2934539, upload-time = "2025-06-24T10:24:34.567Z" }, - { url = "https://files.pythonhosted.org/packages/a5/2e/53e8fd053e1f3ffbe579ca5f9546f35ac67cf0039ed357ad7ec57f5f5af0/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5e9944e61239b083a41cf8fc42802f855e1dca0f499196df37a8ce219abac6eb", size = 3248665, upload-time = "2025-06-24T10:24:39.024Z" }, - { url = "https://files.pythonhosted.org/packages/00/15/79713359f4037aa8f4d1f06ffca35312ac83629da062670e8830917e2153/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:514cd43045c5d546f01142ff9c79a96ea69e4b5cda09e3027708cb2e6d5762ab", size = 3451305, upload-time = "2025-06-24T10:24:36.133Z" }, - { url = "https://files.pythonhosted.org/packages/38/5f/959f3a8756fc9396aeb704292777b84f02a5c6f25c3fc3ba7530db5feb2c/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b1b9405822527ec1e0f7d8d2fdb287a5730c3a6518189c968254a8441b21faae", size = 3214757, upload-time = "2025-06-24T10:24:37.784Z" }, - { url = "https://files.pythonhosted.org/packages/c5/74/f41a432a0733f61f3d21b288de6dfa78f7acff309c6f0f323b2833e9189f/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fed9a4d51c395103ad24f8e7eb976811c57fbec2af9f133df471afcd922e5020", size = 3121887, upload-time = "2025-06-24T10:24:40.293Z" }, - { url = "https://files.pythonhosted.org/packages/3c/6a/bc220a11a17e5d07b0dfb3b5c628621d4dcc084bccd27cfaead659963016/tokenizers-0.21.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2c41862df3d873665ec78b6be36fcc30a26e3d4902e9dd8608ed61d49a48bc19", size = 9091965, upload-time = "2025-06-24T10:24:44.431Z" }, - { url = "https://files.pythonhosted.org/packages/6c/bd/ac386d79c4ef20dc6f39c4706640c24823dca7ebb6f703bfe6b5f0292d88/tokenizers-0.21.2-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:ed21dc7e624e4220e21758b2e62893be7101453525e3d23264081c9ef9a6d00d", size = 9053372, upload-time = "2025-06-24T10:24:46.455Z" }, - { url = "https://files.pythonhosted.org/packages/63/7b/5440bf203b2a5358f074408f7f9c42884849cd9972879e10ee6b7a8c3b3d/tokenizers-0.21.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:0e73770507e65a0e0e2a1affd6b03c36e3bc4377bd10c9ccf51a82c77c0fe365", size = 9298632, upload-time = "2025-06-24T10:24:48.446Z" }, - { url = "https://files.pythonhosted.org/packages/a4/d2/faa1acac3f96a7427866e94ed4289949b2524f0c1878512516567d80563c/tokenizers-0.21.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:106746e8aa9014a12109e58d540ad5465b4c183768ea96c03cbc24c44d329958", size = 9470074, upload-time = "2025-06-24T10:24:50.378Z" }, - { url = "https://files.pythonhosted.org/packages/d8/a5/896e1ef0707212745ae9f37e84c7d50269411aef2e9ccd0de63623feecdf/tokenizers-0.21.2-cp39-abi3-win32.whl", hash = "sha256:cabda5a6d15d620b6dfe711e1af52205266d05b379ea85a8a301b3593c60e962", size = 2330115, upload-time = "2025-06-24T10:24:55.069Z" }, - { url = "https://files.pythonhosted.org/packages/13/c3/cc2755ee10be859c4338c962a35b9a663788c0c0b50c0bdd8078fb6870cf/tokenizers-0.21.2-cp39-abi3-win_amd64.whl", hash = "sha256:58747bb898acdb1007f37a7bbe614346e98dc28708ffb66a3fd50ce169ac6c98", size = 2509918, upload-time = "2025-06-24T10:24:53.71Z" }, + { url = "https://files.pythonhosted.org/packages/6d/b1/18c13648edabbe66baa85fe266a478a7931ddc0cd1ba618802eb7b8d9865/tokenizers-0.22.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:eaa9620122a3fb99b943f864af95ed14c8dfc0f47afa3b404ac8c16b3f2bb484", size = 3081954, upload-time = "2025-08-29T10:25:24.993Z" }, + { url = "https://files.pythonhosted.org/packages/c2/02/c3c454b641bd7c4f79e4464accfae9e7dfc913a777d2e561e168ae060362/tokenizers-0.22.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:71784b9ab5bf0ff3075bceeb198149d2c5e068549c0d18fe32d06ba0deb63f79", size = 2945644, upload-time = "2025-08-29T10:25:23.405Z" }, + { url = "https://files.pythonhosted.org/packages/55/02/d10185ba2fd8c2d111e124c9d92de398aee0264b35ce433f79fb8472f5d0/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec5b71f668a8076802b0241a42387d48289f25435b86b769ae1837cad4172a17", size = 3254764, upload-time = "2025-08-29T10:25:12.445Z" }, + { url = "https://files.pythonhosted.org/packages/13/89/17514bd7ef4bf5bfff58e2b131cec0f8d5cea2b1c8ffe1050a2c8de88dbb/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ea8562fa7498850d02a16178105b58803ea825b50dc9094d60549a7ed63654bb", size = 3161654, upload-time = "2025-08-29T10:25:15.493Z" }, + { url = "https://files.pythonhosted.org/packages/5a/d8/bac9f3a7ef6dcceec206e3857c3b61bb16c6b702ed7ae49585f5bd85c0ef/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4136e1558a9ef2e2f1de1555dcd573e1cbc4a320c1a06c4107a3d46dc8ac6e4b", size = 3511484, upload-time = "2025-08-29T10:25:20.477Z" }, + { url = "https://files.pythonhosted.org/packages/aa/27/9c9800eb6763683010a4851db4d1802d8cab9cec114c17056eccb4d4a6e0/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cdf5954de3962a5fd9781dc12048d24a1a6f1f5df038c6e95db328cd22964206", size = 3712829, upload-time = "2025-08-29T10:25:17.154Z" }, + { url = "https://files.pythonhosted.org/packages/10/e3/b1726dbc1f03f757260fa21752e1921445b5bc350389a8314dd3338836db/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8337ca75d0731fc4860e6204cc24bb36a67d9736142aa06ed320943b50b1e7ed", size = 3408934, upload-time = "2025-08-29T10:25:18.76Z" }, + { url = "https://files.pythonhosted.org/packages/d4/61/aeab3402c26874b74bb67a7f2c4b569dde29b51032c5384db592e7b216f4/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a89264e26f63c449d8cded9061adea7b5de53ba2346fc7e87311f7e4117c1cc8", size = 3345585, upload-time = "2025-08-29T10:25:22.08Z" }, + { url = "https://files.pythonhosted.org/packages/bc/d3/498b4a8a8764cce0900af1add0f176ff24f475d4413d55b760b8cdf00893/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:790bad50a1b59d4c21592f9c3cf5e5cf9c3c7ce7e1a23a739f13e01fb1be377a", size = 9322986, upload-time = "2025-08-29T10:25:26.607Z" }, + { url = "https://files.pythonhosted.org/packages/a2/62/92378eb1c2c565837ca3cb5f9569860d132ab9d195d7950c1ea2681dffd0/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:76cf6757c73a10ef10bf06fa937c0ec7393d90432f543f49adc8cab3fb6f26cb", size = 9276630, upload-time = "2025-08-29T10:25:28.349Z" }, + { url = "https://files.pythonhosted.org/packages/eb/f0/342d80457aa1cda7654327460f69db0d69405af1e4c453f4dc6ca7c4a76e/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:1626cb186e143720c62c6c6b5371e62bbc10af60481388c0da89bc903f37ea0c", size = 9547175, upload-time = "2025-08-29T10:25:29.989Z" }, + { url = "https://files.pythonhosted.org/packages/14/84/8aa9b4adfc4fbd09381e20a5bc6aa27040c9c09caa89988c01544e008d18/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:da589a61cbfea18ae267723d6b029b84598dc8ca78db9951d8f5beff72d8507c", size = 9692735, upload-time = "2025-08-29T10:25:32.089Z" }, + { url = "https://files.pythonhosted.org/packages/bf/24/83ee2b1dc76bfe05c3142e7d0ccdfe69f0ad2f1ebf6c726cea7f0874c0d0/tokenizers-0.22.0-cp39-abi3-win32.whl", hash = "sha256:dbf9d6851bddae3e046fedfb166f47743c1c7bd11c640f0691dd35ef0bcad3be", size = 2471915, upload-time = "2025-08-29T10:25:36.411Z" }, + { url = "https://files.pythonhosted.org/packages/d1/9b/0e0bf82214ee20231845b127aa4a8015936ad5a46779f30865d10e404167/tokenizers-0.22.0-cp39-abi3-win_amd64.whl", hash = "sha256:c78174859eeaee96021f248a56c801e36bfb6bd5b067f2e95aa82445ca324f00", size = 2680494, upload-time = "2025-08-29T10:25:35.14Z" }, ] [[package]] @@ -5860,7 +6086,7 @@ wheels = [ [[package]] name = "transformers" -version = "4.53.3" +version = "4.56.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -5874,9 +6100,9 @@ dependencies = [ { name = "tokenizers" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f1/5c/49182918b58eaa0b4c954fd0e37c79fc299e5643e69d70089d0b0eb0cd9b/transformers-4.53.3.tar.gz", hash = "sha256:b2eda1a261de79b78b97f7888fe2005fc0c3fabf5dad33d52cc02983f9f675d8", size = 9197478, upload-time = "2025-07-22T07:30:51.51Z" } +sdist = { url = "https://files.pythonhosted.org/packages/89/21/dc88ef3da1e49af07ed69386a11047a31dcf1aaf4ded3bc4b173fbf94116/transformers-4.56.1.tar.gz", hash = "sha256:0d88b1089a563996fc5f2c34502f10516cad3ea1aa89f179f522b54c8311fe74", size = 9855473, upload-time = "2025-09-04T20:47:13.14Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/b1/d7520cc5cb69c825599042eb3a7c986fa9baa8a8d2dea9acd78e152c81e2/transformers-4.53.3-py3-none-any.whl", hash = "sha256:5aba81c92095806b6baf12df35d756cf23b66c356975fb2a7fa9e536138d7c75", size = 10826382, upload-time = "2025-07-22T07:30:48.458Z" }, + { url = "https://files.pythonhosted.org/packages/71/7c/283c3dd35e00e22a7803a0b2a65251347b745474a82399be058bde1c9f15/transformers-4.56.1-py3-none-any.whl", hash = "sha256:1697af6addfb6ddbce9618b763f4b52d5a756f6da4899ffd1b4febf58b779248", size = 11608197, upload-time = "2025-09-04T20:47:04.895Z" }, ] [[package]] @@ -6860,6 +7086,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594, upload-time = "2025-01-14T10:35:44.018Z" }, ] +[[package]] +name = "wsproto" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/4a/44d3c295350d776427904d73c189e10aeae66d7f555bb2feee16d1e4ba5a/wsproto-1.2.0.tar.gz", hash = "sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065", size = 53425, upload-time = "2022-08-23T19:58:21.447Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/58/e860788190eba3bcce367f74d29c4675466ce8dddfba85f7827588416f01/wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736", size = 24226, upload-time = "2022-08-23T19:58:19.96Z" }, +] + [[package]] name = "xinference-client" version = "1.2.2" diff --git a/dev/pytest/pytest_config_tests.py b/dev/pytest/pytest_config_tests.py index 63d0cbaf3a..1ec95deb09 100644 --- a/dev/pytest/pytest_config_tests.py +++ b/dev/pytest/pytest_config_tests.py @@ -1,6 +1,7 @@ +from pathlib import Path + import yaml # type: ignore from dotenv import dotenv_values -from pathlib import Path BASE_API_AND_DOCKER_CONFIG_SET_DIFF = { "APP_MAX_EXECUTION_TIME", @@ -98,23 +99,15 @@ with open(Path("docker") / Path("docker-compose.yaml")) as f: def test_yaml_config(): # python set == operator is used to compare two sets - DIFF_API_WITH_DOCKER = ( - API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF - ) + DIFF_API_WITH_DOCKER = API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF if DIFF_API_WITH_DOCKER: - print( - f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}" - ) + print(f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}") raise Exception("API and Docker config sets are different") DIFF_API_WITH_DOCKER_COMPOSE = ( - API_CONFIG_SET - - DOCKER_COMPOSE_CONFIG_SET - - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF + API_CONFIG_SET - DOCKER_COMPOSE_CONFIG_SET - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF ) if DIFF_API_WITH_DOCKER_COMPOSE: - print( - f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}" - ) + print(f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}") raise Exception("API and Docker Compose config sets are different") print("All tests passed!") diff --git a/docker/.env.example b/docker/.env.example index e50153a529..c39a97970b 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -843,6 +843,7 @@ INVITE_EXPIRY_HOURS=72 # Reset password token valid time (minutes), RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5 CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 5924877c7d..761c5868fd 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -372,6 +372,7 @@ x-shared-env: &shared-api-worker-env INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} + EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES: ${EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES:-5} CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: ${CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES:-5} OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: ${OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES:-5} CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} diff --git a/scripts/stress-test/README.md b/scripts/stress-test/README.md new file mode 100644 index 0000000000..15f21cd532 --- /dev/null +++ b/scripts/stress-test/README.md @@ -0,0 +1,521 @@ +# Dify Stress Test Suite + +A high-performance stress test suite for Dify workflow execution using **Locust** - optimized for measuring Server-Sent Events (SSE) streaming performance. + +## Key Metrics Tracked + +The stress test focuses on four critical SSE performance indicators: + +1. **Active SSE Connections** - Real-time count of open SSE connections +1. **New Connection Rate** - Connections per second (conn/sec) +1. **Time to First Event (TTFE)** - Latency until first SSE event arrives +1. **Event Throughput** - Events per second (events/sec) + +## Features + +- **True SSE Support**: Properly handles Server-Sent Events streaming without premature connection closure +- **Real-time Metrics**: Live reporting every 5 seconds during tests +- **Comprehensive Tracking**: + - Active connection monitoring + - Connection establishment rate + - Event processing throughput + - TTFE distribution analysis +- **Multiple Interfaces**: + - Web UI for real-time monitoring () + - Headless mode with periodic console updates +- **Detailed Reports**: Final statistics with overall rates and averages +- **Easy Configuration**: Uses existing API key configuration from setup + +## What Gets Measured + +The stress test focuses on SSE streaming performance with these key metrics: + +### Primary Endpoint: `/v1/workflows/run` + +The stress test tests a single endpoint with comprehensive SSE metrics tracking: + +- **Request Type**: POST request to workflow execution API +- **Response Type**: Server-Sent Events (SSE) stream +- **Payload**: Random questions from a configurable pool +- **Concurrency**: Configurable from 1 to 1000+ simultaneous users + +### Key Performance Metrics + +#### 1. **Active Connections** + +- **What it measures**: Number of concurrent SSE connections open at any moment +- **Why it matters**: Shows system's ability to handle parallel streams +- **Good values**: Should remain stable under load without drops + +#### 2. **Connection Rate (conn/sec)** + +- **What it measures**: How fast new SSE connections are established +- **Why it matters**: Indicates system's ability to handle connection spikes +- **Good values**: + - Light load: 5-10 conn/sec + - Medium load: 20-50 conn/sec + - Heavy load: 100+ conn/sec + +#### 3. **Time to First Event (TTFE)** + +- **What it measures**: Latency from request sent to first SSE event received +- **Why it matters**: Critical for user experience - faster TTFE = better perceived performance +- **Good values**: + - Excellent: < 50ms + - Good: 50-100ms + - Acceptable: 100-500ms + - Poor: > 500ms + +#### 4. **Event Throughput (events/sec)** + +- **What it measures**: Rate of SSE events being delivered across all connections +- **Why it matters**: Shows actual data delivery performance +- **Expected values**: Depends on workflow complexity and number of connections + - Single connection: 10-20 events/sec + - 10 connections: 50-100 events/sec + - 100 connections: 200-500 events/sec + +#### 5. **Request/Response Times** + +- **P50 (Median)**: 50% of requests complete within this time +- **P95**: 95% of requests complete within this time +- **P99**: 99% of requests complete within this time +- **Min/Max**: Best and worst case response times + +## Prerequisites + +1. **Dependencies are automatically installed** when running setup: + + - Locust (load testing framework) + - sseclient-py (SSE client library) + +1. **Complete Dify setup**: + + ```bash + # Run the complete setup + python scripts/stress-test/setup_all.py + ``` + +1. **Ensure services are running**: + + **IMPORTANT**: For accurate stress testing, run the API server with Gunicorn in production mode: + + ```bash + # Run from the api directory + cd api + uv run gunicorn \ + --bind 0.0.0.0:5001 \ + --workers 4 \ + --worker-class gevent \ + --timeout 120 \ + --keep-alive 5 \ + --log-level info \ + --access-logfile - \ + --error-logfile - \ + app:app + ``` + + **Configuration options explained**: + + - `--workers 4`: Number of worker processes (adjust based on CPU cores) + - `--worker-class gevent`: Async worker for handling concurrent connections + - `--timeout 120`: Worker timeout for long-running requests + - `--keep-alive 5`: Keep connections alive for SSE streaming + + **NOT RECOMMENDED for stress testing**: + + ```bash + # Debug mode - DO NOT use for stress testing (slow performance) + ./dev/start-api # This runs Flask in debug mode with single-threaded execution + ``` + + **Also start the Mock OpenAI server**: + + ```bash + python scripts/stress-test/setup/mock_openai_server.py + ``` + +## Running the Stress Test + +```bash +# Run with default configuration (headless mode) +./scripts/stress-test/run_locust_stress_test.sh + +# Or run directly with uv +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py --host http://localhost:5001 + +# Run with Web UI (access at http://localhost:8089) +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py --host http://localhost:5001 --web-port 8089 +``` + +The script will: + +1. Validate that all required services are running +1. Check API token availability +1. Execute the Locust stress test with SSE support +1. Generate comprehensive reports in the `reports/` directory + +## Configuration + +The stress test configuration is in `locust.conf`: + +```ini +users = 10 # Number of concurrent users +spawn-rate = 2 # Users spawned per second +run-time = 1m # Test duration (30s, 5m, 1h) +headless = true # Run without web UI +``` + +### Custom Question Sets + +Modify the questions list in `sse_benchmark.py`: + +```python +self.questions = [ + "Your custom question 1", + "Your custom question 2", + # Add more questions... +] +``` + +## Understanding the Results + +### Report Structure + +After running the stress test, you'll find these files in the `reports/` directory: + +- `locust_summary_YYYYMMDD_HHMMSS.txt` - Complete console output with metrics +- `locust_report_YYYYMMDD_HHMMSS.html` - Interactive HTML report with charts +- `locust_YYYYMMDD_HHMMSS_stats.csv` - CSV with detailed statistics +- `locust_YYYYMMDD_HHMMSS_stats_history.csv` - Time-series data + +### Key Metrics + +**Requests Per Second (RPS)**: + +- **Excellent**: > 50 RPS +- **Good**: 20-50 RPS +- **Acceptable**: 10-20 RPS +- **Needs Improvement**: < 10 RPS + +**Response Time Percentiles**: + +- **P50 (Median)**: 50% of requests complete within this time +- **P95**: 95% of requests complete within this time +- **P99**: 99% of requests complete within this time + +**Success Rate**: + +- Should be > 99% for production readiness +- Lower rates indicate errors or timeouts + +### Example Output + +```text +============================================================ +DIFY SSE STRESS TEST +============================================================ + +[2025-09-12 15:45:44,468] Starting test run with 10 users at 2 users/sec + +============================================================ +SSE Metrics | Active: 8 | Total Conn: 142 | Events: 2841 +Rates: 2.4 conn/s | 47.3 events/s | TTFE: 43ms +============================================================ + +Type Name # reqs # fails | Avg Min Max Med | req/s failures/s +---------|------------------------------|--------|--------|--------|--------|--------|--------|--------|----------- +POST /v1/workflows/run 142 0(0.00%) | 41 18 192 38 | 2.37 0.00 +---------|------------------------------|--------|--------|--------|--------|--------|--------|--------|----------- + Aggregated 142 0(0.00%) | 41 18 192 38 | 2.37 0.00 + +============================================================ +FINAL RESULTS +============================================================ +Total Connections: 142 +Total Events: 2841 +Average TTFE: 43 ms +============================================================ +``` + +### How to Read the Results + +**Live SSE Metrics Box (Updates every 10 seconds):** + +```text +SSE Metrics | Active: 8 | Total Conn: 142 | Events: 2841 +Rates: 2.4 conn/s | 47.3 events/s | TTFE: 43ms +``` + +- **Active**: Current number of open SSE connections +- **Total Conn**: Cumulative connections established +- **Events**: Total SSE events received +- **conn/s**: Connection establishment rate +- **events/s**: Event delivery rate +- **TTFE**: Average time to first event + +**Standard Locust Table:** + +```text +Type Name # reqs # fails | Avg Min Max Med | req/s +POST /v1/workflows/run 142 0(0.00%) | 41 18 192 38 | 2.37 +``` + +- **Type**: Always POST for our SSE requests +- **Name**: The API endpoint being tested +- **# reqs**: Total requests made +- **# fails**: Failed requests (should be 0) +- **Avg/Min/Max/Med**: Response time percentiles (ms) +- **req/s**: Request throughput + +**Performance Targets:** + +✅ **Good Performance**: + +- Zero failures (0.00%) +- TTFE < 100ms +- Stable active connections +- Consistent event throughput + +⚠️ **Warning Signs**: + +- Failures > 1% +- TTFE > 500ms +- Dropping active connections +- Declining event rate over time + +## Test Scenarios + +### Light Load + +```yaml +concurrency: 10 +iterations: 100 +``` + +### Normal Load + +```yaml +concurrency: 100 +iterations: 1000 +``` + +### Heavy Load + +```yaml +concurrency: 500 +iterations: 5000 +``` + +### Stress Test + +```yaml +concurrency: 1000 +iterations: 10000 +``` + +## Performance Tuning + +### API Server Optimization + +**Gunicorn Tuning for Different Load Levels**: + +```bash +# Light load (10-50 concurrent users) +uv run gunicorn --bind 0.0.0.0:5001 --workers 2 --worker-class gevent app:app + +# Medium load (50-200 concurrent users) +uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent --worker-connections 1000 app:app + +# Heavy load (200-1000 concurrent users) +uv run gunicorn --bind 0.0.0.0:5001 --workers 8 --worker-class gevent --worker-connections 2000 --max-requests 1000 app:app +``` + +**Worker calculation formula**: + +- Workers = (2 × CPU cores) + 1 +- For SSE/WebSocket: Use gevent worker class +- For CPU-bound tasks: Use sync workers + +### Database Optimization + +**PostgreSQL Connection Pool Tuning**: + +For high-concurrency stress testing, increase the PostgreSQL max connections in `docker/middleware.env`: + +```bash +# Edit docker/middleware.env +POSTGRES_MAX_CONNECTIONS=200 # Default is 100 + +# Recommended values for different load levels: +# Light load (10-50 users): 100 (default) +# Medium load (50-200 users): 200 +# Heavy load (200-1000 users): 500 +``` + +After changing, restart the PostgreSQL container: + +```bash +docker compose -f docker/docker-compose.middleware.yaml down db +docker compose -f docker/docker-compose.middleware.yaml up -d db +``` + +**Note**: Each connection uses ~10MB of RAM. Ensure your database server has sufficient memory: + +- 100 connections: ~1GB RAM +- 200 connections: ~2GB RAM +- 500 connections: ~5GB RAM + +### System Optimizations + +1. **Increase file descriptor limits**: + + ```bash + ulimit -n 65536 + ``` + +1. **TCP tuning for high concurrency** (Linux): + + ```bash + # Increase TCP buffer sizes + sudo sysctl -w net.core.rmem_max=134217728 + sudo sysctl -w net.core.wmem_max=134217728 + + # Enable TCP fast open + sudo sysctl -w net.ipv4.tcp_fastopen=3 + ``` + +1. **macOS specific**: + + ```bash + # Increase maximum connections + sudo sysctl -w kern.ipc.somaxconn=2048 + ``` + +## Troubleshooting + +### Common Issues + +1. **"ModuleNotFoundError: No module named 'locust'"**: + + ```bash + # Dependencies are installed automatically, but if needed: + uv --project api add --dev locust sseclient-py + ``` + +1. **"API key configuration not found"**: + + ```bash + # Run setup + python scripts/stress-test/setup_all.py + ``` + +1. **Services not running**: + + ```bash + # Start Dify API with Gunicorn (production mode) + cd api + uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent app:app + + # Start Mock OpenAI server + python scripts/stress-test/setup/mock_openai_server.py + ``` + +1. **High error rate**: + + - Reduce concurrency level + - Check system resources (CPU, memory) + - Review API server logs for errors + - Increase timeout values if needed + +1. **Permission denied running script**: + + ```bash + chmod +x run_benchmark.sh + ``` + +## Advanced Usage + +### Running Multiple Iterations + +```bash +# Run stress test 3 times with 60-second intervals +for i in {1..3}; do + echo "Run $i of 3" + ./run_locust_stress_test.sh + sleep 60 +done +``` + +### Custom Locust Options + +Run Locust directly with custom options: + +```bash +# With specific user count and spawn rate +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py \ + --host http://localhost:5001 --users 50 --spawn-rate 5 + +# Generate CSV reports +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py \ + --host http://localhost:5001 --csv reports/results + +# Run for specific duration +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py \ + --host http://localhost:5001 --run-time 5m --headless +``` + +### Comparing Results + +```bash +# Compare multiple stress test runs +ls -la reports/stress_test_*.txt | tail -5 +``` + +## Interpreting Performance Issues + +### High Response Times + +Possible causes: + +- Database query performance +- External API latency +- Insufficient server resources +- Network congestion + +### Low Throughput (RPS < 10) + +Check for: + +- CPU bottlenecks +- Memory constraints +- Database connection pooling +- API rate limiting + +### High Error Rate + +Investigate: + +- Server error logs +- Resource exhaustion +- Timeout configurations +- Connection limits + +## Why Locust? + +Locust was chosen over Drill for this stress test because: + +1. **Proper SSE Support**: Correctly handles streaming responses without premature closure +1. **Custom Metrics**: Can track SSE-specific metrics like TTFE and stream duration +1. **Web UI**: Real-time monitoring and control via web interface +1. **Python Integration**: Seamlessly integrates with existing Python setup code +1. **Extensibility**: Easy to customize for specific testing scenarios + +## Contributing + +To improve the stress test suite: + +1. Edit `stress_test.yml` for configuration changes +1. Modify `run_locust_stress_test.sh` for workflow improvements +1. Update question sets for better coverage +1. Add new metrics or analysis features diff --git a/scripts/stress-test/cleanup.py b/scripts/stress-test/cleanup.py new file mode 100755 index 0000000000..05b97be7ca --- /dev/null +++ b/scripts/stress-test/cleanup.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 + +import shutil +import sys +from pathlib import Path + +from common import Logger + + +def cleanup() -> None: + """Clean up all configuration files and reports created during setup and stress testing.""" + + log = Logger("Cleanup") + log.header("Stress Test Cleanup") + + config_dir = Path(__file__).parent / "setup" / "config" + reports_dir = Path(__file__).parent / "reports" + + dirs_to_clean = [] + if config_dir.exists(): + dirs_to_clean.append(config_dir) + if reports_dir.exists(): + dirs_to_clean.append(reports_dir) + + if not dirs_to_clean: + log.success("No directories to clean. Everything is already clean.") + return + + log.info("Cleaning up stress test data...") + log.info("This will remove:") + for dir_path in dirs_to_clean: + log.list_item(str(dir_path)) + + # List files that will be deleted + log.separator() + if config_dir.exists(): + config_files = list(config_dir.glob("*.json")) + if config_files: + log.info("Config files to be removed:") + for file in config_files: + log.list_item(file.name) + + if reports_dir.exists(): + report_files = list(reports_dir.glob("*")) + if report_files: + log.info("Report files to be removed:") + for file in report_files: + log.list_item(file.name) + + # Ask for confirmation if running interactively + if sys.stdin.isatty(): + log.separator() + log.warning("This action cannot be undone!") + confirmation = input("Are you sure you want to remove all config and report files? (yes/no): ") + + if confirmation.lower() not in ["yes", "y"]: + log.error("Cleanup cancelled.") + return + + try: + # Remove directories and all their contents + for dir_path in dirs_to_clean: + shutil.rmtree(dir_path) + log.success(f"{dir_path.name} directory removed successfully!") + + log.separator() + log.info("To run the setup again, execute:") + log.list_item("python setup_all.py") + log.info("Or run scripts individually in this order:") + log.list_item("python setup/mock_openai_server.py (in a separate terminal)") + log.list_item("python setup/setup_admin.py") + log.list_item("python setup/login_admin.py") + log.list_item("python setup/install_openai_plugin.py") + log.list_item("python setup/configure_openai_plugin.py") + log.list_item("python setup/import_workflow_app.py") + log.list_item("python setup/create_api_key.py") + log.list_item("python setup/publish_workflow.py") + log.list_item("python setup/run_workflow.py") + + except PermissionError as e: + log.error(f"Permission denied: {e}") + log.info("Try running with appropriate permissions.") + except Exception as e: + log.error(f"An error occurred during cleanup: {e}") + + +if __name__ == "__main__": + cleanup() diff --git a/scripts/stress-test/common/__init__.py b/scripts/stress-test/common/__init__.py new file mode 100644 index 0000000000..a38d972ffb --- /dev/null +++ b/scripts/stress-test/common/__init__.py @@ -0,0 +1,6 @@ +"""Common utilities for Dify benchmark suite.""" + +from .config_helper import config_helper +from .logger_helper import Logger, ProgressLogger + +__all__ = ["Logger", "ProgressLogger", "config_helper"] diff --git a/scripts/stress-test/common/config_helper.py b/scripts/stress-test/common/config_helper.py new file mode 100644 index 0000000000..75fcbffa6f --- /dev/null +++ b/scripts/stress-test/common/config_helper.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 + +import json +from pathlib import Path +from typing import Any + + +class ConfigHelper: + """Helper class for reading and writing configuration files.""" + + def __init__(self, base_dir: Path | None = None): + """Initialize ConfigHelper with base directory. + + Args: + base_dir: Base directory for config files. If None, uses setup/config + """ + if base_dir is None: + # Default to config directory in setup folder + base_dir = Path(__file__).parent.parent / "setup" / "config" + self.base_dir = base_dir + self.state_file = "stress_test_state.json" + + def ensure_config_dir(self) -> None: + """Ensure the config directory exists.""" + self.base_dir.mkdir(exist_ok=True, parents=True) + + def get_config_path(self, filename: str) -> Path: + """Get the full path for a config file. + + Args: + filename: Name of the config file (e.g., 'admin_config.json') + + Returns: + Full path to the config file + """ + if not filename.endswith(".json"): + filename += ".json" + return self.base_dir / filename + + def read_config(self, filename: str) -> dict[str, Any] | None: + """Read a configuration file. + + DEPRECATED: Use read_state() or get_state_section() for new code. + This method provides backward compatibility. + + Args: + filename: Name of the config file to read + + Returns: + Dictionary containing config data, or None if file doesn't exist + """ + # Provide backward compatibility for old config names + if filename in ["admin_config", "token_config", "app_config", "api_key_config"]: + section_map = { + "admin_config": "admin", + "token_config": "auth", + "app_config": "app", + "api_key_config": "api_key", + } + return self.get_state_section(section_map[filename]) + + config_path = self.get_config_path(filename) + + if not config_path.exists(): + return None + + try: + with open(config_path) as f: + return json.load(f) + except (OSError, json.JSONDecodeError) as e: + print(f"❌ Error reading {filename}: {e}") + return None + + def write_config(self, filename: str, data: dict[str, Any]) -> bool: + """Write data to a configuration file. + + DEPRECATED: Use write_state() or update_state_section() for new code. + This method provides backward compatibility. + + Args: + filename: Name of the config file to write + data: Dictionary containing data to save + + Returns: + True if successful, False otherwise + """ + # Provide backward compatibility for old config names + if filename in ["admin_config", "token_config", "app_config", "api_key_config"]: + section_map = { + "admin_config": "admin", + "token_config": "auth", + "app_config": "app", + "api_key_config": "api_key", + } + return self.update_state_section(section_map[filename], data) + + self.ensure_config_dir() + config_path = self.get_config_path(filename) + + try: + with open(config_path, "w") as f: + json.dump(data, f, indent=2) + return True + except OSError as e: + print(f"❌ Error writing {filename}: {e}") + return False + + def config_exists(self, filename: str) -> bool: + """Check if a config file exists. + + Args: + filename: Name of the config file to check + + Returns: + True if file exists, False otherwise + """ + return self.get_config_path(filename).exists() + + def delete_config(self, filename: str) -> bool: + """Delete a configuration file. + + Args: + filename: Name of the config file to delete + + Returns: + True if successful, False otherwise + """ + config_path = self.get_config_path(filename) + + if not config_path.exists(): + return True # Already doesn't exist + + try: + config_path.unlink() + return True + except OSError as e: + print(f"❌ Error deleting {filename}: {e}") + return False + + def read_state(self) -> dict[str, Any] | None: + """Read the entire stress test state. + + Returns: + Dictionary containing all state data, or None if file doesn't exist + """ + state_path = self.get_config_path(self.state_file) + if not state_path.exists(): + return None + + try: + with open(state_path) as f: + return json.load(f) + except (OSError, json.JSONDecodeError) as e: + print(f"❌ Error reading {self.state_file}: {e}") + return None + + def write_state(self, data: dict[str, Any]) -> bool: + """Write the entire stress test state. + + Args: + data: Dictionary containing all state data to save + + Returns: + True if successful, False otherwise + """ + self.ensure_config_dir() + state_path = self.get_config_path(self.state_file) + + try: + with open(state_path, "w") as f: + json.dump(data, f, indent=2) + return True + except OSError as e: + print(f"❌ Error writing {self.state_file}: {e}") + return False + + def update_state_section(self, section: str, data: dict[str, Any]) -> bool: + """Update a specific section of the stress test state. + + Args: + section: Name of the section to update (e.g., 'admin', 'auth', 'app', 'api_key') + data: Dictionary containing section data to save + + Returns: + True if successful, False otherwise + """ + state = self.read_state() or {} + state[section] = data + return self.write_state(state) + + def get_state_section(self, section: str) -> dict[str, Any] | None: + """Get a specific section from the stress test state. + + Args: + section: Name of the section to get (e.g., 'admin', 'auth', 'app', 'api_key') + + Returns: + Dictionary containing section data, or None if not found + """ + state = self.read_state() + if state: + return state.get(section) + return None + + def get_token(self) -> str | None: + """Get the access token from auth section. + + Returns: + Access token string or None if not found + """ + auth = self.get_state_section("auth") + if auth: + return auth.get("access_token") + return None + + def get_app_id(self) -> str | None: + """Get the app ID from app section. + + Returns: + App ID string or None if not found + """ + app = self.get_state_section("app") + if app: + return app.get("app_id") + return None + + def get_api_key(self) -> str | None: + """Get the API key token from api_key section. + + Returns: + API key token string or None if not found + """ + api_key = self.get_state_section("api_key") + if api_key: + return api_key.get("token") + return None + + +# Create a default instance for convenience +config_helper = ConfigHelper() diff --git a/scripts/stress-test/common/logger_helper.py b/scripts/stress-test/common/logger_helper.py new file mode 100644 index 0000000000..c522685f1d --- /dev/null +++ b/scripts/stress-test/common/logger_helper.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 + +import sys +import time +from enum import Enum + + +class LogLevel(Enum): + """Log levels with associated colors and symbols.""" + + DEBUG = ("🔍", "\033[90m") # Gray + INFO = ("ℹ️ ", "\033[94m") # Blue + SUCCESS = ("✅", "\033[92m") # Green + WARNING = ("⚠️ ", "\033[93m") # Yellow + ERROR = ("❌", "\033[91m") # Red + STEP = ("🚀", "\033[96m") # Cyan + PROGRESS = ("📋", "\033[95m") # Magenta + + +class Logger: + """Logger class for formatted console output.""" + + def __init__(self, name: str | None = None, use_colors: bool = True): + """Initialize logger. + + Args: + name: Optional name for the logger (e.g., script name) + use_colors: Whether to use ANSI color codes + """ + self.name = name + self.use_colors = use_colors and sys.stdout.isatty() + self._reset_color = "\033[0m" if self.use_colors else "" + + def _format_message(self, level: LogLevel, message: str, indent: int = 0) -> str: + """Format a log message with level, color, and indentation. + + Args: + level: Log level + message: Message to log + indent: Number of spaces to indent + + Returns: + Formatted message string + """ + symbol, color = level.value + color = color if self.use_colors else "" + reset = self._reset_color + + prefix = " " * indent + + if self.name and level in [LogLevel.STEP, LogLevel.ERROR]: + return f"{prefix}{color}{symbol} [{self.name}] {message}{reset}" + else: + return f"{prefix}{color}{symbol} {message}{reset}" + + def debug(self, message: str, indent: int = 0) -> None: + """Log debug message.""" + print(self._format_message(LogLevel.DEBUG, message, indent)) + + def info(self, message: str, indent: int = 0) -> None: + """Log info message.""" + print(self._format_message(LogLevel.INFO, message, indent)) + + def success(self, message: str, indent: int = 0) -> None: + """Log success message.""" + print(self._format_message(LogLevel.SUCCESS, message, indent)) + + def warning(self, message: str, indent: int = 0) -> None: + """Log warning message.""" + print(self._format_message(LogLevel.WARNING, message, indent)) + + def error(self, message: str, indent: int = 0) -> None: + """Log error message.""" + print(self._format_message(LogLevel.ERROR, message, indent), file=sys.stderr) + + def step(self, message: str, indent: int = 0) -> None: + """Log a step in a process.""" + print(self._format_message(LogLevel.STEP, message, indent)) + + def progress(self, message: str, indent: int = 0) -> None: + """Log progress information.""" + print(self._format_message(LogLevel.PROGRESS, message, indent)) + + def separator(self, char: str = "-", length: int = 60) -> None: + """Print a separator line.""" + print(char * length) + + def header(self, title: str, width: int = 60) -> None: + """Print a formatted header.""" + if self.use_colors: + print(f"\n\033[1m{'=' * width}\033[0m") # Bold + print(f"\033[1m{title.center(width)}\033[0m") + print(f"\033[1m{'=' * width}\033[0m\n") + else: + print(f"\n{'=' * width}") + print(title.center(width)) + print(f"{'=' * width}\n") + + def box(self, title: str, width: int = 60) -> None: + """Print a title in a box.""" + border = "═" * (width - 2) + if self.use_colors: + print(f"\033[1m╔{border}╗\033[0m") + print(f"\033[1m║{title.center(width - 2)}║\033[0m") + print(f"\033[1m╚{border}╝\033[0m") + else: + print(f"╔{border}╗") + print(f"║{title.center(width - 2)}║") + print(f"╚{border}╝") + + def list_item(self, item: str, indent: int = 2) -> None: + """Print a list item.""" + prefix = " " * indent + print(f"{prefix}• {item}") + + def key_value(self, key: str, value: str, indent: int = 2) -> None: + """Print a key-value pair.""" + prefix = " " * indent + if self.use_colors: + print(f"{prefix}\033[1m{key}:\033[0m {value}") + else: + print(f"{prefix}{key}: {value}") + + def spinner_start(self, message: str) -> None: + """Start a spinner (simple implementation).""" + sys.stdout.write(f"\r{message}... ") + sys.stdout.flush() + + def spinner_stop(self, success: bool = True, message: str | None = None) -> None: + """Stop the spinner and show result.""" + if success: + symbol = "✅" if message else "Done" + sys.stdout.write(f"\r{symbol} {message or ''}\n") + else: + symbol = "❌" if message else "Failed" + sys.stdout.write(f"\r{symbol} {message or ''}\n") + sys.stdout.flush() + + +class ProgressLogger: + """Logger for tracking progress through multiple steps.""" + + def __init__(self, total_steps: int, logger: Logger | None = None): + """Initialize progress logger. + + Args: + total_steps: Total number of steps + logger: Logger instance to use (creates new if None) + """ + self.total_steps = total_steps + self.current_step = 0 + self.logger = logger or Logger() + self.start_time = time.time() + + def next_step(self, description: str) -> None: + """Move to next step and log it.""" + self.current_step += 1 + elapsed = time.time() - self.start_time + + if self.logger.use_colors: + progress_bar = self._create_progress_bar() + print(f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}") + self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)") + else: + print(f"\n[Step {self.current_step}/{self.total_steps}]") + self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)") + + def _create_progress_bar(self, width: int = 20) -> str: + """Create a simple progress bar.""" + filled = int(width * self.current_step / self.total_steps) + bar = "█" * filled + "░" * (width - filled) + percentage = int(100 * self.current_step / self.total_steps) + return f"[{bar}] {percentage}%" + + def complete(self) -> None: + """Mark progress as complete.""" + elapsed = time.time() - self.start_time + self.logger.success(f"All steps completed! Total time: {elapsed:.1f}s") + + +# Create default logger instance +logger = Logger() + + +# Convenience functions using default logger +def debug(message: str, indent: int = 0) -> None: + """Log debug message using default logger.""" + logger.debug(message, indent) + + +def info(message: str, indent: int = 0) -> None: + """Log info message using default logger.""" + logger.info(message, indent) + + +def success(message: str, indent: int = 0) -> None: + """Log success message using default logger.""" + logger.success(message, indent) + + +def warning(message: str, indent: int = 0) -> None: + """Log warning message using default logger.""" + logger.warning(message, indent) + + +def error(message: str, indent: int = 0) -> None: + """Log error message using default logger.""" + logger.error(message, indent) + + +def step(message: str, indent: int = 0) -> None: + """Log step using default logger.""" + logger.step(message, indent) + + +def progress(message: str, indent: int = 0) -> None: + """Log progress using default logger.""" + logger.progress(message, indent) diff --git a/scripts/stress-test/locust.conf b/scripts/stress-test/locust.conf new file mode 100644 index 0000000000..87bd8c2870 --- /dev/null +++ b/scripts/stress-test/locust.conf @@ -0,0 +1,37 @@ +# Locust configuration file for Dify SSE benchmark + +# Target host +host = http://localhost:5001 + +# Number of users to simulate +users = 10 + +# Spawn rate (users per second) +spawn-rate = 2 + +# Run time (use format like 30s, 5m, 1h) +run-time = 1m + +# Locustfile to use +locustfile = scripts/stress-test/sse_benchmark.py + +# Headless mode (no web UI) +headless = true + +# Print stats in the console +print-stats = true + +# Only print summary stats +only-summary = false + +# Reset statistics after ramp-up +reset-stats = false + +# Log level +loglevel = INFO + +# CSV output (uncomment to enable) +# csv = reports/locust_results + +# HTML report (uncomment to enable) +# html = reports/locust_report.html \ No newline at end of file diff --git a/scripts/stress-test/run_locust_stress_test.sh b/scripts/stress-test/run_locust_stress_test.sh new file mode 100755 index 0000000000..665cb68754 --- /dev/null +++ b/scripts/stress-test/run_locust_stress_test.sh @@ -0,0 +1,202 @@ +#!/bin/bash + +# Run Dify SSE Stress Test using Locust + +set -e + +# Get the directory where this script is located +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# Go to project root first, then to script dir +PROJECT_ROOT="$( cd "${SCRIPT_DIR}/../.." && pwd )" +cd "${PROJECT_ROOT}" +STRESS_TEST_DIR="scripts/stress-test" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +CYAN='\033[0;36m' +NC='\033[0m' # No Color + +# Configuration +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +REPORT_DIR="${STRESS_TEST_DIR}/reports" +CSV_PREFIX="${REPORT_DIR}/locust_${TIMESTAMP}" +HTML_REPORT="${REPORT_DIR}/locust_report_${TIMESTAMP}.html" +SUMMARY_REPORT="${REPORT_DIR}/locust_summary_${TIMESTAMP}.txt" + +# Create reports directory if it doesn't exist +mkdir -p "${REPORT_DIR}" + +echo -e "${BLUE}╔════════════════════════════════════════════════════════════════╗${NC}" +echo -e "${BLUE}║ DIFY SSE WORKFLOW STRESS TEST (LOCUST) ║${NC}" +echo -e "${BLUE}╚════════════════════════════════════════════════════════════════╝${NC}" +echo + +# Check if services are running +echo -e "${YELLOW}Checking services...${NC}" + +# Check Dify API +if curl -s -f http://localhost:5001/health > /dev/null 2>&1; then + echo -e "${GREEN}✓ Dify API is running${NC}" + + # Warn if running in debug mode (check for werkzeug in process) + if ps aux | grep -v grep | grep -q "werkzeug.*5001\|flask.*run.*5001"; then + echo -e "${YELLOW}⚠ WARNING: API appears to be running in debug mode (Flask development server)${NC}" + echo -e "${YELLOW} This will give inaccurate benchmark results!${NC}" + echo -e "${YELLOW} For accurate benchmarking, restart with Gunicorn:${NC}" + echo -e "${CYAN} cd api && uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent app:app${NC}" + echo + echo -n "Continue anyway? (not recommended) [y/N]: " + read -t 10 continue_debug || continue_debug="n" + if [ "$continue_debug" != "y" ] && [ "$continue_debug" != "Y" ]; then + echo -e "${RED}Benchmark cancelled. Please restart API with Gunicorn.${NC}" + exit 1 + fi + fi +else + echo -e "${RED}✗ Dify API is not running on port 5001${NC}" + echo -e "${YELLOW} Start it with Gunicorn for accurate benchmarking:${NC}" + echo -e "${CYAN} cd api && uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent app:app${NC}" + exit 1 +fi + +# Check Mock OpenAI server +if curl -s -f http://localhost:5004/v1/models > /dev/null 2>&1; then + echo -e "${GREEN}✓ Mock OpenAI server is running${NC}" +else + echo -e "${RED}✗ Mock OpenAI server is not running on port 5004${NC}" + echo -e "${YELLOW} Start it with: python scripts/stress-test/setup/mock_openai_server.py${NC}" + exit 1 +fi + +# Check API token exists +if [ ! -f "${STRESS_TEST_DIR}/setup/config/stress_test_state.json" ]; then + echo -e "${RED}✗ Stress test configuration not found${NC}" + echo -e "${YELLOW} Run setup first: python scripts/stress-test/setup_all.py${NC}" + exit 1 +fi + +API_TOKEN=$(python3 -c "import json; state = json.load(open('${STRESS_TEST_DIR}/setup/config/stress_test_state.json')); print(state.get('api_key', {}).get('token', ''))" 2>/dev/null) +if [ -z "$API_TOKEN" ]; then + echo -e "${RED}✗ Failed to read API token from stress test state${NC}" + exit 1 +fi +echo -e "${GREEN}✓ API token found: ${API_TOKEN:0:10}...${NC}" + +echo +echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" +echo -e "${CYAN} STRESS TEST PARAMETERS ${NC}" +echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" + +# Parse configuration +USERS=$(grep "^users" ${STRESS_TEST_DIR}/locust.conf | cut -d'=' -f2 | tr -d ' ') +SPAWN_RATE=$(grep "^spawn-rate" ${STRESS_TEST_DIR}/locust.conf | cut -d'=' -f2 | tr -d ' ') +RUN_TIME=$(grep "^run-time" ${STRESS_TEST_DIR}/locust.conf | cut -d'=' -f2 | tr -d ' ') + +echo -e " ${YELLOW}Users:${NC} $USERS concurrent users" +echo -e " ${YELLOW}Spawn Rate:${NC} $SPAWN_RATE users/second" +echo -e " ${YELLOW}Duration:${NC} $RUN_TIME" +echo -e " ${YELLOW}Mode:${NC} SSE Streaming" +echo + +# Ask user for run mode +echo -e "${YELLOW}Select run mode:${NC}" +echo " 1) Headless (CLI only) - Default" +echo " 2) Web UI (http://localhost:8089)" +echo -n "Choice [1]: " +read -t 10 choice || choice="1" +echo + +# Use SSE stress test script +LOCUST_SCRIPT="${STRESS_TEST_DIR}/sse_benchmark.py" + +# Prepare Locust command +if [ "$choice" = "2" ]; then + echo -e "${BLUE}Starting Locust with Web UI...${NC}" + echo -e "${YELLOW}Access the web interface at: ${CYAN}http://localhost:8089${NC}" + echo + + # Run with web UI + uv --project api run locust \ + -f ${LOCUST_SCRIPT} \ + --host http://localhost:5001 \ + --web-port 8089 +else + echo -e "${BLUE}Starting stress test in headless mode...${NC}" + echo + + # Run in headless mode with CSV output + uv --project api run locust \ + -f ${LOCUST_SCRIPT} \ + --host http://localhost:5001 \ + --users $USERS \ + --spawn-rate $SPAWN_RATE \ + --run-time $RUN_TIME \ + --headless \ + --print-stats \ + --csv=$CSV_PREFIX \ + --html=$HTML_REPORT \ + 2>&1 | tee $SUMMARY_REPORT + + echo + echo -e "${GREEN}═══════════════════════════════════════════════════════════════${NC}" + echo -e "${GREEN} STRESS TEST COMPLETE ${NC}" + echo -e "${GREEN}═══════════════════════════════════════════════════════════════${NC}" + echo + echo -e "${BLUE}Reports generated:${NC}" + echo -e " ${YELLOW}Summary:${NC} $SUMMARY_REPORT" + echo -e " ${YELLOW}HTML Report:${NC} $HTML_REPORT" + echo -e " ${YELLOW}CSV Stats:${NC} ${CSV_PREFIX}_stats.csv" + echo -e " ${YELLOW}CSV History:${NC} ${CSV_PREFIX}_stats_history.csv" + echo + echo -e "${CYAN}View HTML report:${NC}" + echo " open $HTML_REPORT # macOS" + echo " xdg-open $HTML_REPORT # Linux" + echo + + # Parse and display key metrics + echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" + echo -e "${CYAN} KEY METRICS ${NC}" + echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" + + if [ -f "${CSV_PREFIX}_stats.csv" ]; then + python3 - < None: + """Configure OpenAI plugin with mock server credentials.""" + + log = Logger("ConfigPlugin") + log.header("Configuring OpenAI Plugin") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + log.info("Please run login_admin.py first to get access token") + return + + log.step("Configuring OpenAI plugin with mock server...") + + # API endpoint for plugin configuration + base_url = "http://localhost:5001" + config_endpoint = f"{base_url}/console/api/workspaces/current/model-providers/langgenius/openai/openai/credentials" + + # Configuration payload with mock server + config_payload = { + "credentials": { + "openai_api_key": "apikey", + "openai_organization": None, + "openai_api_base": "http://host.docker.internal:5004", + } + } + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the configuration request + with httpx.Client() as client: + response = client.post( + config_endpoint, + json=config_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200: + log.success("OpenAI plugin configured successfully!") + log.key_value("API Base", config_payload["credentials"]["openai_api_base"]) + log.key_value("API Key", config_payload["credentials"]["openai_api_key"]) + + elif response.status_code == 201: + log.success("OpenAI plugin credentials created successfully!") + log.key_value("API Base", config_payload["credentials"]["openai_api_base"]) + log.key_value("API Key", config_payload["credentials"]["openai_api_key"]) + + elif response.status_code == 401: + log.error("Configuration failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + else: + log.error(f"Configuration failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + configure_openai_plugin() diff --git a/scripts/stress-test/setup/create_api_key.py b/scripts/stress-test/setup/create_api_key.py new file mode 100755 index 0000000000..cd04fe57eb --- /dev/null +++ b/scripts/stress-test/setup/create_api_key.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import json + +import httpx +from common import Logger, config_helper + + +def create_api_key() -> None: + """Create API key for the imported app.""" + + log = Logger("CreateAPIKey") + log.header("Creating API Key") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + return + + # Read app_id from config + app_id = config_helper.get_app_id() + if not app_id: + log.error("No app_id found in config") + log.info("Please run import_workflow_app.py first to import the app") + return + + log.step(f"Creating API key for app: {app_id}") + + # API endpoint for creating API key + base_url = "http://localhost:5001" + api_key_endpoint = f"{base_url}/console/api/apps/{app_id}/api-keys" + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Length": "0", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the API key creation request + with httpx.Client() as client: + response = client.post( + api_key_endpoint, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200 or response.status_code == 201: + response_data = response.json() + + api_key_id = response_data.get("id") + api_key_token = response_data.get("token") + + if api_key_token: + log.success("API key created successfully!") + log.key_value("Key ID", api_key_id) + log.key_value("Token", api_key_token) + log.key_value("Type", response_data.get("type")) + + # Save API key to config + api_key_config = { + "id": api_key_id, + "token": api_key_token, + "type": response_data.get("type"), + "app_id": app_id, + "created_at": response_data.get("created_at"), + } + + if config_helper.write_config("api_key_config", api_key_config): + log.info(f"API key saved to: {config_helper.get_config_path('benchmark_state')}") + else: + log.error("No API token received") + log.debug(f"Response: {json.dumps(response_data, indent=2)}") + + elif response.status_code == 401: + log.error("API key creation failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + else: + log.error(f"API key creation failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + create_api_key() diff --git a/scripts/stress-test/setup/dsl/workflow_llm.yml b/scripts/stress-test/setup/dsl/workflow_llm.yml new file mode 100644 index 0000000000..c0fd2c7d8b --- /dev/null +++ b/scripts/stress-test/setup/dsl/workflow_llm.yml @@ -0,0 +1,176 @@ +app: + description: '' + icon: 🤖 + icon_background: '#FFEAD5' + mode: workflow + name: workflow_llm + use_icon_as_answer_icon: false +dependencies: +- current_identifier: null + type: marketplace + value: + marketplace_plugin_unique_identifier: langgenius/openai:0.2.5@373362a028986aae53a7baf73a7f11991ba3c22c69eaf97d6cde048cfd4a9f98 +kind: app +version: 0.4.0 +workflow: + conversation_variables: [] + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_size_limit: 10 + video_file_size_limit: 100 + workflow_file_upload_limit: 10 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - data: + isInIteration: false + isInLoop: false + sourceType: start + targetType: llm + id: 1757611990947-source-1757611992921-target + source: '1757611990947' + sourceHandle: source + target: '1757611992921' + targetHandle: target + type: custom + zIndex: 0 + - data: + isInIteration: false + isInLoop: false + sourceType: llm + targetType: end + id: 1757611992921-source-1757611996447-target + source: '1757611992921' + sourceHandle: source + target: '1757611996447' + targetHandle: target + type: custom + zIndex: 0 + nodes: + - data: + desc: '' + selected: false + title: Start + type: start + variables: + - label: question + max_length: null + options: [] + required: true + type: text-input + variable: question + height: 90 + id: '1757611990947' + position: + x: 30 + y: 245 + positionAbsolute: + x: 30 + y: 245 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 244 + - data: + context: + enabled: false + variable_selector: [] + desc: '' + model: + completion_params: + temperature: 0.7 + mode: chat + name: gpt-4o + provider: langgenius/openai/openai + prompt_template: + - id: c165fcb6-f1f0-42f2-abab-e81982434deb + role: system + text: '' + - role: user + text: '{{#1757611990947.question#}}' + selected: false + title: LLM + type: llm + variables: [] + vision: + enabled: false + height: 90 + id: '1757611992921' + position: + x: 334 + y: 245 + positionAbsolute: + x: 334 + y: 245 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 244 + - data: + desc: '' + outputs: + - value_selector: + - '1757611992921' + - text + value_type: string + variable: answer + selected: false + title: End + type: end + height: 90 + id: '1757611996447' + position: + x: 638 + y: 245 + positionAbsolute: + x: 638 + y: 245 + selected: true + sourcePosition: right + targetPosition: left + type: custom + width: 244 + viewport: + x: 0 + y: 0 + zoom: 0.7 diff --git a/scripts/stress-test/setup/import_workflow_app.py b/scripts/stress-test/setup/import_workflow_app.py new file mode 100755 index 0000000000..86d0239e35 --- /dev/null +++ b/scripts/stress-test/setup/import_workflow_app.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import json + +import httpx +from common import Logger, config_helper + + +def import_workflow_app() -> None: + """Import workflow app from DSL file and save app_id.""" + + log = Logger("ImportApp") + log.header("Importing Workflow Application") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + log.info("Please run login_admin.py first to get access token") + return + + # Read workflow DSL file + dsl_path = Path(__file__).parent / "dsl" / "workflow_llm.yml" + + if not dsl_path.exists(): + log.error(f"DSL file not found: {dsl_path}") + return + + with open(dsl_path) as f: + yaml_content = f.read() + + log.step("Importing workflow app from DSL...") + log.key_value("DSL file", dsl_path.name) + + # API endpoint for app import + base_url = "http://localhost:5001" + import_endpoint = f"{base_url}/console/api/apps/imports" + + # Import payload + import_payload = {"mode": "yaml-content", "yaml_content": yaml_content} + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the import request + with httpx.Client() as client: + response = client.post( + import_endpoint, + json=import_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200: + response_data = response.json() + + # Check import status + if response_data.get("status") == "completed": + app_id = response_data.get("app_id") + + if app_id: + log.success("Workflow app imported successfully!") + log.key_value("App ID", app_id) + log.key_value("App Mode", response_data.get("app_mode")) + log.key_value("DSL Version", response_data.get("imported_dsl_version")) + + # Save app_id to config + app_config = { + "app_id": app_id, + "app_mode": response_data.get("app_mode"), + "app_name": "workflow_llm", + "dsl_version": response_data.get("imported_dsl_version"), + } + + if config_helper.write_config("app_config", app_config): + log.info(f"App config saved to: {config_helper.get_config_path('benchmark_state')}") + else: + log.error("Import completed but no app_id received") + log.debug(f"Response: {json.dumps(response_data, indent=2)}") + + elif response_data.get("status") == "failed": + log.error("Import failed") + log.error(f"Error: {response_data.get('error')}") + else: + log.warning(f"Import status: {response_data.get('status')}") + log.debug(f"Response: {json.dumps(response_data, indent=2)}") + + elif response.status_code == 401: + log.error("Import failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + else: + log.error(f"Import failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + import_workflow_app() diff --git a/scripts/stress-test/setup/install_openai_plugin.py b/scripts/stress-test/setup/install_openai_plugin.py new file mode 100755 index 0000000000..055e5661f8 --- /dev/null +++ b/scripts/stress-test/setup/install_openai_plugin.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import time + +import httpx +from common import Logger, config_helper + + +def install_openai_plugin() -> None: + """Install OpenAI plugin using saved access token.""" + + log = Logger("InstallPlugin") + log.header("Installing OpenAI Plugin") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + log.info("Please run login_admin.py first to get access token") + return + + log.step("Installing OpenAI plugin...") + + # API endpoint for plugin installation + base_url = "http://localhost:5001" + install_endpoint = f"{base_url}/console/api/workspaces/current/plugin/install/marketplace" + + # Plugin identifier + plugin_payload = { + "plugin_unique_identifiers": [ + "langgenius/openai:0.2.5@373362a028986aae53a7baf73a7f11991ba3c22c69eaf97d6cde048cfd4a9f98" + ] + } + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the installation request + with httpx.Client() as client: + response = client.post( + install_endpoint, + json=plugin_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200: + response_data = response.json() + task_id = response_data.get("task_id") + + if not task_id: + log.error("No task ID received from installation request") + return + + log.progress(f"Installation task created: {task_id}") + log.info("Polling for task completion...") + + # Poll for task completion + task_endpoint = f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}" + + max_attempts = 30 # 30 attempts with 2 second delay = 60 seconds max + attempt = 0 + + log.spinner_start("Installing plugin") + + while attempt < max_attempts: + attempt += 1 + time.sleep(2) # Wait 2 seconds between polls + + task_response = client.get( + task_endpoint, + headers=headers, + cookies=cookies, + ) + + if task_response.status_code != 200: + log.spinner_stop( + success=False, + message=f"Failed to get task status: {task_response.status_code}", + ) + return + + task_data = task_response.json() + task_info = task_data.get("task", {}) + status = task_info.get("status") + + if status == "success": + log.spinner_stop(success=True, message="Plugin installed!") + log.success("OpenAI plugin installed successfully!") + + # Display plugin info + plugins = task_info.get("plugins", []) + if plugins: + plugin_info = plugins[0] + log.key_value("Plugin ID", plugin_info.get("plugin_id")) + log.key_value("Message", plugin_info.get("message")) + break + + elif status == "failed": + log.spinner_stop(success=False, message="Installation failed") + log.error("Plugin installation failed") + plugins = task_info.get("plugins", []) + if plugins: + for plugin in plugins: + log.list_item(f"{plugin.get('plugin_id')}: {plugin.get('message')}") + break + + # Continue polling if status is "pending" or other + + else: + log.spinner_stop(success=False, message="Installation timed out") + log.error("Installation timed out after 60 seconds") + + elif response.status_code == 401: + log.error("Installation failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + elif response.status_code == 409: + log.warning("Plugin may already be installed") + log.debug(f"Response: {response.text}") + else: + log.error(f"Installation failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + install_openai_plugin() diff --git a/scripts/stress-test/setup/login_admin.py b/scripts/stress-test/setup/login_admin.py new file mode 100755 index 0000000000..572b8fb650 --- /dev/null +++ b/scripts/stress-test/setup/login_admin.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import json + +import httpx +from common import Logger, config_helper + + +def login_admin() -> None: + """Login with admin account and save access token.""" + + log = Logger("Login") + log.header("Admin Login") + + # Read admin credentials from config + admin_config = config_helper.read_config("admin_config") + + if not admin_config: + log.error("Admin config not found") + log.info("Please run setup_admin.py first to create the admin account") + return + + log.info(f"Logging in with email: {admin_config['email']}") + + # API login endpoint + base_url = "http://localhost:5001" + login_endpoint = f"{base_url}/console/api/login" + + # Prepare login payload + login_payload = { + "email": admin_config["email"], + "password": admin_config["password"], + "remember_me": True, + } + + try: + # Make the login request + with httpx.Client() as client: + response = client.post( + login_endpoint, + json=login_payload, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code == 200: + log.success("Login successful!") + + # Extract token from response + response_data = response.json() + + # Check if login was successful + if response_data.get("result") != "success": + log.error(f"Login failed: {response_data}") + return + + # Extract tokens from data field + token_data = response_data.get("data", {}) + access_token = token_data.get("access_token", "") + refresh_token = token_data.get("refresh_token", "") + + if not access_token: + log.error("No access token found in response") + log.debug(f"Full response: {json.dumps(response_data, indent=2)}") + return + + # Save token to config file + token_config = { + "email": admin_config["email"], + "access_token": access_token, + "refresh_token": refresh_token, + } + + # Save token config + if config_helper.write_config("token_config", token_config): + log.info(f"Token saved to: {config_helper.get_config_path('benchmark_state')}") + + # Show truncated token for verification + token_display = f"{access_token[:20]}..." if len(access_token) > 20 else "Token saved" + log.key_value("Access token", token_display) + + elif response.status_code == 401: + log.error("Login failed: Invalid credentials") + log.debug(f"Response: {response.text}") + else: + log.error(f"Login failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + login_admin() diff --git a/scripts/stress-test/setup/mock_openai_server.py b/scripts/stress-test/setup/mock_openai_server.py new file mode 100755 index 0000000000..7333c66e57 --- /dev/null +++ b/scripts/stress-test/setup/mock_openai_server.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 + +import json +import time +import uuid +from collections.abc import Iterator +from typing import Any + +from flask import Flask, Response, jsonify, request + +app = Flask(__name__) + +# Mock models list +MODELS = [ + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1677649963, + "owned_by": "openai", + }, + {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, + { + "id": "text-embedding-ada-002", + "object": "model", + "created": 1671217299, + "owned_by": "openai-internal", + }, +] + + +@app.route("/v1/models", methods=["GET"]) +def list_models() -> Any: + """List available models.""" + return jsonify({"object": "list", "data": MODELS}) + + +@app.route("/v1/chat/completions", methods=["POST"]) +def chat_completions() -> Any: + """Handle chat completions.""" + data = request.json or {} + model = data.get("model", "gpt-3.5-turbo") + messages = data.get("messages", []) + stream = data.get("stream", False) + + # Generate mock response + response_content = "This is a mock response from the OpenAI server." + if messages: + last_message = messages[-1].get("content", "") + response_content = f"Mock response to: {last_message[:100]}..." + + if stream: + # Streaming response + def generate() -> Iterator[str]: + # Send initial chunk + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(chunk)}\n\n" + + # Send content in chunks + words = response_content.split() + for word in words: + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": word + " "}, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(chunk)}\n\n" + time.sleep(0.05) # Simulate streaming delay + + # Send final chunk + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + + return Response(generate(), mimetype="text/event-stream") + else: + # Non-streaming response + return jsonify( + { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": response_content}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": len(str(messages)), + "completion_tokens": len(response_content.split()), + "total_tokens": len(str(messages)) + len(response_content.split()), + }, + } + ) + + +@app.route("/v1/completions", methods=["POST"]) +def completions() -> Any: + """Handle text completions.""" + data = request.json or {} + model = data.get("model", "gpt-3.5-turbo-instruct") + prompt = data.get("prompt", "") + + response_text = f"Mock completion for prompt: {prompt[:100]}..." + + return jsonify( + { + "id": f"cmpl-{uuid.uuid4().hex[:8]}", + "object": "text_completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "text": response_text, + "index": 0, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": len(prompt.split()), + "completion_tokens": len(response_text.split()), + "total_tokens": len(prompt.split()) + len(response_text.split()), + }, + } + ) + + +@app.route("/v1/embeddings", methods=["POST"]) +def embeddings() -> Any: + """Handle embeddings requests.""" + data = request.json or {} + model = data.get("model", "text-embedding-ada-002") + input_text = data.get("input", "") + + # Generate mock embedding (1536 dimensions for ada-002) + mock_embedding = [0.1] * 1536 + + return jsonify( + { + "object": "list", + "data": [{"object": "embedding", "embedding": mock_embedding, "index": 0}], + "model": model, + "usage": { + "prompt_tokens": len(input_text.split()), + "total_tokens": len(input_text.split()), + }, + } + ) + + +@app.route("/v1/models/", methods=["GET"]) +def get_model(model_id: str) -> tuple[Any, int] | Any: + """Get specific model details.""" + for model in MODELS: + if model["id"] == model_id: + return jsonify(model) + + return jsonify({"error": "Model not found"}), 404 + + +@app.route("/health", methods=["GET"]) +def health() -> Any: + """Health check endpoint.""" + return jsonify({"status": "healthy"}) + + +if __name__ == "__main__": + print("🚀 Starting Mock OpenAI Server on http://localhost:5004") + print("Available endpoints:") + print(" - GET /v1/models") + print(" - POST /v1/chat/completions") + print(" - POST /v1/completions") + print(" - POST /v1/embeddings") + print(" - GET /v1/models/") + print(" - GET /health") + app.run(host="0.0.0.0", port=5004, debug=True) diff --git a/scripts/stress-test/setup/publish_workflow.py b/scripts/stress-test/setup/publish_workflow.py new file mode 100755 index 0000000000..b772eccebd --- /dev/null +++ b/scripts/stress-test/setup/publish_workflow.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import json + +import httpx +from common import Logger, config_helper + + +def publish_workflow() -> None: + """Publish the imported workflow app.""" + + log = Logger("PublishWorkflow") + log.header("Publishing Workflow") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + return + + # Read app_id from config + app_id = config_helper.get_app_id() + if not app_id: + log.error("No app_id found in config") + return + + log.step(f"Publishing workflow for app: {app_id}") + + # API endpoint for publishing workflow + base_url = "http://localhost:5001" + publish_endpoint = f"{base_url}/console/api/apps/{app_id}/workflows/publish" + + # Publish payload + publish_payload = {"marked_name": "", "marked_comment": ""} + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the publish request + with httpx.Client() as client: + response = client.post( + publish_endpoint, + json=publish_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200 or response.status_code == 201: + log.success("Workflow published successfully!") + log.key_value("App ID", app_id) + + # Try to parse response if it has JSON content + if response.text: + try: + response_data = response.json() + if response_data: + log.debug(f"Response: {json.dumps(response_data, indent=2)}") + except json.JSONDecodeError: + # Response might be empty or non-JSON + pass + + elif response.status_code == 401: + log.error("Workflow publish failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + elif response.status_code == 404: + log.error("Workflow publish failed: App not found") + log.info("Make sure the app was imported successfully") + else: + log.error(f"Workflow publish failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + publish_workflow() diff --git a/scripts/stress-test/setup/run_workflow.py b/scripts/stress-test/setup/run_workflow.py new file mode 100755 index 0000000000..6da0ff17be --- /dev/null +++ b/scripts/stress-test/setup/run_workflow.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import json + +import httpx +from common import Logger, config_helper + + +def run_workflow(question: str = "fake question", streaming: bool = True) -> None: + """Run the workflow app with a question.""" + + log = Logger("RunWorkflow") + log.header("Running Workflow") + + # Read API key from config + api_token = config_helper.get_api_key() + if not api_token: + log.error("No API token found in config") + log.info("Please run create_api_key.py first to create an API key") + return + + log.key_value("Question", question) + log.key_value("Mode", "Streaming" if streaming else "Blocking") + log.separator() + + # API endpoint for running workflow + base_url = "http://localhost:5001" + run_endpoint = f"{base_url}/v1/workflows/run" + + # Run payload + run_payload = { + "inputs": {"question": question}, + "user": "default user", + "response_mode": "streaming" if streaming else "blocking", + } + + headers = { + "Authorization": f"Bearer {api_token}", + "Content-Type": "application/json", + } + + try: + # Make the run request + with httpx.Client(timeout=30.0) as client: + if streaming: + # Handle streaming response + with client.stream( + "POST", + run_endpoint, + json=run_payload, + headers=headers, + ) as response: + if response.status_code == 200: + log.success("Workflow started successfully!") + log.separator() + log.step("Streaming response:") + + for line in response.iter_lines(): + if line.startswith("data: "): + data_str = line[6:] # Remove "data: " prefix + if data_str == "[DONE]": + log.success("Workflow completed!") + break + try: + data = json.loads(data_str) + event = data.get("event") + + if event == "workflow_started": + log.progress(f"Workflow started: {data.get('data', {}).get('id')}") + elif event == "node_started": + node_data = data.get("data", {}) + log.progress( + f"Node started: {node_data.get('node_type')} - {node_data.get('title')}" + ) + elif event == "node_finished": + node_data = data.get("data", {}) + log.progress( + f"Node finished: {node_data.get('node_type')} - {node_data.get('title')}" + ) + + # Print output if it's the LLM node + outputs = node_data.get("outputs", {}) + if outputs.get("text"): + log.separator() + log.info("💬 LLM Response:") + log.info(outputs.get("text"), indent=2) + log.separator() + + elif event == "workflow_finished": + workflow_data = data.get("data", {}) + outputs = workflow_data.get("outputs", {}) + if outputs.get("answer"): + log.separator() + log.info("📤 Final Answer:") + log.info(outputs.get("answer"), indent=2) + log.separator() + log.key_value( + "Total tokens", + str(workflow_data.get("total_tokens", 0)), + ) + log.key_value( + "Total steps", + str(workflow_data.get("total_steps", 0)), + ) + + elif event == "error": + log.error(f"Error: {data.get('message')}") + + except json.JSONDecodeError: + # Some lines might not be JSON + pass + else: + log.error(f"Workflow run failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + else: + # Handle blocking response + response = client.post( + run_endpoint, + json=run_payload, + headers=headers, + ) + + if response.status_code == 200: + log.success("Workflow completed successfully!") + response_data = response.json() + + log.separator() + log.debug(f"Full response: {json.dumps(response_data, indent=2)}") + + # Extract the answer if available + outputs = response_data.get("data", {}).get("outputs", {}) + if outputs.get("answer"): + log.separator() + log.info("📤 Final Answer:") + log.info(outputs.get("answer"), indent=2) + else: + log.error(f"Workflow run failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except httpx.TimeoutException: + log.error("Request timed out") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + # Allow passing question as command line argument + if len(sys.argv) > 1: + question = " ".join(sys.argv[1:]) + else: + question = "What is the capital of France?" + + run_workflow(question=question, streaming=True) diff --git a/scripts/stress-test/setup/setup_admin.py b/scripts/stress-test/setup/setup_admin.py new file mode 100755 index 0000000000..a5e9161210 --- /dev/null +++ b/scripts/stress-test/setup/setup_admin.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import httpx +from common import Logger, config_helper + + +def setup_admin_account() -> None: + """Setup Dify API with an admin account.""" + + log = Logger("SetupAdmin") + log.header("Setting up Admin Account") + + # Admin account credentials + admin_config = { + "email": "test@dify.ai", + "username": "dify", + "password": "password123", + } + + # Save credentials to config file + if config_helper.write_config("admin_config", admin_config): + log.info(f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}") + + # API setup endpoint + base_url = "http://localhost:5001" + setup_endpoint = f"{base_url}/console/api/setup" + + # Prepare setup payload + setup_payload = { + "email": admin_config["email"], + "name": admin_config["username"], + "password": admin_config["password"], + } + + log.step("Configuring Dify with admin account...") + + try: + # Make the setup request + with httpx.Client() as client: + response = client.post( + setup_endpoint, + json=setup_payload, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code == 201: + log.success("Admin account created successfully!") + log.key_value("Email", admin_config["email"]) + log.key_value("Username", admin_config["username"]) + + elif response.status_code == 400: + log.warning("Setup may have already been completed or invalid data provided") + log.debug(f"Response: {response.text}") + else: + log.error(f"Setup failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + setup_admin_account() diff --git a/scripts/stress-test/setup_all.py b/scripts/stress-test/setup_all.py new file mode 100755 index 0000000000..ece420f925 --- /dev/null +++ b/scripts/stress-test/setup_all.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 + +import socket +import subprocess +import sys +import time +from pathlib import Path + +from common import Logger, ProgressLogger + + +def run_script(script_name: str, description: str) -> bool: + """Run a Python script and return success status.""" + script_path = Path(__file__).parent / "setup" / script_name + + if not script_path.exists(): + print(f"❌ Script not found: {script_path}") + return False + + print(f"\n{'=' * 60}") + print(f"🚀 {description}") + print(f" Running: {script_name}") + print(f"{'=' * 60}") + + try: + result = subprocess.run( + [sys.executable, str(script_path)], + capture_output=True, + text=True, + check=False, + ) + + # Print output + if result.stdout: + print(result.stdout) + if result.stderr: + print(result.stderr, file=sys.stderr) + + if result.returncode != 0: + print(f"❌ Script failed with exit code: {result.returncode}") + return False + + print(f"✅ {script_name} completed successfully") + return True + + except Exception as e: + print(f"❌ Error running {script_name}: {e}") + return False + + +def check_port(host: str, port: int, service_name: str) -> bool: + """Check if a service is running on the specified port.""" + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2) + result = sock.connect_ex((host, port)) + sock.close() + + if result == 0: + Logger().success(f"{service_name} is running on port {port}") + return True + else: + Logger().error(f"{service_name} is not accessible on port {port}") + return False + except Exception as e: + Logger().error(f"Error checking {service_name}: {e}") + return False + + +def main() -> None: + """Run all setup scripts in order.""" + + log = Logger("Setup") + log.box("Dify Stress Test Setup - Full Installation") + + # Check if required services are running + log.step("Checking required services...") + log.separator() + + dify_running = check_port("localhost", 5001, "Dify API server") + if not dify_running: + log.info("To start Dify API server:") + log.list_item("Run: ./dev/start-api") + + mock_running = check_port("localhost", 5004, "Mock OpenAI server") + if not mock_running: + log.info("To start Mock OpenAI server:") + log.list_item("Run: python scripts/stress-test/setup/mock_openai_server.py") + + if not dify_running or not mock_running: + print("\n⚠️ Both services must be running before proceeding.") + retry = input("\nWould you like to check again? (yes/no): ") + if retry.lower() in ["yes", "y"]: + return main() # Recursively call main to check again + else: + print("❌ Setup cancelled. Please start the required services and try again.") + sys.exit(1) + + log.success("All required services are running!") + input("\nPress Enter to continue with setup...") + + # Define setup steps + setup_steps = [ + ("setup_admin.py", "Creating admin account"), + ("login_admin.py", "Logging in and getting access token"), + ("install_openai_plugin.py", "Installing OpenAI plugin"), + ("configure_openai_plugin.py", "Configuring OpenAI plugin with mock server"), + ("import_workflow_app.py", "Importing workflow application"), + ("create_api_key.py", "Creating API key for the app"), + ("publish_workflow.py", "Publishing the workflow"), + ] + + # Create progress logger + progress = ProgressLogger(len(setup_steps), log) + failed_step = None + + for script, description in setup_steps: + progress.next_step(description) + success = run_script(script, description) + + if not success: + failed_step = script + break + + # Small delay between steps + time.sleep(1) + + log.separator() + + if failed_step: + log.error(f"Setup failed at: {failed_step}") + log.separator() + log.info("Troubleshooting:") + log.list_item("Check if the Dify API server is running (./dev/start-api)") + log.list_item("Check if the mock OpenAI server is running (port 5004)") + log.list_item("Review the error messages above") + log.list_item("Run cleanup.py and try again") + sys.exit(1) + else: + progress.complete() + log.separator() + log.success("Setup completed successfully!") + log.info("Next steps:") + log.list_item("Test the workflow:") + log.info( + ' python scripts/stress-test/setup/run_workflow.py "Your question here"', + indent=4, + ) + log.list_item("To clean up and start over:") + log.info(" python scripts/stress-test/cleanup.py", indent=4) + + # Optionally run a test + log.separator() + test_input = input("Would you like to run a test workflow now? (yes/no): ") + + if test_input.lower() in ["yes", "y"]: + log.step("Running test workflow...") + run_script("run_workflow.py", "Testing workflow with default question") + + +if __name__ == "__main__": + main() diff --git a/scripts/stress-test/sse_benchmark.py b/scripts/stress-test/sse_benchmark.py new file mode 100644 index 0000000000..99fe2b20f4 --- /dev/null +++ b/scripts/stress-test/sse_benchmark.py @@ -0,0 +1,750 @@ +#!/usr/bin/env python3 +""" +SSE (Server-Sent Events) Stress Test for Dify Workflow API + +This script stress tests the streaming performance of Dify's workflow execution API, +measuring key metrics like connection rate, event throughput, and time to first event (TTFE). +""" + +import json +import logging +import os +import random +import statistics +import sys +import threading +import time +from collections import deque +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Literal, TypeAlias, TypedDict + +import requests.exceptions +from locust import HttpUser, between, constant, events, task + +# Add the stress-test directory to path to import common modules +sys.path.insert(0, str(Path(__file__).parent)) +from common.config_helper import ConfigHelper # type: ignore[import-not-found] + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +# Configuration from environment +WORKFLOW_PATH = os.getenv("WORKFLOW_PATH", "/v1/workflows/run") +CONNECT_TIMEOUT = float(os.getenv("CONNECT_TIMEOUT", "10")) +READ_TIMEOUT = float(os.getenv("READ_TIMEOUT", "60")) +TERMINAL_EVENTS = [e.strip() for e in os.getenv("TERMINAL_EVENTS", "workflow_finished,error").split(",") if e.strip()] +QUESTIONS_FILE = os.getenv("QUESTIONS_FILE", "") + + +# Type definitions +ErrorType: TypeAlias = Literal[ + "connection_error", + "timeout", + "invalid_json", + "http_4xx", + "http_5xx", + "early_termination", + "invalid_response", +] + + +class ErrorCounts(TypedDict): + """Error count tracking""" + + connection_error: int + timeout: int + invalid_json: int + http_4xx: int + http_5xx: int + early_termination: int + invalid_response: int + + +class SSEEvent(TypedDict): + """Server-Sent Event structure""" + + data: str + event: str + id: str | None + + +class WorkflowInputs(TypedDict): + """Workflow input structure""" + + question: str + + +class WorkflowRequestData(TypedDict): + """Workflow request payload""" + + inputs: WorkflowInputs + response_mode: Literal["streaming"] + user: str + + +class ParsedEventData(TypedDict, total=False): + """Parsed event data from SSE stream""" + + event: str + task_id: str + workflow_run_id: str + data: object # For dynamic content + created_at: int + + +class LocustStats(TypedDict): + """Locust statistics structure""" + + total_requests: int + total_failures: int + avg_response_time: float + min_response_time: float + max_response_time: float + + +class ReportData(TypedDict): + """JSON report structure""" + + timestamp: str + duration_seconds: float + metrics: dict[str, object] # Metrics as dict for JSON serialization + locust_stats: LocustStats | None + + +@dataclass +class StreamMetrics: + """Metrics for a single stream""" + + stream_duration: float + events_count: int + bytes_received: int + ttfe: float + inter_event_times: list[float] + + +@dataclass +class MetricsSnapshot: + """Snapshot of current metrics state""" + + active_connections: int + total_connections: int + total_events: int + connection_rate: float + event_rate: float + overall_conn_rate: float + overall_event_rate: float + ttfe_avg: float + ttfe_min: float + ttfe_max: float + ttfe_p50: float + ttfe_p95: float + ttfe_samples: int + ttfe_total_samples: int # Total TTFE samples collected (not limited by window) + error_counts: ErrorCounts + stream_duration_avg: float + stream_duration_p50: float + stream_duration_p95: float + events_per_stream_avg: float + inter_event_latency_avg: float + inter_event_latency_p50: float + inter_event_latency_p95: float + + +class MetricsTracker: + def __init__(self) -> None: + self.lock = threading.Lock() + self.active_connections = 0 + self.total_connections = 0 + self.total_events = 0 + self.start_time = time.time() + + # Enhanced metrics with memory limits + self.max_samples = 10000 # Prevent unbounded growth + self.ttfe_samples: deque[float] = deque(maxlen=self.max_samples) + self.ttfe_total_count = 0 # Track total TTFE samples collected + + # For rate calculations - no maxlen to avoid artificial limits + self.connection_times: deque[float] = deque() + self.event_times: deque[float] = deque() + self.last_stats_time = time.time() + self.last_total_connections = 0 + self.last_total_events = 0 + self.stream_metrics: deque[StreamMetrics] = deque(maxlen=self.max_samples) + self.error_counts: ErrorCounts = ErrorCounts( + connection_error=0, + timeout=0, + invalid_json=0, + http_4xx=0, + http_5xx=0, + early_termination=0, + invalid_response=0, + ) + + def connection_started(self) -> None: + with self.lock: + self.active_connections += 1 + self.total_connections += 1 + self.connection_times.append(time.time()) + + def connection_ended(self) -> None: + with self.lock: + self.active_connections -= 1 + + def event_received(self) -> None: + with self.lock: + self.total_events += 1 + self.event_times.append(time.time()) + + def record_ttfe(self, ttfe_ms: float) -> None: + with self.lock: + self.ttfe_samples.append(ttfe_ms) # deque handles maxlen + self.ttfe_total_count += 1 # Increment total counter + + def record_stream_metrics(self, metrics: StreamMetrics) -> None: + with self.lock: + self.stream_metrics.append(metrics) # deque handles maxlen + + def record_error(self, error_type: ErrorType) -> None: + with self.lock: + self.error_counts[error_type] += 1 + + def get_stats(self) -> MetricsSnapshot: + with self.lock: + current_time = time.time() + time_window = 10.0 # 10 second window for rate calculation + + # Clean up old timestamps outside the window + cutoff_time = current_time - time_window + while self.connection_times and self.connection_times[0] < cutoff_time: + self.connection_times.popleft() + while self.event_times and self.event_times[0] < cutoff_time: + self.event_times.popleft() + + # Calculate rates based on actual window or elapsed time + window_duration = min(time_window, current_time - self.start_time) + if window_duration > 0: + conn_rate = len(self.connection_times) / window_duration + event_rate = len(self.event_times) / window_duration + else: + conn_rate = 0 + event_rate = 0 + + # Calculate TTFE statistics + if self.ttfe_samples: + avg_ttfe = statistics.mean(self.ttfe_samples) + min_ttfe = min(self.ttfe_samples) + max_ttfe = max(self.ttfe_samples) + p50_ttfe = statistics.median(self.ttfe_samples) + if len(self.ttfe_samples) >= 2: + quantiles = statistics.quantiles(self.ttfe_samples, n=20, method="inclusive") + p95_ttfe = quantiles[18] # 19th of 19 quantiles = 95th percentile + else: + p95_ttfe = max_ttfe + else: + avg_ttfe = min_ttfe = max_ttfe = p50_ttfe = p95_ttfe = 0 + + # Calculate stream metrics + if self.stream_metrics: + durations = [m.stream_duration for m in self.stream_metrics] + events_per_stream = [m.events_count for m in self.stream_metrics] + stream_duration_avg = statistics.mean(durations) + stream_duration_p50 = statistics.median(durations) + stream_duration_p95 = ( + statistics.quantiles(durations, n=20, method="inclusive")[18] + if len(durations) >= 2 + else max(durations) + if durations + else 0 + ) + events_per_stream_avg = statistics.mean(events_per_stream) if events_per_stream else 0 + + # Calculate inter-event latency statistics + all_inter_event_times = [] + for m in self.stream_metrics: + all_inter_event_times.extend(m.inter_event_times) + + if all_inter_event_times: + inter_event_latency_avg = statistics.mean(all_inter_event_times) + inter_event_latency_p50 = statistics.median(all_inter_event_times) + inter_event_latency_p95 = ( + statistics.quantiles(all_inter_event_times, n=20, method="inclusive")[18] + if len(all_inter_event_times) >= 2 + else max(all_inter_event_times) + ) + else: + inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0 + else: + stream_duration_avg = stream_duration_p50 = stream_duration_p95 = events_per_stream_avg = 0 + inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0 + + # Also calculate overall average rates + total_elapsed = current_time - self.start_time + overall_conn_rate = self.total_connections / total_elapsed if total_elapsed > 0 else 0 + overall_event_rate = self.total_events / total_elapsed if total_elapsed > 0 else 0 + + return MetricsSnapshot( + active_connections=self.active_connections, + total_connections=self.total_connections, + total_events=self.total_events, + connection_rate=conn_rate, + event_rate=event_rate, + overall_conn_rate=overall_conn_rate, + overall_event_rate=overall_event_rate, + ttfe_avg=avg_ttfe, + ttfe_min=min_ttfe, + ttfe_max=max_ttfe, + ttfe_p50=p50_ttfe, + ttfe_p95=p95_ttfe, + ttfe_samples=len(self.ttfe_samples), + ttfe_total_samples=self.ttfe_total_count, # Return total count + error_counts=ErrorCounts(**self.error_counts), + stream_duration_avg=stream_duration_avg, + stream_duration_p50=stream_duration_p50, + stream_duration_p95=stream_duration_p95, + events_per_stream_avg=events_per_stream_avg, + inter_event_latency_avg=inter_event_latency_avg, + inter_event_latency_p50=inter_event_latency_p50, + inter_event_latency_p95=inter_event_latency_p95, + ) + + +# Global metrics instance +metrics = MetricsTracker() + + +class SSEParser: + """Parser for Server-Sent Events according to W3C spec""" + + def __init__(self) -> None: + self.data_buffer: list[str] = [] + self.event_type: str | None = None + self.event_id: str | None = None + + def parse_line(self, line: str) -> SSEEvent | None: + """Parse a single SSE line and return event if complete""" + # Empty line signals end of event + if not line: + if self.data_buffer: + event = SSEEvent( + data="\n".join(self.data_buffer), + event=self.event_type or "message", + id=self.event_id, + ) + self.data_buffer = [] + self.event_type = None + self.event_id = None + return event + return None + + # Comment line + if line.startswith(":"): + return None + + # Parse field + if ":" in line: + field, value = line.split(":", 1) + value = value.lstrip() + + if field == "data": + self.data_buffer.append(value) + elif field == "event": + self.event_type = value + elif field == "id": + self.event_id = value + + return None + + +# Note: SSEClient removed - we'll handle SSE parsing directly in the task for better Locust integration + + +class DifyWorkflowUser(HttpUser): + """Locust user for testing Dify workflow SSE endpoints""" + + # Use constant wait for streaming workloads + wait_time = constant(0) if os.getenv("WAIT_TIME", "0") == "0" else between(1, 3) + + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) # type: ignore[arg-type] + + # Load API configuration + config_helper = ConfigHelper() + self.api_token = config_helper.get_api_key() + + if not self.api_token: + raise ValueError("API key not found. Please run setup_all.py first.") + + # Load questions from file or use defaults + if QUESTIONS_FILE and os.path.exists(QUESTIONS_FILE): + with open(QUESTIONS_FILE) as f: + self.questions = [line.strip() for line in f if line.strip()] + else: + self.questions = [ + "What is artificial intelligence?", + "Explain quantum computing", + "What is machine learning?", + "How do neural networks work?", + "What is renewable energy?", + ] + + self.user_counter = 0 + + def on_start(self) -> None: + """Called when a user starts""" + self.user_counter = 0 + + @task + def test_workflow_stream(self) -> None: + """Test workflow SSE streaming endpoint""" + + question = random.choice(self.questions) + self.user_counter += 1 + + headers = { + "Authorization": f"Bearer {self.api_token}", + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + data = WorkflowRequestData( + inputs=WorkflowInputs(question=question), + response_mode="streaming", + user=f"user_{self.user_counter}", + ) + + start_time = time.time() + first_event_time = None + event_count = 0 + inter_event_times: list[float] = [] + last_event_time = None + ttfe = 0 + request_success = False + bytes_received = 0 + + metrics.connection_started() + + # Use catch_response context manager directly + with self.client.request( + method="POST", + url=WORKFLOW_PATH, + headers=headers, + json=data, + stream=True, + catch_response=True, + timeout=(CONNECT_TIMEOUT, READ_TIMEOUT), + name="/v1/workflows/run", # Name for Locust stats + ) as response: + try: + # Validate response + if response.status_code >= 400: + error_type: ErrorType = "http_4xx" if response.status_code < 500 else "http_5xx" + metrics.record_error(error_type) + response.failure(f"HTTP {response.status_code}") + return + + content_type = response.headers.get("Content-Type", "") + if "text/event-stream" not in content_type and "application/json" not in content_type: + logger.error(f"Expected text/event-stream, got: {content_type}") + metrics.record_error("invalid_response") + response.failure(f"Invalid content type: {content_type}") + return + + # Parse SSE events + parser = SSEParser() + + for line in response.iter_lines(decode_unicode=True): + # Check if runner is stopping + if getattr(self.environment.runner, "state", "") in ( + "stopping", + "stopped", + ): + logger.debug("Runner stopping, breaking streaming loop") + break + + if line is not None: + bytes_received += len(line.encode("utf-8")) + + # Parse SSE line + event = parser.parse_line(line if line is not None else "") + if event: + event_count += 1 + current_time = time.time() + metrics.event_received() + + # Track inter-event timing + if last_event_time: + inter_event_times.append((current_time - last_event_time) * 1000) + last_event_time = current_time + + if first_event_time is None: + first_event_time = current_time + ttfe = (first_event_time - start_time) * 1000 + metrics.record_ttfe(ttfe) + + try: + # Parse event data + event_data = event.get("data", "") + if event_data: + if event_data == "[DONE]": + logger.debug("Received [DONE] sentinel") + request_success = True + break + + try: + parsed_event: ParsedEventData = json.loads(event_data) + # Check for terminal events + if parsed_event.get("event") in TERMINAL_EVENTS: + logger.debug(f"Received terminal event: {parsed_event.get('event')}") + request_success = True + break + except json.JSONDecodeError as e: + logger.debug(f"JSON decode error: {e} for data: {event_data[:100]}") + metrics.record_error("invalid_json") + + except Exception as e: + logger.error(f"Error processing event: {e}") + + # Mark success only if terminal condition was met or events were received + if request_success: + response.success() + elif event_count > 0: + # Got events but no proper terminal condition + metrics.record_error("early_termination") + response.failure("Stream ended without terminal event") + else: + response.failure("No events received") + + except ( + requests.exceptions.ConnectTimeout, + requests.exceptions.ReadTimeout, + ) as e: + metrics.record_error("timeout") + response.failure(f"Timeout: {e}") + except ( + requests.exceptions.ConnectionError, + requests.exceptions.RequestException, + ) as e: + metrics.record_error("connection_error") + response.failure(f"Connection error: {e}") + except Exception as e: + response.failure(str(e)) + raise + finally: + metrics.connection_ended() + + # Record stream metrics + if event_count > 0: + stream_duration = (time.time() - start_time) * 1000 + stream_metrics = StreamMetrics( + stream_duration=stream_duration, + events_count=event_count, + bytes_received=bytes_received, + ttfe=ttfe, + inter_event_times=inter_event_times, + ) + metrics.record_stream_metrics(stream_metrics) + logger.debug( + f"Stream completed: {event_count} events, {stream_duration:.1f}ms, success={request_success}" + ) + else: + logger.warning("No events received in stream") + + +# Event handlers +@events.test_start.add_listener # type: ignore[misc] +def on_test_start(environment: object, **kwargs: object) -> None: + logger.info("=" * 80) + logger.info(" " * 25 + "DIFY SSE BENCHMARK - REAL-TIME METRICS") + logger.info("=" * 80) + logger.info(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + logger.info("=" * 80) + + # Periodic stats reporting + def report_stats() -> None: + if not hasattr(environment, "runner"): + return + runner = environment.runner + while hasattr(runner, "state") and runner.state not in ["stopped", "stopping"]: + time.sleep(5) # Report every 5 seconds + if hasattr(runner, "state") and runner.state == "running": + stats = metrics.get_stats() + + # Only log on master node in distributed mode + is_master = ( + not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True + ) + if is_master: + # Clear previous lines and show updated stats + logger.info("\n" + "=" * 80) + logger.info( + f"{'METRIC':<25} {'CURRENT':>15} {'RATE (10s)':>15} {'AVG (overall)':>15} {'TOTAL':>12}" + ) + logger.info("-" * 80) + + # Active SSE Connections + logger.info( + f"{'Active SSE Connections':<25} {stats.active_connections:>15,d} {'-':>15} {'-':>12} {'-':>12}" + ) + + # New Connection Rate + logger.info( + f"{'New Connections':<25} {'-':>15} {stats.connection_rate:>13.2f}/s {stats.overall_conn_rate:>13.2f}/s {stats.total_connections:>12,d}" + ) + + # Event Throughput + logger.info( + f"{'Event Throughput':<25} {'-':>15} {stats.event_rate:>13.2f}/s {stats.overall_event_rate:>13.2f}/s {stats.total_events:>12,d}" + ) + + logger.info("-" * 80) + logger.info( + f"{'TIME TO FIRST EVENT':<25} {'AVG':>15} {'P50':>10} {'P95':>10} {'MIN':>10} {'MAX':>10}" + ) + logger.info( + f"{'(TTFE in ms)':<25} {stats.ttfe_avg:>15.1f} {stats.ttfe_p50:>10.1f} {stats.ttfe_p95:>10.1f} {stats.ttfe_min:>10.1f} {stats.ttfe_max:>10.1f}" + ) + logger.info( + f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)" + ) + logger.info(f"{'Total Samples':<25} {stats.ttfe_total_samples:>15,d}") + + # Inter-event latency + if stats.inter_event_latency_avg > 0: + logger.info("-" * 80) + logger.info(f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}") + logger.info( + f"{'(ms between events)':<25} {stats.inter_event_latency_avg:>15.1f} {stats.inter_event_latency_p50:>10.1f} {stats.inter_event_latency_p95:>10.1f}" + ) + + # Error stats + if any(stats.error_counts.values()): + logger.info("-" * 80) + logger.info(f"{'ERROR TYPE':<25} {'COUNT':>15}") + for error_type, count in stats.error_counts.items(): + if isinstance(count, int) and count > 0: + logger.info(f"{error_type:<25} {count:>15,d}") + + logger.info("=" * 80) + + # Show Locust stats summary + if hasattr(environment, "stats") and hasattr(environment.stats, "total"): + total = environment.stats.total + if hasattr(total, "num_requests") and total.num_requests > 0: + logger.info( + f"{'LOCUST STATS':<25} {'Requests':>12} {'Fails':>8} {'Avg (ms)':>12} {'Min':>8} {'Max':>8}" + ) + logger.info("-" * 80) + logger.info( + f"{'Aggregated':<25} {total.num_requests:>12,d} " + f"{total.num_failures:>8,d} " + f"{total.avg_response_time:>12.1f} " + f"{total.min_response_time:>8.0f} " + f"{total.max_response_time:>8.0f}" + ) + logger.info("=" * 80) + + threading.Thread(target=report_stats, daemon=True).start() + + +@events.test_stop.add_listener # type: ignore[misc] +def on_test_stop(environment: object, **kwargs: object) -> None: + stats = metrics.get_stats() + test_duration = time.time() - metrics.start_time + + # Log final results + logger.info("\n" + "=" * 80) + logger.info(" " * 30 + "FINAL BENCHMARK RESULTS") + logger.info("=" * 80) + logger.info(f"Test Duration: {test_duration:.1f} seconds") + logger.info("-" * 80) + + logger.info("") + logger.info("CONNECTIONS") + logger.info(f" {'Total Connections:':<30} {stats.total_connections:>10,d}") + logger.info(f" {'Final Active:':<30} {stats.active_connections:>10,d}") + logger.info(f" {'Average Rate:':<30} {stats.overall_conn_rate:>10.2f} conn/s") + + logger.info("") + logger.info("EVENTS") + logger.info(f" {'Total Events Received:':<30} {stats.total_events:>10,d}") + logger.info(f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s") + logger.info(f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s") + + logger.info("") + logger.info("STREAM METRICS") + logger.info(f" {'Avg Stream Duration:':<30} {stats.stream_duration_avg:>10.1f} ms") + logger.info(f" {'P50 Stream Duration:':<30} {stats.stream_duration_p50:>10.1f} ms") + logger.info(f" {'P95 Stream Duration:':<30} {stats.stream_duration_p95:>10.1f} ms") + logger.info(f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}") + + logger.info("") + logger.info("INTER-EVENT LATENCY") + logger.info(f" {'Average:':<30} {stats.inter_event_latency_avg:>10.1f} ms") + logger.info(f" {'Median (P50):':<30} {stats.inter_event_latency_p50:>10.1f} ms") + logger.info(f" {'95th Percentile:':<30} {stats.inter_event_latency_p95:>10.1f} ms") + + logger.info("") + logger.info("TIME TO FIRST EVENT (ms)") + logger.info(f" {'Average:':<30} {stats.ttfe_avg:>10.1f} ms") + logger.info(f" {'Median (P50):':<30} {stats.ttfe_p50:>10.1f} ms") + logger.info(f" {'95th Percentile:':<30} {stats.ttfe_p95:>10.1f} ms") + logger.info(f" {'Minimum:':<30} {stats.ttfe_min:>10.1f} ms") + logger.info(f" {'Maximum:':<30} {stats.ttfe_max:>10.1f} ms") + logger.info( + f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})" + ) + logger.info(f" {'Total Samples:':<30} {stats.ttfe_total_samples:>10,d}") + + # Error summary + if any(stats.error_counts.values()): + logger.info("") + logger.info("ERRORS") + for error_type, count in stats.error_counts.items(): + if isinstance(count, int) and count > 0: + logger.info(f" {error_type:<30} {count:>10,d}") + + logger.info("=" * 80 + "\n") + + # Export machine-readable report (only on master node) + is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True + if is_master: + export_json_report(stats, test_duration, environment) + + +def export_json_report(stats: MetricsSnapshot, duration: float, environment: object) -> None: + """Export metrics to JSON file for CI/CD analysis""" + + reports_dir = Path(__file__).parent / "reports" + reports_dir.mkdir(exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + report_file = reports_dir / f"sse_metrics_{timestamp}.json" + + # Access environment.stats.total attributes safely + locust_stats: LocustStats | None = None + if hasattr(environment, "stats") and hasattr(environment.stats, "total"): + total = environment.stats.total + if hasattr(total, "num_requests") and total.num_requests > 0: + locust_stats = LocustStats( + total_requests=total.num_requests, + total_failures=total.num_failures, + avg_response_time=total.avg_response_time, + min_response_time=total.min_response_time, + max_response_time=total.max_response_time, + ) + + report_data = ReportData( + timestamp=datetime.now().isoformat(), + duration_seconds=duration, + metrics=asdict(stats), # type: ignore[arg-type] + locust_stats=locust_stats, + ) + + with open(report_file, "w") as f: + json.dump(report_data, f, indent=2) + + logger.info(f"Exported metrics to {report_file}") diff --git a/sdks/nodejs-client/index.js b/sdks/nodejs-client/index.js index 0ba7bba8bb..3025cc2ab6 100644 --- a/sdks/nodejs-client/index.js +++ b/sdks/nodejs-client/index.js @@ -95,10 +95,9 @@ export class DifyClient { headerParams = {} ) { const headers = { - ...{ + Authorization: `Bearer ${this.apiKey}`, "Content-Type": "application/json", - }, ...headerParams }; diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py index d00c207afa..e866472f45 100644 --- a/sdks/python-client/dify_client/__init__.py +++ b/sdks/python-client/dify_client/__init__.py @@ -1,7 +1,15 @@ from dify_client.client import ( ChatClient, CompletionClient, - WorkflowClient, - KnowledgeBaseClient, DifyClient, + KnowledgeBaseClient, + WorkflowClient, ) + +__all__ = [ + "ChatClient", + "CompletionClient", + "DifyClient", + "KnowledgeBaseClient", + "WorkflowClient", +] diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index abd0e7ae29..791cb98a1b 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -1,5 +1,5 @@ import json - +from typing import Literal import requests @@ -8,16 +8,16 @@ class DifyClient: self.api_key = api_key self.base_url = base_url - def _send_request(self, method, endpoint, json=None, params=None, stream=False): + def _send_request( + self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False + ): headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } url = f"{self.base_url}{endpoint}" - response = requests.request( - method, url, json=json, params=params, headers=headers, stream=stream - ) + response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream) return response @@ -25,37 +25,35 @@ class DifyClient: headers = {"Authorization": f"Bearer {self.api_key}"} url = f"{self.base_url}{endpoint}" - response = requests.request( - method, url, data=data, headers=headers, files=files - ) + response = requests.request(method, url, data=data, headers=headers, files=files) return response - def message_feedback(self, message_id, rating, user): + def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): data = {"rating": rating, "user": user} return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - def get_application_parameters(self, user): + def get_application_parameters(self, user: str): params = {"user": user} return self._send_request("GET", "/parameters", params=params) - def file_upload(self, user, files): + def file_upload(self, user: str, files: dict): data = {"user": user} - return self._send_request_with_files( - "POST", "/files/upload", data=data, files=files - ) + return self._send_request_with_files("POST", "/files/upload", data=data, files=files) def text_to_audio(self, text: str, user: str, streaming: bool = False): data = {"text": text, "user": user, "streaming": streaming} return self._send_request("POST", "/text-to-audio", json=data) - def get_meta(self, user): + def get_meta(self, user: str): params = {"user": user} return self._send_request("GET", "/meta", params=params) class CompletionClient(DifyClient): - def create_completion_message(self, inputs, response_mode, user, files=None): + def create_completion_message( + self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None + ): data = { "inputs": inputs, "response_mode": response_mode, @@ -76,7 +74,7 @@ class ChatClient(DifyClient): inputs: dict, query: str, user: str, - response_mode: str = "blocking", + response_mode: Literal["blocking", "streaming"] = "blocking", conversation_id: str | None = None, files: dict | None = None, ): @@ -99,9 +97,7 @@ class ChatClient(DifyClient): def get_suggested(self, message_id: str, user: str): params = {"user": user} - return self._send_request( - "GET", f"/messages/{message_id}/suggested", params=params - ) + return self._send_request("GET", f"/messages/{message_id}/suggested", params=params) def stop_message(self, task_id: str, user: str): data = {"user": user} @@ -112,10 +108,9 @@ class ChatClient(DifyClient): user: str, last_id: str | None = None, limit: int | None = None, - pinned: bool | None = None + pinned: bool | None = None, ): - params = {"user": user, "last_id": last_id, - "limit": limit, "pinned": pinned} + params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} return self._send_request("GET", "/conversations", params=params) def get_conversation_messages( @@ -123,7 +118,7 @@ class ChatClient(DifyClient): user: str, conversation_id: str | None = None, first_id: str | None = None, - limit: int | None = None + limit: int | None = None, ): params = {"user": user} @@ -136,13 +131,9 @@ class ChatClient(DifyClient): return self._send_request("GET", "/messages", params=params) - def rename_conversation( - self, conversation_id: str, name: str, auto_generate: bool, user: str - ): + def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str): data = {"name": name, "auto_generate": auto_generate, "user": user} - return self._send_request( - "POST", f"/conversations/{conversation_id}/name", data - ) + return self._send_request("POST", f"/conversations/{conversation_id}/name", data) def delete_conversation(self, conversation_id: str, user: str): data = {"user": user} @@ -155,9 +146,7 @@ class ChatClient(DifyClient): class WorkflowClient(DifyClient): - def run( - self, inputs: dict, response_mode: str = "streaming", user: str = "abc-123" - ): + def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"): data = {"inputs": inputs, "response_mode": response_mode, "user": user} return self._send_request("POST", "/workflows/run", data) @@ -172,7 +161,7 @@ class WorkflowClient(DifyClient): class KnowledgeBaseClient(DifyClient): def __init__( self, - api_key, + api_key: str, base_url: str = "https://api.dify.ai/v1", dataset_id: str | None = None, ): @@ -197,13 +186,9 @@ class KnowledgeBaseClient(DifyClient): return self._send_request("POST", "/datasets", {"name": name}, **kwargs) def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): - return self._send_request( - "GET", f"/datasets?page={page}&limit={page_size}", **kwargs - ) + return self._send_request("GET", f"/datasets?page={page}&limit={page_size}", **kwargs) - def create_document_by_text( - self, name, text, extra_params: dict | None = None, **kwargs - ): + def create_document_by_text(self, name, text, extra_params: dict | None = None, **kwargs): """ Create a document by text. @@ -241,7 +226,7 @@ class KnowledgeBaseClient(DifyClient): return self._send_request("POST", url, json=data, **kwargs) def update_document_by_text( - self, document_id, name, text, extra_params: dict | None = None, **kwargs + self, document_id: str, name: str, text: str, extra_params: dict | None = None, **kwargs ): """ Update a document by text. @@ -272,13 +257,11 @@ class KnowledgeBaseClient(DifyClient): data = {"name": name, "text": text} if extra_params is not None and isinstance(extra_params, dict): data.update(extra_params) - url = ( - f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" - ) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" return self._send_request("POST", url, json=data, **kwargs) def create_document_by_file( - self, file_path, original_document_id=None, extra_params: dict | None = None + self, file_path: str, original_document_id: str | None = None, extra_params: dict | None = None ): """ Create a document by file. @@ -315,13 +298,9 @@ class KnowledgeBaseClient(DifyClient): if original_document_id is not None: data["original_document_id"] = original_document_id url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" - return self._send_request_with_files( - "POST", url, {"data": json.dumps(data)}, files - ) + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - def update_document_by_file( - self, document_id, file_path, extra_params: dict | None = None - ): + def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None): """ Update a document by file. @@ -351,12 +330,8 @@ class KnowledgeBaseClient(DifyClient): data = {} if extra_params is not None and isinstance(extra_params, dict): data.update(extra_params) - url = ( - f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" - ) - return self._send_request_with_files( - "POST", url, {"data": json.dumps(data)}, files - ) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) def batch_indexing_status(self, batch_id: str, **kwargs): """ @@ -377,7 +352,7 @@ class KnowledgeBaseClient(DifyClient): url = f"/datasets/{self._get_dataset_id()}" return self._send_request("DELETE", url) - def delete_document(self, document_id): + def delete_document(self, document_id: str): """ Delete a document. @@ -409,7 +384,7 @@ class KnowledgeBaseClient(DifyClient): url = f"/datasets/{self._get_dataset_id()}/documents" return self._send_request("GET", url, params=params, **kwargs) - def add_segments(self, document_id, segments, **kwargs): + def add_segments(self, document_id: str, segments: list[dict], **kwargs): """ Add segments to a document. @@ -423,7 +398,7 @@ class KnowledgeBaseClient(DifyClient): def query_segments( self, - document_id, + document_id: str, keyword: str | None = None, status: str | None = None, **kwargs, @@ -445,7 +420,7 @@ class KnowledgeBaseClient(DifyClient): params.update(kwargs["params"]) return self._send_request("GET", url, params=params, **kwargs) - def delete_document_segment(self, document_id, segment_id): + def delete_document_segment(self, document_id: str, segment_id: str): """ Delete a segment from a document. @@ -456,7 +431,7 @@ class KnowledgeBaseClient(DifyClient): url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" return self._send_request("DELETE", url) - def update_document_segment(self, document_id, segment_id, segment_data, **kwargs): + def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs): """ Update a segment in a document. diff --git a/sdks/python-client/setup.py b/sdks/python-client/setup.py index 7340fffb4c..a05f6410fb 100644 --- a/sdks/python-client/setup.py +++ b/sdks/python-client/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -with open("README.md", "r", encoding="utf-8") as fh: +with open("README.md", encoding="utf-8") as fh: long_description = fh.read() setup( diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py index 52032417c0..fce1b11eba 100644 --- a/sdks/python-client/tests/test_client.py +++ b/sdks/python-client/tests/test_client.py @@ -18,9 +18,7 @@ FILE_PATH_BASE = os.path.dirname(__file__) class TestKnowledgeBaseClient(unittest.TestCase): def setUp(self): self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL) - self.README_FILE_PATH = os.path.abspath( - os.path.join(FILE_PATH_BASE, "../README.md") - ) + self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md")) self.dataset_id = None self.document_id = None self.segment_id = None @@ -28,9 +26,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _get_dataset_kb_client(self): self.assertIsNotNone(self.dataset_id) - return KnowledgeBaseClient( - API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id - ) + return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id) def test_001_create_dataset(self): response = self.knowledge_base_client.create_dataset(name="test_dataset") @@ -76,9 +72,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _test_004_update_document_by_text(self): client = self._get_dataset_kb_client() self.assertIsNotNone(self.document_id) - response = client.update_document_by_text( - self.document_id, "test_document_updated", "test_text_updated" - ) + response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated") data = response.json() self.assertIn("document", data) self.assertIn("batch", data) @@ -93,9 +87,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _test_006_update_document_by_file(self): client = self._get_dataset_kb_client() self.assertIsNotNone(self.document_id) - response = client.update_document_by_file( - self.document_id, self.README_FILE_PATH - ) + response = client.update_document_by_file(self.document_id, self.README_FILE_PATH) data = response.json() self.assertIn("document", data) self.assertIn("batch", data) @@ -125,9 +117,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _test_010_add_segments(self): client = self._get_dataset_kb_client() - response = client.add_segments( - self.document_id, [{"content": "test text segment 1"}] - ) + response = client.add_segments(self.document_id, [{"content": "test text segment 1"}]) data = response.json() self.assertIn("data", data) self.assertGreater(len(data["data"]), 0) @@ -174,18 +164,12 @@ class TestChatClient(unittest.TestCase): self.chat_client = ChatClient(API_KEY) def test_create_chat_message(self): - response = self.chat_client.create_chat_message( - {}, "Hello, World!", "test_user" - ) + response = self.chat_client.create_chat_message({}, "Hello, World!", "test_user") self.assertIn("answer", response.text) def test_create_chat_message_with_vision_model_by_remote_url(self): - files = [ - {"type": "image", "transfer_method": "remote_url", "url": "your_image_url"} - ] - response = self.chat_client.create_chat_message( - {}, "Describe the picture.", "test_user", files=files - ) + files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}] + response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) self.assertIn("answer", response.text) def test_create_chat_message_with_vision_model_by_local_file(self): @@ -196,15 +180,11 @@ class TestChatClient(unittest.TestCase): "upload_file_id": "your_file_id", } ] - response = self.chat_client.create_chat_message( - {}, "Describe the picture.", "test_user", files=files - ) + response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) self.assertIn("answer", response.text) def test_get_conversation_messages(self): - response = self.chat_client.get_conversation_messages( - "test_user", "your_conversation_id" - ) + response = self.chat_client.get_conversation_messages("test_user", "your_conversation_id") self.assertIn("answer", response.text) def test_get_conversations(self): @@ -223,9 +203,7 @@ class TestCompletionClient(unittest.TestCase): self.assertIn("answer", response.text) def test_create_completion_message_with_vision_model_by_remote_url(self): - files = [ - {"type": "image", "transfer_method": "remote_url", "url": "your_image_url"} - ] + files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}] response = self.completion_client.create_completion_message( {"query": "Describe the picture."}, "blocking", "test_user", files ) @@ -250,9 +228,7 @@ class TestDifyClient(unittest.TestCase): self.dify_client = DifyClient(API_KEY) def test_message_feedback(self): - response = self.dify_client.message_feedback( - "your_message_id", "like", "test_user" - ) + response = self.dify_client.message_feedback("your_message_id", "like", "test_user") self.assertIn("success", response.text) def test_get_application_parameters(self): diff --git a/web/.oxlintrc.json b/web/.oxlintrc.json index 1bfcca58f5..57eddd34fb 100644 --- a/web/.oxlintrc.json +++ b/web/.oxlintrc.json @@ -45,7 +45,7 @@ "no-unassigned-vars": "warn", "no-unsafe-finally": "warn", "no-unsafe-negation": "warn", - "no-unsafe-optional-chaining": "warn", + "no-unsafe-optional-chaining": "error", "no-unused-labels": "warn", "no-unused-private-class-members": "warn", "no-unused-vars": "warn", diff --git a/web/__tests__/goto-anything/match-action.test.ts b/web/__tests__/goto-anything/match-action.test.ts new file mode 100644 index 0000000000..3df9c0d533 --- /dev/null +++ b/web/__tests__/goto-anything/match-action.test.ts @@ -0,0 +1,235 @@ +import type { ActionItem } from '../../app/components/goto-anything/actions/types' + +// Mock the entire actions module to avoid import issues +jest.mock('../../app/components/goto-anything/actions', () => ({ + matchAction: jest.fn(), +})) + +jest.mock('../../app/components/goto-anything/actions/commands/registry') + +// Import after mocking to get mocked version +import { matchAction } from '../../app/components/goto-anything/actions' +import { slashCommandRegistry } from '../../app/components/goto-anything/actions/commands/registry' + +// Implement the actual matchAction logic for testing +const actualMatchAction = (query: string, actions: Record) => { + const result = Object.values(actions).find((action) => { + // Special handling for slash commands + if (action.key === '/') { + // Get all registered commands from the registry + const allCommands = slashCommandRegistry.getAllCommands() + + // Check if query matches any registered command + return allCommands.some((cmd) => { + const cmdPattern = `/${cmd.name}` + + // For direct mode commands, don't match (keep in command selector) + if (cmd.mode === 'direct') + return false + + // For submenu mode commands, match when complete command is entered + return query === cmdPattern || query.startsWith(`${cmdPattern} `) + }) + } + + const reg = new RegExp(`^(${action.key}|${action.shortcut})(?:\\s|$)`) + return reg.test(query) + }) + return result +} + +// Replace mock with actual implementation +;(matchAction as jest.Mock).mockImplementation(actualMatchAction) + +describe('matchAction Logic', () => { + const mockActions: Record = { + app: { + key: '@app', + shortcut: '@a', + title: 'Search Applications', + description: 'Search apps', + search: jest.fn(), + }, + knowledge: { + key: '@knowledge', + shortcut: '@kb', + title: 'Search Knowledge', + description: 'Search knowledge bases', + search: jest.fn(), + }, + slash: { + key: '/', + shortcut: '/', + title: 'Commands', + description: 'Execute commands', + search: jest.fn(), + }, + } + + beforeEach(() => { + jest.clearAllMocks() + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'docs', mode: 'direct' }, + { name: 'community', mode: 'direct' }, + { name: 'feedback', mode: 'direct' }, + { name: 'account', mode: 'direct' }, + { name: 'theme', mode: 'submenu' }, + { name: 'language', mode: 'submenu' }, + ]) + }) + + describe('@ Actions Matching', () => { + it('should match @app with key', () => { + const result = matchAction('@app', mockActions) + expect(result).toBe(mockActions.app) + }) + + it('should match @app with shortcut', () => { + const result = matchAction('@a', mockActions) + expect(result).toBe(mockActions.app) + }) + + it('should match @knowledge with key', () => { + const result = matchAction('@knowledge', mockActions) + expect(result).toBe(mockActions.knowledge) + }) + + it('should match @knowledge with shortcut @kb', () => { + const result = matchAction('@kb', mockActions) + expect(result).toBe(mockActions.knowledge) + }) + + it('should match with text after action', () => { + const result = matchAction('@app search term', mockActions) + expect(result).toBe(mockActions.app) + }) + + it('should not match partial @ actions', () => { + const result = matchAction('@ap', mockActions) + expect(result).toBeUndefined() + }) + }) + + describe('Slash Commands Matching', () => { + describe('Direct Mode Commands', () => { + it('should not match direct mode commands', () => { + const result = matchAction('/docs', mockActions) + expect(result).toBeUndefined() + }) + + it('should not match direct mode with arguments', () => { + const result = matchAction('/docs something', mockActions) + expect(result).toBeUndefined() + }) + + it('should not match any direct mode command', () => { + expect(matchAction('/community', mockActions)).toBeUndefined() + expect(matchAction('/feedback', mockActions)).toBeUndefined() + expect(matchAction('/account', mockActions)).toBeUndefined() + }) + }) + + describe('Submenu Mode Commands', () => { + it('should match submenu mode commands exactly', () => { + const result = matchAction('/theme', mockActions) + expect(result).toBe(mockActions.slash) + }) + + it('should match submenu mode with arguments', () => { + const result = matchAction('/theme dark', mockActions) + expect(result).toBe(mockActions.slash) + }) + + it('should match all submenu commands', () => { + expect(matchAction('/language', mockActions)).toBe(mockActions.slash) + expect(matchAction('/language en', mockActions)).toBe(mockActions.slash) + }) + }) + + describe('Slash Without Command', () => { + it('should not match single slash', () => { + const result = matchAction('/', mockActions) + expect(result).toBeUndefined() + }) + + it('should not match unregistered commands', () => { + const result = matchAction('/unknown', mockActions) + expect(result).toBeUndefined() + }) + }) + }) + + describe('Edge Cases', () => { + it('should handle empty query', () => { + const result = matchAction('', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle whitespace only', () => { + const result = matchAction(' ', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle regular text without actions', () => { + const result = matchAction('search something', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle special characters', () => { + const result = matchAction('#tag', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle multiple @ or /', () => { + expect(matchAction('@@app', mockActions)).toBeUndefined() + expect(matchAction('//theme', mockActions)).toBeUndefined() + }) + }) + + describe('Mode-based Filtering', () => { + it('should filter direct mode commands from matching', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'test', mode: 'direct' }, + ]) + + const result = matchAction('/test', mockActions) + expect(result).toBeUndefined() + }) + + it('should allow submenu mode commands to match', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'test', mode: 'submenu' }, + ]) + + const result = matchAction('/test', mockActions) + expect(result).toBe(mockActions.slash) + }) + + it('should treat undefined mode as submenu', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'test' }, // No mode specified + ]) + + const result = matchAction('/test', mockActions) + expect(result).toBe(mockActions.slash) + }) + }) + + describe('Registry Integration', () => { + it('should call getAllCommands when matching slash', () => { + matchAction('/theme', mockActions) + expect(slashCommandRegistry.getAllCommands).toHaveBeenCalled() + }) + + it('should not call getAllCommands for @ actions', () => { + matchAction('@app', mockActions) + expect(slashCommandRegistry.getAllCommands).not.toHaveBeenCalled() + }) + + it('should handle empty command list', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([]) + const result = matchAction('/anything', mockActions) + expect(result).toBeUndefined() + }) + }) +}) diff --git a/web/__tests__/goto-anything/scope-command-tags.test.tsx b/web/__tests__/goto-anything/scope-command-tags.test.tsx new file mode 100644 index 0000000000..339e259a06 --- /dev/null +++ b/web/__tests__/goto-anything/scope-command-tags.test.tsx @@ -0,0 +1,134 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import '@testing-library/jest-dom' + +// Type alias for search mode +type SearchMode = 'scopes' | 'commands' | null + +// Mock component to test tag display logic +const TagDisplay: React.FC<{ searchMode: SearchMode }> = ({ searchMode }) => { + if (!searchMode) return null + + return ( +
+ {searchMode === 'scopes' ? 'SCOPES' : 'COMMANDS'} +
+ ) +} + +describe('Scope and Command Tags', () => { + describe('Tag Display Logic', () => { + it('should display SCOPES for @ actions', () => { + render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + expect(screen.queryByText('COMMANDS')).not.toBeInTheDocument() + }) + + it('should display COMMANDS for / actions', () => { + render() + expect(screen.getByText('COMMANDS')).toBeInTheDocument() + expect(screen.queryByText('SCOPES')).not.toBeInTheDocument() + }) + + it('should not display any tag when searchMode is null', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + }) + + describe('Search Mode Detection', () => { + const getSearchMode = (query: string): SearchMode => { + if (query.startsWith('@')) return 'scopes' + if (query.startsWith('/')) return 'commands' + return null + } + + it('should detect scopes mode for @ queries', () => { + expect(getSearchMode('@app')).toBe('scopes') + expect(getSearchMode('@knowledge')).toBe('scopes') + expect(getSearchMode('@plugin')).toBe('scopes') + expect(getSearchMode('@node')).toBe('scopes') + }) + + it('should detect commands mode for / queries', () => { + expect(getSearchMode('/theme')).toBe('commands') + expect(getSearchMode('/language')).toBe('commands') + expect(getSearchMode('/docs')).toBe('commands') + }) + + it('should return null for regular queries', () => { + expect(getSearchMode('')).toBe(null) + expect(getSearchMode('search term')).toBe(null) + expect(getSearchMode('app')).toBe(null) + }) + + it('should handle queries with spaces', () => { + expect(getSearchMode('@app search')).toBe('scopes') + expect(getSearchMode('/theme dark')).toBe('commands') + }) + }) + + describe('Tag Styling', () => { + it('should apply correct styling classes', () => { + const { container } = render() + const tagContainer = container.querySelector('.flex.items-center.gap-1.text-xs.text-text-tertiary') + expect(tagContainer).toBeInTheDocument() + }) + + it('should use hardcoded English text', () => { + // Verify that tags are hardcoded and not using i18n + render() + const scopesText = screen.getByText('SCOPES') + expect(scopesText.textContent).toBe('SCOPES') + + render() + const commandsText = screen.getByText('COMMANDS') + expect(commandsText.textContent).toBe('COMMANDS') + }) + }) + + describe('Integration with Search States', () => { + const SearchComponent: React.FC<{ query: string }> = ({ query }) => { + let searchMode: SearchMode = null + + if (query.startsWith('@')) searchMode = 'scopes' + else if (query.startsWith('/')) searchMode = 'commands' + + return ( +
+ + +
+ ) + } + + it('should update tag when switching between @ and /', () => { + const { rerender } = render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.queryByText('SCOPES')).not.toBeInTheDocument() + expect(screen.getByText('COMMANDS')).toBeInTheDocument() + }) + + it('should hide tag when clearing search', () => { + const { rerender } = render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.queryByText('SCOPES')).not.toBeInTheDocument() + expect(screen.queryByText('COMMANDS')).not.toBeInTheDocument() + }) + + it('should maintain correct tag during search refinement', () => { + const { rerender } = render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + }) + }) +}) diff --git a/web/__tests__/goto-anything/slash-command-modes.test.tsx b/web/__tests__/goto-anything/slash-command-modes.test.tsx new file mode 100644 index 0000000000..f8126958fc --- /dev/null +++ b/web/__tests__/goto-anything/slash-command-modes.test.tsx @@ -0,0 +1,212 @@ +import '@testing-library/jest-dom' +import { slashCommandRegistry } from '../../app/components/goto-anything/actions/commands/registry' +import type { SlashCommandHandler } from '../../app/components/goto-anything/actions/commands/types' + +// Mock the registry +jest.mock('../../app/components/goto-anything/actions/commands/registry') + +describe('Slash Command Dual-Mode System', () => { + const mockDirectCommand: SlashCommandHandler = { + name: 'docs', + description: 'Open documentation', + mode: 'direct', + execute: jest.fn(), + search: jest.fn().mockResolvedValue([ + { + id: 'docs', + title: 'Documentation', + description: 'Open documentation', + type: 'command' as const, + data: { command: 'navigation.docs', args: {} }, + }, + ]), + register: jest.fn(), + unregister: jest.fn(), + } + + const mockSubmenuCommand: SlashCommandHandler = { + name: 'theme', + description: 'Change theme', + mode: 'submenu', + search: jest.fn().mockResolvedValue([ + { + id: 'theme-light', + title: 'Light Theme', + description: 'Switch to light theme', + type: 'command' as const, + data: { command: 'theme.set', args: { theme: 'light' } }, + }, + { + id: 'theme-dark', + title: 'Dark Theme', + description: 'Switch to dark theme', + type: 'command' as const, + data: { command: 'theme.set', args: { theme: 'dark' } }, + }, + ]), + register: jest.fn(), + unregister: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + ;(slashCommandRegistry as any).findCommand = jest.fn((name: string) => { + if (name === 'docs') return mockDirectCommand + if (name === 'theme') return mockSubmenuCommand + return null + }) + ;(slashCommandRegistry as any).getAllCommands = jest.fn(() => [ + mockDirectCommand, + mockSubmenuCommand, + ]) + }) + + describe('Direct Mode Commands', () => { + it('should execute immediately when selected', () => { + const mockSetShow = jest.fn() + const mockSetSearchQuery = jest.fn() + + // Simulate command selection + const handler = slashCommandRegistry.findCommand('docs') + expect(handler?.mode).toBe('direct') + + if (handler?.mode === 'direct' && handler.execute) { + handler.execute() + mockSetShow(false) + mockSetSearchQuery('') + } + + expect(mockDirectCommand.execute).toHaveBeenCalled() + expect(mockSetShow).toHaveBeenCalledWith(false) + expect(mockSetSearchQuery).toHaveBeenCalledWith('') + }) + + it('should not enter submenu for direct mode commands', () => { + const handler = slashCommandRegistry.findCommand('docs') + expect(handler?.mode).toBe('direct') + expect(handler?.execute).toBeDefined() + }) + + it('should close modal after execution', () => { + const mockModalClose = jest.fn() + + const handler = slashCommandRegistry.findCommand('docs') + if (handler?.mode === 'direct' && handler.execute) { + handler.execute() + mockModalClose() + } + + expect(mockModalClose).toHaveBeenCalled() + }) + }) + + describe('Submenu Mode Commands', () => { + it('should show options instead of executing immediately', async () => { + const handler = slashCommandRegistry.findCommand('theme') + expect(handler?.mode).toBe('submenu') + + const results = await handler?.search('', 'en') + expect(results).toHaveLength(2) + expect(results?.[0].title).toBe('Light Theme') + expect(results?.[1].title).toBe('Dark Theme') + }) + + it('should not have execute function for submenu mode', () => { + const handler = slashCommandRegistry.findCommand('theme') + expect(handler?.mode).toBe('submenu') + expect(handler?.execute).toBeUndefined() + }) + + it('should keep modal open for selection', () => { + const mockModalClose = jest.fn() + + const handler = slashCommandRegistry.findCommand('theme') + // For submenu mode, modal should not close immediately + expect(handler?.mode).toBe('submenu') + expect(mockModalClose).not.toHaveBeenCalled() + }) + }) + + describe('Mode Detection and Routing', () => { + it('should correctly identify direct mode commands', () => { + const commands = slashCommandRegistry.getAllCommands() + const directCommands = commands.filter(cmd => cmd.mode === 'direct') + const submenuCommands = commands.filter(cmd => cmd.mode === 'submenu') + + expect(directCommands).toContainEqual(expect.objectContaining({ name: 'docs' })) + expect(submenuCommands).toContainEqual(expect.objectContaining({ name: 'theme' })) + }) + + it('should handle missing mode property gracefully', () => { + const commandWithoutMode: SlashCommandHandler = { + name: 'test', + description: 'Test command', + search: jest.fn(), + register: jest.fn(), + unregister: jest.fn(), + } + + ;(slashCommandRegistry as any).findCommand = jest.fn(() => commandWithoutMode) + + const handler = slashCommandRegistry.findCommand('test') + // Default behavior should be submenu when mode is not specified + expect(handler?.mode).toBeUndefined() + expect(handler?.execute).toBeUndefined() + }) + }) + + describe('Enter Key Handling', () => { + // Helper function to simulate key handler behavior + const createKeyHandler = () => { + return (commandKey: string) => { + if (commandKey.startsWith('/')) { + const commandName = commandKey.substring(1) + const handler = slashCommandRegistry.findCommand(commandName) + if (handler?.mode === 'direct' && handler.execute) { + handler.execute() + return true // Indicates handled + } + } + return false + } + } + + it('should trigger direct execution on Enter for direct mode', () => { + const keyHandler = createKeyHandler() + const handled = keyHandler('/docs') + expect(handled).toBe(true) + expect(mockDirectCommand.execute).toHaveBeenCalled() + }) + + it('should not trigger direct execution for submenu mode', () => { + const keyHandler = createKeyHandler() + const handled = keyHandler('/theme') + expect(handled).toBe(false) + expect(mockSubmenuCommand.search).not.toHaveBeenCalled() + }) + }) + + describe('Command Registration', () => { + it('should register both direct and submenu commands', () => { + mockDirectCommand.register?.({}) + mockSubmenuCommand.register?.({ setTheme: jest.fn() }) + + expect(mockDirectCommand.register).toHaveBeenCalled() + expect(mockSubmenuCommand.register).toHaveBeenCalled() + }) + + it('should handle unregistration for both command types', () => { + // Test unregister for direct command + mockDirectCommand.unregister?.() + expect(mockDirectCommand.unregister).toHaveBeenCalled() + + // Test unregister for submenu command + mockSubmenuCommand.unregister?.() + expect(mockSubmenuCommand.unregister).toHaveBeenCalled() + + // Verify both were called independently + expect(mockDirectCommand.unregister).toHaveBeenCalledTimes(1) + expect(mockSubmenuCommand.unregister).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx index 29af3e3a57..107442761a 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx @@ -60,7 +60,7 @@ export default function MailAndCodeAuth() { setEmail(e.target.value)} />
- +
diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 757a862c3c..d22577c9ad 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -321,7 +321,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx background={appDetail.icon_background} imageUrl={appDetail.icon_url} /> -
+
{appDetail.name}
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
diff --git a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx index 73a25d9ab9..ac07333712 100644 --- a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx @@ -148,7 +148,11 @@ const DatasetSidebarDropdown = ({ ) })} - +
diff --git a/web/app/components/app/configuration/config-vision/index.tsx b/web/app/components/app/configuration/config-vision/index.tsx index f719822bf9..f0904b3fd8 100644 --- a/web/app/components/app/configuration/config-vision/index.tsx +++ b/web/app/components/app/configuration/config-vision/index.tsx @@ -45,7 +45,7 @@ const ConfigVision: FC = () => { if (draft.file) { draft.file.enabled = (draft.file.allowed_file_types?.length ?? 0) > 0 draft.file.image = { - ...(draft.file.image || {}), + ...draft.file.image, enabled: value, } } diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index e6b6c83846..5c87eea3ca 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -50,6 +50,7 @@ export type IGetAutomaticResProps = { onFinished: (res: GenRes) => void flowId?: string nodeId?: string + editorId?: string currentPrompt?: string isBasicMode?: boolean } @@ -76,6 +77,7 @@ const GetAutomaticRes: FC = ({ onClose, flowId, nodeId, + editorId, currentPrompt, isBasicMode, onFinished, @@ -132,7 +134,8 @@ const GetAutomaticRes: FC = ({ }, ] - const [instructionFromSessionStorage, setInstruction] = useSessionStorageState(`improve-instruction-${flowId}${isBasicMode ? '' : `-${nodeId}`}`) + // eslint-disable-next-line sonarjs/no-nested-template-literals, sonarjs/no-nested-conditional + const [instructionFromSessionStorage, setInstruction] = useSessionStorageState(`improve-instruction-${flowId}${isBasicMode ? '' : `-${nodeId}${editorId ? `-${editorId}` : ''}`}`) const instruction = instructionFromSessionStorage || '' const [ideaOutput, setIdeaOutput] = useState('') @@ -166,7 +169,7 @@ const GetAutomaticRes: FC = ({ return true } const [isLoading, { setTrue: setLoadingTrue, setFalse: setLoadingFalse }] = useBoolean(false) - const storageKey = `${flowId}${isBasicMode ? '' : `-${nodeId}`}` + const storageKey = `${flowId}${isBasicMode ? '' : `-${nodeId}${editorId ? `-${editorId}` : ''}`}` const { addVersion, current, currentVersionIndex, setCurrentVersionIndex, versions } = useGenData({ storageKey, }) diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 70a45a4bbe..cd73874c2c 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -1,6 +1,6 @@ 'use client' -import { useCallback, useRef, useState } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' @@ -35,14 +35,15 @@ type CreateAppProps = { onSuccess: () => void onClose: () => void onCreateFromTemplate?: () => void + defaultAppMode?: AppMode } -function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) { +function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: CreateAppProps) { const { t } = useTranslation() const { push } = useRouter() const { notify } = useContext(ToastContext) - const [appMode, setAppMode] = useState('advanced-chat') + const [appMode, setAppMode] = useState(defaultAppMode || 'advanced-chat') const [appIcon, setAppIcon] = useState({ type: 'emoji', icon: '🤖', background: '#FFEAD5' }) const [showAppIconPicker, setShowAppIconPicker] = useState(false) const [name, setName] = useState('') @@ -55,6 +56,11 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) const isCreatingRef = useRef(false) + useEffect(() => { + if (appMode === 'chat' || appMode === 'agent-chat' || appMode === 'completion') + setIsAppTypeExpanded(true) + }, [appMode]) + const onCreate = useCallback(async () => { if (!appMode) { notify({ type: 'error', message: t('app.newApp.appTypeRequired') }) @@ -264,7 +270,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) type CreateAppDialogProps = CreateAppProps & { show: boolean } -const CreateAppModal = ({ show, onClose, onSuccess, onCreateFromTemplate }: CreateAppDialogProps) => { +const CreateAppModal = ({ show, onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: CreateAppDialogProps) => { return ( - + ) } diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 6d08958c67..e96793ff72 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -281,12 +281,21 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { )} { - (isGettingUserCanAccessApp || !userCanAccessApp?.result) ? null : <> - - - + (!systemFeatures.webapp_auth.enabled) + ? <> + + + + : !(isGettingUserCanAccessApp || !userCanAccessApp?.result) && ( + <> + + + + ) } { diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index aa85fb1313..4ee9a6d6d5 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -211,14 +211,14 @@ const List = () => { {(data && data[0].total > 0) ?
{isCurrentWorkspaceEditor - && } + && } {data.map(({ data: apps }) => apps.map(app => ( )))}
:
{isCurrentWorkspaceEditor - && } + && }
} diff --git a/web/app/components/apps/new-app-card.tsx b/web/app/components/apps/new-app-card.tsx index 451d2ae326..6ceeb47982 100644 --- a/web/app/components/apps/new-app-card.tsx +++ b/web/app/components/apps/new-app-card.tsx @@ -26,12 +26,14 @@ export type CreateAppCardProps = { className?: string onSuccess?: () => void ref: React.RefObject + selectedAppType?: string } const CreateAppCard = ({ ref, className, onSuccess, + selectedAppType, }: CreateAppCardProps) => { const { t } = useTranslation() const { onPlanInfoChanged } = useProviderContext() @@ -86,6 +88,7 @@ const CreateAppCard = ({ setShowNewAppTemplateDialog(true) setShowNewAppModal(false) }} + defaultAppMode={selectedAppType !== 'all' ? selectedAppType as any : undefined} /> )} {showNewAppTemplateDialog && ( diff --git a/web/app/components/base/agent-log-modal/tool-call.tsx b/web/app/components/base/agent-log-modal/tool-call.tsx index 499a70367c..433a20fd5d 100644 --- a/web/app/components/base/agent-log-modal/tool-call.tsx +++ b/web/app/components/base/agent-log-modal/tool-call.tsx @@ -33,7 +33,7 @@ const ToolCallItem: FC = ({ toolCall, isLLM = false, isFinal, tokens, obs if (time < 1) return `${(time * 1000).toFixed(3)} ms` if (time > 60) - return `${Number.parseInt(Math.round(time / 60).toString())} m ${(time % 60).toFixed(3)} s` + return `${Math.floor(time / 60)} m ${(time % 60).toFixed(3)} s` return `${time.toFixed(3)} s` } diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 17373cec9d..665e7e8bc3 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -682,7 +682,7 @@ export const useChat = ( updateChatTreeNode(targetAnswerId, { content: chatList[index].content, annotation: { - ...(chatList[index].annotation || {}), + ...chatList[index].annotation, id: '', } as Annotation, }) diff --git a/web/app/components/base/date-and-time-picker/date-picker/index.tsx b/web/app/components/base/date-and-time-picker/date-picker/index.tsx index f99b8257c1..f6b7973cb0 100644 --- a/web/app/components/base/date-and-time-picker/date-picker/index.tsx +++ b/web/app/components/base/date-and-time-picker/date-picker/index.tsx @@ -42,7 +42,14 @@ const DatePicker = ({ const [view, setView] = useState(ViewType.date) const containerRef = useRef(null) const isInitial = useRef(true) - const inputValue = useRef(value ? value.tz(timezone) : undefined).current + + // Normalize the value to ensure that all subsequent uses are Day.js objects. + const normalizedValue = useMemo(() => { + if (!value) return undefined + return dayjs.isDayjs(value) ? value.tz(timezone) : dayjs(value).tz(timezone) + }, [value, timezone]) + + const inputValue = useRef(normalizedValue).current const defaultValue = useRef(getDateWithTimezone({ timezone })).current const [currentDate, setCurrentDate] = useState(inputValue || defaultValue) @@ -68,8 +75,8 @@ const DatePicker = ({ return } clearMonthMapCache() - if (value) { - const newValue = getDateWithTimezone({ date: value, timezone }) + if (normalizedValue) { + const newValue = getDateWithTimezone({ date: normalizedValue, timezone }) setCurrentDate(newValue) setSelectedDate(newValue) onChange(newValue) @@ -88,9 +95,9 @@ const DatePicker = ({ } setView(ViewType.date) setIsOpen(true) - if (value) { - setCurrentDate(value) - setSelectedDate(value) + if (normalizedValue) { + setCurrentDate(normalizedValue) + setSelectedDate(normalizedValue) } } @@ -192,7 +199,7 @@ const DatePicker = ({ } const timeFormat = needTimePicker ? t('time.dateFormats.displayWithTime') : t('time.dateFormats.display') - const displayValue = value?.format(timeFormat) || '' + const displayValue = normalizedValue?.format(timeFormat) || '' const displayTime = selectedDate?.format('hh:mm A') || '--:-- --' const placeholderDate = isOpen && selectedDate ? selectedDate.format(timeFormat) : (placeholder || t('time.defaultPlaceholder')) @@ -204,7 +211,7 @@ const DatePicker = ({ > {renderTrigger ? (renderTrigger({ - value, + value: normalizedValue, selectedDate, isOpen, handleClear, diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx index 53db991e71..ec8681f37c 100644 --- a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx @@ -83,9 +83,7 @@ const OpeningSettingModal = ({ }, [handleSave, hideConfirmAddVar]) const autoAddVar = useCallback(() => { - onAutoAddPromptVariable?.([ - ...notIncludeKeys.map(key => getNewVar(key, 'string')), - ]) + onAutoAddPromptVariable?.(notIncludeKeys.map(key => getNewVar(key, 'string'))) hideConfirmAddVar() handleSave(true) }, [handleSave, hideConfirmAddVar, notIncludeKeys, onAutoAddPromptVariable]) diff --git a/web/app/components/base/ga/index.tsx b/web/app/components/base/ga/index.tsx index 7a95561754..81d84a85d3 100644 --- a/web/app/components/base/ga/index.tsx +++ b/web/app/components/base/ga/index.tsx @@ -24,7 +24,7 @@ const GA: FC = ({ if (IS_CE_EDITION) return null - const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') : '' + const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : '' return ( <> @@ -32,7 +32,7 @@ const GA: FC = ({ strategy="beforeInteractive" async src={`https://www.googletagmanager.com/gtag/js?id=${gaIdMaps[gaType]}`} - nonce={nonce!} + nonce={nonce ?? undefined} > {/* Cookie banner */} diff --git a/web/app/components/base/icons/assets/vender/knowledge/api-aggregate.svg b/web/app/components/base/icons/assets/vender/knowledge/api-aggregate.svg new file mode 100644 index 0000000000..e09f4bcece --- /dev/null +++ b/web/app/components/base/icons/assets/vender/knowledge/api-aggregate.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/base/icons/src/vender/knowledge/ApiAggregate.json b/web/app/components/base/icons/src/vender/knowledge/ApiAggregate.json new file mode 100644 index 0000000000..1057842352 --- /dev/null +++ b/web/app/components/base/icons/src/vender/knowledge/ApiAggregate.json @@ -0,0 +1,26 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "16", + "height": "16", + "viewBox": "0 0 16 16", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M5.92578 11.0094C5.92578 10.0174 5.12163 9.21256 4.12956 9.21256C3.13752 9.2126 2.33333 10.0174 2.33333 11.0094C2.33349 12.0014 3.13762 12.8056 4.12956 12.8057C5.12153 12.8057 5.92562 12.0014 5.92578 11.0094ZM13.6667 11.0094C13.6667 10.0174 12.8625 9.2126 11.8704 9.21256C10.8784 9.21256 10.0742 10.0174 10.0742 11.0094C10.0744 12.0014 10.8785 12.8057 11.8704 12.8057C12.8624 12.8056 13.6665 12.0014 13.6667 11.0094ZM9.79622 4.32389C9.79619 3.33186 8.99205 2.52767 8 2.52767C7.00796 2.52767 6.20382 3.33186 6.20378 4.32389C6.20378 5.31596 7.00793 6.12012 8 6.12012C8.99207 6.12012 9.79622 5.31596 9.79622 4.32389ZM11.1296 4.32389C11.1296 5.82351 10.0748 7.07628 8.66667 7.38184V7.9196L9.74284 8.71387C10.3012 8.19607 11.0489 7.87923 11.8704 7.87923C13.5989 7.87927 15 9.28101 15 11.0094C14.9998 12.7377 13.5988 14.139 11.8704 14.139C10.1421 14.139 8.74104 12.7378 8.74089 11.0094C8.74089 10.5837 8.82585 10.1776 8.97982 9.80762L8 9.08366L7.01953 9.80762C7.17356 10.1777 7.25911 10.5836 7.25911 11.0094C7.25896 12.7378 5.85791 14.139 4.12956 14.139C2.40124 14.139 1.00016 12.7377 1 11.0094C1 9.28101 2.40114 7.87927 4.12956 7.87923C4.95094 7.87923 5.69819 8.19627 6.25651 8.71387L7.33333 7.9196V7.38184C5.92523 7.07628 4.87044 5.82351 4.87044 4.32389C4.87048 2.59548 6.27158 1.19434 8 1.19434C9.72843 1.19434 11.1295 2.59548 11.1296 4.32389Z", + "fill": "currentColor" + }, + "children": [] + } + ] + }, + "name": "ApiAggregate" +} diff --git a/web/app/components/base/icons/src/vender/knowledge/ApiAggregate.tsx b/web/app/components/base/icons/src/vender/knowledge/ApiAggregate.tsx new file mode 100644 index 0000000000..64193e900b --- /dev/null +++ b/web/app/components/base/icons/src/vender/knowledge/ApiAggregate.tsx @@ -0,0 +1,20 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './ApiAggregate.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconData } from '@/app/components/base/icons/IconBase' + +const Icon = ( + { + ref, + ...props + }: React.SVGProps & { + ref?: React.RefObject>; + }, +) => + +Icon.displayName = 'ApiAggregate' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/knowledge/index.ts b/web/app/components/base/icons/src/vender/knowledge/index.ts index 74e5a5fce8..7239511af3 100644 --- a/web/app/components/base/icons/src/vender/knowledge/index.ts +++ b/web/app/components/base/icons/src/vender/knowledge/index.ts @@ -1,4 +1,5 @@ export { default as AddChunks } from './AddChunks' +export { default as ApiAggregate } from './ApiAggregate' export { default as ArrowShape } from './ArrowShape' export { default as Chunk } from './Chunk' export { default as Collapse } from './Collapse' diff --git a/web/app/components/base/markdown-blocks/link.tsx b/web/app/components/base/markdown-blocks/link.tsx index 0274ee0141..9bf13040a7 100644 --- a/web/app/components/base/markdown-blocks/link.tsx +++ b/web/app/components/base/markdown-blocks/link.tsx @@ -17,7 +17,7 @@ const Link = ({ node, children, ...props }: any) => { } else { const href = props.href || node.properties?.href - if (href && /^#[a-zA-Z0-9_\-]+$/.test(href.toString())) { + if (href && /^#[a-zA-Z0-9_-]+$/.test(href.toString())) { const handleClick = (e: React.MouseEvent) => { e.preventDefault() // scroll to target element if exists within the answer container diff --git a/web/app/components/base/markdown-blocks/think-block.tsx b/web/app/components/base/markdown-blocks/think-block.tsx index 46f992d758..a5813266f1 100644 --- a/web/app/components/base/markdown-blocks/think-block.tsx +++ b/web/app/components/base/markdown-blocks/think-block.tsx @@ -1,5 +1,6 @@ import React, { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import { useChatContext } from '../chat/chat/context' const hasEndThink = (children: any): boolean => { if (typeof children === 'string') @@ -35,6 +36,7 @@ const removeEndThink = (children: any): any => { } const useThinkTimer = (children: any) => { + const { isResponding } = useChatContext() const [startTime] = useState(Date.now()) const [elapsedTime, setElapsedTime] = useState(0) const [isComplete, setIsComplete] = useState(false) @@ -54,9 +56,9 @@ const useThinkTimer = (children: any) => { }, [startTime, isComplete]) useEffect(() => { - if (hasEndThink(children)) + if (hasEndThink(children) || !isResponding) setIsComplete(true) - }, [children]) + }, [children, isResponding]) return { elapsedTime, isComplete } } diff --git a/web/app/components/base/notion-page-selector/page-selector/index.tsx b/web/app/components/base/notion-page-selector/page-selector/index.tsx index a61b45cbf6..c293555582 100644 --- a/web/app/components/base/notion-page-selector/page-selector/index.tsx +++ b/web/app/components/base/notion-page-selector/page-selector/index.tsx @@ -241,7 +241,7 @@ const PageSelector = ({ if (current.expand) { current.expand = false - newDataList = [...dataList.filter(item => !descendantsIds.includes(item.page_id))] + newDataList = dataList.filter(item => !descendantsIds.includes(item.page_id)) } else { current.expand = true @@ -258,7 +258,7 @@ const PageSelector = ({ setDataList(newDataList) } - const copyValue = new Set([...value]) + const copyValue = new Set(value) const handleCheck = (index: number) => { const current = currentDataList[index] const pageId = current.page_id @@ -281,7 +281,7 @@ const PageSelector = ({ copyValue.add(pageId) } - onSelect(new Set([...copyValue])) + onSelect(new Set(copyValue)) } const handlePreview = (index: number) => { diff --git a/web/app/components/base/param-item/score-threshold-item.tsx b/web/app/components/base/param-item/score-threshold-item.tsx index b5557c80cf..3790a2a074 100644 --- a/web/app/components/base/param-item/score-threshold-item.tsx +++ b/web/app/components/base/param-item/score-threshold-item.tsx @@ -20,7 +20,6 @@ const VALUE_LIMIT = { max: 1, } -const key = 'score_threshold' const ScoreThresholdItem: FC = ({ className, value, @@ -39,9 +38,9 @@ const ScoreThresholdItem: FC = ({ return ( = ({ className, value, @@ -41,9 +40,9 @@ const TopKItem: FC = ({ return ( ), ) diff --git a/web/app/components/base/zendesk/index.tsx b/web/app/components/base/zendesk/index.tsx new file mode 100644 index 0000000000..b3d67eb390 --- /dev/null +++ b/web/app/components/base/zendesk/index.tsx @@ -0,0 +1,21 @@ +import { memo } from 'react' +import { type UnsafeUnwrappedHeaders, headers } from 'next/headers' +import Script from 'next/script' +import { IS_CE_EDITION, ZENDESK_WIDGET_KEY } from '@/config' + +const Zendesk = () => { + if (IS_CE_EDITION || !ZENDESK_WIDGET_KEY) + return null + + const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : '' + + return ( +