diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..6756a2fce6 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +version: 2 +updates: + - package-ecosystem: "npm" + directory: "/web" + schedule: + interval: "weekly" + open-pull-requests-limit: 2 + - package-ecosystem: "uv" + directory: "/api" + schedule: + interval: "weekly" + open-pull-requests-limit: 2 diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index be6ce80dfc..068ba686fa 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -22,12 +22,58 @@ jobs: # Fix lint errors uv run ruff check --fix . # Format code - uv run ruff format . + uv run ruff format .. + - name: ast-grep run: | uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all + # Convert Optional[T] to T | None (ignoring quoted types) + cat > /tmp/optional-rule.yml << 'EOF' + id: convert-optional-to-union + language: python + rule: + kind: generic_type + all: + - has: + kind: identifier + pattern: Optional + - has: + kind: type_parameter + has: + kind: type + pattern: $T + fix: $T | None + EOF + uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all + # Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax) + find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; + find . -name "*.py.bak" -type f -delete + - name: mdformat run: | uvx mdformat . + + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + package_json_file: web/package.json + run_install: false + + - name: Setup NodeJS + uses: actions/setup-node@v4 + with: + node-version: 22 + cache: pnpm + cache-dependency-path: ./web/package.json + + - name: Web dependencies + working-directory: ./web + run: pnpm install --frozen-lockfile + + - name: oxlint + working-directory: ./web + run: | + pnpx oxlint --fix + - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 diff --git a/.github/workflows/deploy-enterprise.yml b/.github/workflows/deploy-enterprise.yml index 98fa7c3b49..9cff3a3482 100644 --- a/.github/workflows/deploy-enterprise.yml +++ b/.github/workflows/deploy-enterprise.yml @@ -19,11 +19,23 @@ jobs: github.event.workflow_run.head_branch == 'deploy/enterprise' steps: - - name: Deploy to server - uses: appleboy/ssh-action@v0.1.8 - with: - host: ${{ secrets.ENTERPRISE_SSH_HOST }} - username: ${{ secrets.ENTERPRISE_SSH_USER }} - password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }} - script: | - ${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }} + - name: trigger deployments + env: + DEV_ENV_ADDRS: ${{ vars.DEV_ENV_ADDRS }} + DEPLOY_SECRET: ${{ secrets.DEPLOY_SECRET }} + run: | + IFS=',' read -ra ENDPOINTS <<< "${DEV_ENV_ADDRS:-}" + BODY='{"project":"dify-api","tag":"deploy-enterprise"}' + + for ENDPOINT in "${ENDPOINTS[@]}"; do + ENDPOINT="$(echo "$ENDPOINT" | xargs)" + [ -z "$ENDPOINT" ] && continue + + API_SIGNATURE=$(printf '%s' "$BODY" | openssl dgst -sha256 -hmac "$DEPLOY_SECRET" | awk '{print "sha256="$2}') + + curl -sSf -X POST \ + -H "Content-Type: application/json" \ + -H "X-Hub-Signature-256: $API_SIGNATURE" \ + -d "$BODY" \ + "$ENDPOINT" + done diff --git a/.gitignore b/.gitignore index bc354e639e..cbb7b4dac0 100644 --- a/.gitignore +++ b/.gitignore @@ -227,3 +227,7 @@ web/public/fallback-*.js .roo/ api/.env.backup /clickzetta + +# Benchmark +scripts/stress-test/setup/config/ +scripts/stress-test/reports/ \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md deleted file mode 120000 index 681311eb9c..0000000000 --- a/AGENTS.md +++ /dev/null @@ -1 +0,0 @@ -CLAUDE.md \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000..44f7b30360 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,87 @@ +# AGENTS.md + +## Project Overview + +Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management. + +The codebase consists of: + +- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture +- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19 +- **Docker deployment** (`/docker`): Containerized deployment configurations + +## Development Commands + +### Backend (API) + +All Python commands must be prefixed with `uv run --project api`: + +```bash +# Start development servers +./dev/start-api # Start API server +./dev/start-worker # Start Celery worker + +# Run tests +uv run --project api pytest # Run all tests +uv run --project api pytest tests/unit_tests/ # Unit tests only +uv run --project api pytest tests/integration_tests/ # Integration tests + +# Code quality +./dev/reformat # Run all formatters and linters +uv run --project api ruff check --fix ./ # Fix linting issues +uv run --project api ruff format ./ # Format code +uv run --directory api basedpyright # Type checking +``` + +### Frontend (Web) + +```bash +cd web +pnpm lint # Run ESLint +pnpm eslint-fix # Fix ESLint issues +pnpm test # Run Jest tests +``` + +## Testing Guidelines + +### Backend Testing + +- Use `pytest` for all backend tests +- Write tests first (TDD approach) +- Test structure: Arrange-Act-Assert + +## Code Style Requirements + +### Python + +- Use type hints for all functions and class attributes +- No `Any` types unless absolutely necessary +- Implement special methods (`__repr__`, `__str__`) appropriately + +### TypeScript/JavaScript + +- Strict TypeScript configuration +- ESLint with Prettier integration +- Avoid `any` type + +## Important Notes + +- **Environment Variables**: Always use UV for Python commands: `uv run --project api ` +- **Comments**: Only write meaningful comments that explain "why", not "what" +- **File Creation**: Always prefer editing existing files over creating new ones +- **Documentation**: Don't create documentation files unless explicitly requested +- **Code Quality**: Always run `./dev/reformat` before committing backend changes + +## Common Development Tasks + +### Adding a New API Endpoint + +1. Create controller in `/api/controllers/` +1. Add service logic in `/api/services/` +1. Update routes in controller's `__init__.py` +1. Write tests in `/api/tests/` + +## Project-Specific Conventions + +- All async tasks use Celery with Redis as broker +- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations. diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index aea23db703..0000000000 --- a/CLAUDE.md +++ /dev/null @@ -1,89 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management. - -The codebase consists of: - -- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture -- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19 -- **Docker deployment** (`/docker`): Containerized deployment configurations - -## Development Commands - -### Backend (API) - -All Python commands must be prefixed with `uv run --project api`: - -```bash -# Start development servers -./dev/start-api # Start API server -./dev/start-worker # Start Celery worker - -# Run tests -uv run --project api pytest # Run all tests -uv run --project api pytest tests/unit_tests/ # Unit tests only -uv run --project api pytest tests/integration_tests/ # Integration tests - -# Code quality -./dev/reformat # Run all formatters and linters -uv run --project api ruff check --fix ./ # Fix linting issues -uv run --project api ruff format ./ # Format code -uv run --directory api basedpyright # Type checking -``` - -### Frontend (Web) - -```bash -cd web -pnpm lint # Run ESLint -pnpm eslint-fix # Fix ESLint issues -pnpm test # Run Jest tests -``` - -## Testing Guidelines - -### Backend Testing - -- Use `pytest` for all backend tests -- Write tests first (TDD approach) -- Test structure: Arrange-Act-Assert - -## Code Style Requirements - -### Python - -- Use type hints for all functions and class attributes -- No `Any` types unless absolutely necessary -- Implement special methods (`__repr__`, `__str__`) appropriately - -### TypeScript/JavaScript - -- Strict TypeScript configuration -- ESLint with Prettier integration -- Avoid `any` type - -## Important Notes - -- **Environment Variables**: Always use UV for Python commands: `uv run --project api ` -- **Comments**: Only write meaningful comments that explain "why", not "what" -- **File Creation**: Always prefer editing existing files over creating new ones -- **Documentation**: Don't create documentation files unless explicitly requested -- **Code Quality**: Always run `./dev/reformat` before committing backend changes - -## Common Development Tasks - -### Adding a New API Endpoint - -1. Create controller in `/api/controllers/` -1. Add service logic in `/api/services/` -1. Update routes in controller's `__init__.py` -1. Write tests in `/api/tests/` - -## Project-Specific Conventions - -- All async tasks use Celery with Redis as broker -- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 0000000000..47dc3e3d86 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/Makefile b/Makefile index d82f6f24ad..ec7df3e03d 100644 --- a/Makefile +++ b/Makefile @@ -4,10 +4,13 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web API_IMAGE=$(DOCKER_REGISTRY)/dify-api VERSION=latest +# Default target - show help +.DEFAULT_GOAL := help + # Backend Development Environment Setup .PHONY: dev-setup prepare-docker prepare-web prepare-api -# Default dev setup target +# Dev setup target dev-setup: prepare-docker prepare-web prepare-api @echo "✅ Backend development environment setup complete!" @@ -46,6 +49,27 @@ dev-clean: @rm -rf api/storage @echo "✅ Cleanup complete" +# Backend Code Quality Commands +format: + @echo "🎨 Running ruff format..." + @uv run --project api --dev ruff format ./api + @echo "✅ Code formatting complete" + +check: + @echo "🔍 Running ruff check..." + @uv run --project api --dev ruff check ./api + @echo "✅ Code check complete" + +lint: + @echo "🔧 Running ruff format and check with fixes..." + @uv run --directory api --dev sh -c 'ruff format ./api && ruff check --fix ./api' + @echo "✅ Linting complete" + +type-check: + @echo "📝 Running type check with basedpyright..." + @uv run --directory api --dev basedpyright + @echo "✅ Type check complete" + # Build Docker images build-web: @echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..." @@ -90,6 +114,12 @@ help: @echo " make prepare-api - Set up API environment" @echo " make dev-clean - Stop Docker middleware containers" @echo "" + @echo "Backend Code Quality:" + @echo " make format - Format code with ruff" + @echo " make check - Check code with ruff" + @echo " make lint - Format and fix code with ruff" + @echo " make type-check - Run type checking with basedpyright" + @echo "" @echo "Docker Build Targets:" @echo " make build-web - Build web Docker image" @echo " make build-api - Build API Docker image" @@ -98,4 +128,4 @@ help: @echo " make build-push-all - Build and push all Docker images" # Phony targets -.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help +.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check diff --git a/api/.env.example b/api/.env.example index 8d783af134..2986402e9e 100644 --- a/api/.env.example +++ b/api/.env.example @@ -530,6 +530,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5 CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 diff --git a/api/.ruff.toml b/api/.ruff.toml index 9a15754d9a..67ad3b1449 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -5,7 +5,7 @@ line-length = 120 quote-style = "double" [lint] -preview = false +preview = true select = [ "B", # flake8-bugbear rules "C4", # flake8-comprehensions @@ -65,6 +65,7 @@ ignore = [ "B006", # mutable-argument-default "B007", # unused-loop-control-variable "B026", # star-arg-unpacking-after-keyword-arg + "B901", # allow return in yield "B903", # class-as-data-structure "B904", # raise-without-from-inside-except "B905", # zip-without-explicit-strict diff --git a/api/commands.py b/api/commands.py index 1858cb2734..58054a9adf 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,8 +1,9 @@ import base64 import json import logging +import operator import secrets -from typing import Any, Optional +from typing import Any import click import sqlalchemy as sa @@ -477,12 +478,12 @@ def convert_to_agent_apps(): click.echo(f"Converting app: {app.id}") try: - app.mode = AppMode.AGENT_CHAT.value + app.mode = AppMode.AGENT_CHAT db.session.commit() # update conversation mode to agent db.session.query(Conversation).where(Conversation.app_id == app.id).update( - {Conversation.mode: AppMode.AGENT_CHAT.value} + {Conversation.mode: AppMode.AGENT_CHAT} ) db.session.commit() @@ -639,7 +640,7 @@ def old_metadata_migration(): @click.option("--email", prompt=True, help="Tenant account email.") @click.option("--name", prompt=True, help="Workspace name.") @click.option("--language", prompt=True, help="Account language, default: en-US.") -def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None): +def create_tenant(email: str, language: str | None = None, name: str | None = None): """ Create tenant account """ @@ -953,7 +954,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style("- Deleting orphaned message_files records", fg="white")) query = "DELETE FROM message_files WHERE id IN :ids" with db.engine.begin() as conn: - conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])}) + conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)}) click.echo( click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green") ) @@ -1307,7 +1308,7 @@ def cleanup_orphaned_draft_variables( if dry_run: logger.info("DRY RUN: Would delete the following:") - for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[ + for app_id, count in sorted(stats["orphaned_by_app"].items(), key=operator.itemgetter(1), reverse=True)[ :10 ]: # Show top 10 logger.info(" App %s: %s variables", app_id, count) diff --git a/api/configs/extra/notion_config.py b/api/configs/extra/notion_config.py index f9c4d73463..9694f3db6b 100644 --- a/api/configs/extra/notion_config.py +++ b/api/configs/extra/notion_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,28 +7,28 @@ class NotionConfig(BaseSettings): Configuration settings for Notion integration """ - NOTION_CLIENT_ID: Optional[str] = Field( + NOTION_CLIENT_ID: str | None = Field( description="Client ID for Notion API authentication. Required for OAuth 2.0 flow.", default=None, ) - NOTION_CLIENT_SECRET: Optional[str] = Field( + NOTION_CLIENT_SECRET: str | None = Field( description="Client secret for Notion API authentication. Required for OAuth 2.0 flow.", default=None, ) - NOTION_INTEGRATION_TYPE: Optional[str] = Field( + NOTION_INTEGRATION_TYPE: str | None = Field( description="Type of Notion integration." " Set to 'internal' for internal integrations, or None for public integrations.", default=None, ) - NOTION_INTERNAL_SECRET: Optional[str] = Field( + NOTION_INTERNAL_SECRET: str | None = Field( description="Secret key for internal Notion integrations. Required when NOTION_INTEGRATION_TYPE is 'internal'.", default=None, ) - NOTION_INTEGRATION_TOKEN: Optional[str] = Field( + NOTION_INTEGRATION_TOKEN: str | None = Field( description="Integration token for Notion API access. Used for direct API calls without OAuth flow.", default=None, ) diff --git a/api/configs/extra/sentry_config.py b/api/configs/extra/sentry_config.py index f76a6bdb95..d72d01b49f 100644 --- a/api/configs/extra/sentry_config.py +++ b/api/configs/extra/sentry_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeFloat from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class SentryConfig(BaseSettings): Configuration settings for Sentry error tracking and performance monitoring """ - SENTRY_DSN: Optional[str] = Field( + SENTRY_DSN: str | None = Field( description="Sentry Data Source Name (DSN)." " This is the unique identifier of your Sentry project, used to send events to the correct project.", default=None, diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 899fecea7c..0b340c51e7 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import ( AliasChoices, @@ -31,6 +31,12 @@ class SecurityConfig(BaseSettings): description="Duration in minutes for which a password reset token remains valid", default=5, ) + + EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a email register token remains valid", + default=5, + ) + CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( description="Duration in minutes for which a change email token remains valid", default=5, @@ -51,7 +57,7 @@ class SecurityConfig(BaseSettings): default=False, ) - ADMIN_API_KEY: Optional[str] = Field( + ADMIN_API_KEY: str | None = Field( description="admin api key for authentication", default=None, ) @@ -91,17 +97,17 @@ class CodeExecutionSandboxConfig(BaseSettings): default="dify-sandbox", ) - CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field( + CODE_EXECUTION_CONNECT_TIMEOUT: float | None = Field( description="Connection timeout in seconds for code execution requests", default=10.0, ) - CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field( + CODE_EXECUTION_READ_TIMEOUT: float | None = Field( description="Read timeout in seconds for code execution requests", default=60.0, ) - CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field( + CODE_EXECUTION_WRITE_TIMEOUT: float | None = Field( description="Write timeout in seconds for code execution request", default=10.0, ) @@ -362,17 +368,17 @@ class HttpConfig(BaseSettings): default=3, ) - SSRF_PROXY_ALL_URL: Optional[str] = Field( + SSRF_PROXY_ALL_URL: str | None = Field( description="Proxy URL for HTTP or HTTPS requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) - SSRF_PROXY_HTTP_URL: Optional[str] = Field( + SSRF_PROXY_HTTP_URL: str | None = Field( description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) - SSRF_PROXY_HTTPS_URL: Optional[str] = Field( + SSRF_PROXY_HTTPS_URL: str | None = Field( description="Proxy URL for HTTPS requests to prevent Server-Side Request Forgery (SSRF)", default=None, ) @@ -414,7 +420,7 @@ class InnerAPIConfig(BaseSettings): default=False, ) - INNER_API_KEY: Optional[str] = Field( + INNER_API_KEY: str | None = Field( description="API key for accessing the internal API", default=None, ) @@ -430,7 +436,7 @@ class LoggingConfig(BaseSettings): default="INFO", ) - LOG_FILE: Optional[str] = Field( + LOG_FILE: str | None = Field( description="File path for log output.", default=None, ) @@ -450,12 +456,12 @@ class LoggingConfig(BaseSettings): default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", ) - LOG_DATEFORMAT: Optional[str] = Field( + LOG_DATEFORMAT: str | None = Field( description="Date format string for log timestamps", default=None, ) - LOG_TZ: Optional[str] = Field( + LOG_TZ: str | None = Field( description="Timezone for log timestamps (e.g., 'America/New_York')", default="UTC", ) @@ -589,22 +595,22 @@ class AuthConfig(BaseSettings): default="/console/api/oauth/authorize", ) - GITHUB_CLIENT_ID: Optional[str] = Field( + GITHUB_CLIENT_ID: str | None = Field( description="GitHub OAuth client ID", default=None, ) - GITHUB_CLIENT_SECRET: Optional[str] = Field( + GITHUB_CLIENT_SECRET: str | None = Field( description="GitHub OAuth client secret", default=None, ) - GOOGLE_CLIENT_ID: Optional[str] = Field( + GOOGLE_CLIENT_ID: str | None = Field( description="Google OAuth client ID", default=None, ) - GOOGLE_CLIENT_SECRET: Optional[str] = Field( + GOOGLE_CLIENT_SECRET: str | None = Field( description="Google OAuth client secret", default=None, ) @@ -639,6 +645,11 @@ class AuthConfig(BaseSettings): default=86400, ) + EMAIL_REGISTER_LOCKOUT_DURATION: PositiveInt = Field( + description="Time (in seconds) a user must wait before retrying email register after exceeding the rate limit.", + default=86400, + ) + class ModerationConfig(BaseSettings): """ @@ -667,42 +678,42 @@ class MailConfig(BaseSettings): Configuration for email services """ - MAIL_TYPE: Optional[str] = Field( + MAIL_TYPE: str | None = Field( description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.", default=None, ) - MAIL_DEFAULT_SEND_FROM: Optional[str] = Field( + MAIL_DEFAULT_SEND_FROM: str | None = Field( description="Default email address to use as the sender", default=None, ) - RESEND_API_KEY: Optional[str] = Field( + RESEND_API_KEY: str | None = Field( description="API key for Resend email service", default=None, ) - RESEND_API_URL: Optional[str] = Field( + RESEND_API_URL: str | None = Field( description="API URL for Resend email service", default=None, ) - SMTP_SERVER: Optional[str] = Field( + SMTP_SERVER: str | None = Field( description="SMTP server hostname", default=None, ) - SMTP_PORT: Optional[int] = Field( + SMTP_PORT: int | None = Field( description="SMTP server port number", default=465, ) - SMTP_USERNAME: Optional[str] = Field( + SMTP_USERNAME: str | None = Field( description="Username for SMTP authentication", default=None, ) - SMTP_PASSWORD: Optional[str] = Field( + SMTP_PASSWORD: str | None = Field( description="Password for SMTP authentication", default=None, ) @@ -722,7 +733,7 @@ class MailConfig(BaseSettings): default=50, ) - SENDGRID_API_KEY: Optional[str] = Field( + SENDGRID_API_KEY: str | None = Field( description="API key for SendGrid service", default=None, ) @@ -745,17 +756,17 @@ class RagEtlConfig(BaseSettings): default="database", ) - UNSTRUCTURED_API_URL: Optional[str] = Field( + UNSTRUCTURED_API_URL: str | None = Field( description="API URL for Unstructured.io service", default=None, ) - UNSTRUCTURED_API_KEY: Optional[str] = Field( + UNSTRUCTURED_API_KEY: str | None = Field( description="API key for Unstructured.io service", default="", ) - SCARF_NO_ANALYTICS: Optional[str] = Field( + SCARF_NO_ANALYTICS: str | None = Field( description="This is about whether to disable Scarf analytics in Unstructured library.", default="false", ) diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 18ef1ed45b..476b397ba1 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt from pydantic_settings import BaseSettings @@ -40,17 +38,17 @@ class HostedOpenAiConfig(BaseSettings): Configuration for hosted OpenAI service """ - HOSTED_OPENAI_API_KEY: Optional[str] = Field( + HOSTED_OPENAI_API_KEY: str | None = Field( description="API key for hosted OpenAI service", default=None, ) - HOSTED_OPENAI_API_BASE: Optional[str] = Field( + HOSTED_OPENAI_API_BASE: str | None = Field( description="Base URL for hosted OpenAI API", default=None, ) - HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field( + HOSTED_OPENAI_API_ORGANIZATION: str | None = Field( description="Organization ID for hosted OpenAI service", default=None, ) @@ -110,12 +108,12 @@ class HostedAzureOpenAiConfig(BaseSettings): default=False, ) - HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field( + HOSTED_AZURE_OPENAI_API_KEY: str | None = Field( description="API key for hosted Azure OpenAI service", default=None, ) - HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field( + HOSTED_AZURE_OPENAI_API_BASE: str | None = Field( description="Base URL for hosted Azure OpenAI API", default=None, ) @@ -131,12 +129,12 @@ class HostedAnthropicConfig(BaseSettings): Configuration for hosted Anthropic service """ - HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field( + HOSTED_ANTHROPIC_API_BASE: str | None = Field( description="Base URL for hosted Anthropic API", default=None, ) - HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field( + HOSTED_ANTHROPIC_API_KEY: str | None = Field( description="API key for hosted Anthropic service", default=None, ) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 591c24cbe0..dbad90270e 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,5 +1,5 @@ import os -from typing import Any, Literal, Optional +from typing import Any, Literal from urllib.parse import parse_qsl, quote_plus from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field @@ -78,18 +78,18 @@ class StorageConfig(BaseSettings): class VectorStoreConfig(BaseSettings): - VECTOR_STORE: Optional[str] = Field( + VECTOR_STORE: str | None = Field( description="Type of vector store to use for efficient similarity search." " Set to None if not using a vector store.", default=None, ) - VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field( + VECTOR_STORE_WHITELIST_ENABLE: bool | None = Field( description="Enable whitelist for vector store.", default=False, ) - VECTOR_INDEX_NAME_PREFIX: Optional[str] = Field( + VECTOR_INDEX_NAME_PREFIX: str | None = Field( description="Prefix used to create collection name in vector database", default="Vector_index", ) @@ -225,26 +225,26 @@ class CeleryConfig(DatabaseConfig): default="redis", ) - CELERY_BROKER_URL: Optional[str] = Field( + CELERY_BROKER_URL: str | None = Field( description="URL of the message broker for Celery tasks.", default=None, ) - CELERY_USE_SENTINEL: Optional[bool] = Field( + CELERY_USE_SENTINEL: bool | None = Field( description="Whether to use Redis Sentinel for high availability.", default=False, ) - CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field( + CELERY_SENTINEL_MASTER_NAME: str | None = Field( description="Name of the Redis Sentinel master.", default=None, ) - CELERY_SENTINEL_PASSWORD: Optional[str] = Field( + CELERY_SENTINEL_PASSWORD: str | None = Field( description="Password of the Redis Sentinel master.", default=None, ) - CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( + CELERY_SENTINEL_SOCKET_TIMEOUT: PositiveFloat | None = Field( description="Timeout for Redis Sentinel socket operations in seconds.", default=0.1, ) @@ -268,12 +268,12 @@ class InternalTestConfig(BaseSettings): Configuration settings for Internal Test """ - AWS_SECRET_ACCESS_KEY: Optional[str] = Field( + AWS_SECRET_ACCESS_KEY: str | None = Field( description="Internal test AWS secret access key", default=None, ) - AWS_ACCESS_KEY_ID: Optional[str] = Field( + AWS_ACCESS_KEY_ID: str | None = Field( description="Internal test AWS access key ID", default=None, ) @@ -284,15 +284,15 @@ class DatasetQueueMonitorConfig(BaseSettings): Configuration settings for Dataset Queue Monitor """ - QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field( + QUEUE_MONITOR_THRESHOLD: NonNegativeInt | None = Field( description="Threshold for dataset queue monitor", default=200, ) - QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field( + QUEUE_MONITOR_ALERT_EMAILS: str | None = Field( description="Emails for dataset queue monitor alert, separated by commas", default=None, ) - QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field( + QUEUE_MONITOR_INTERVAL: NonNegativeFloat | None = Field( description="Interval for dataset queue monitor in minutes", default=30, ) diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 16dca98cfa..4705b28c69 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt from pydantic_settings import BaseSettings @@ -19,12 +17,12 @@ class RedisConfig(BaseSettings): default=6379, ) - REDIS_USERNAME: Optional[str] = Field( + REDIS_USERNAME: str | None = Field( description="Username for Redis authentication (if required)", default=None, ) - REDIS_PASSWORD: Optional[str] = Field( + REDIS_PASSWORD: str | None = Field( description="Password for Redis authentication (if required)", default=None, ) @@ -44,47 +42,47 @@ class RedisConfig(BaseSettings): default="CERT_NONE", ) - REDIS_SSL_CA_CERTS: Optional[str] = Field( + REDIS_SSL_CA_CERTS: str | None = Field( description="Path to the CA certificate file for SSL verification", default=None, ) - REDIS_SSL_CERTFILE: Optional[str] = Field( + REDIS_SSL_CERTFILE: str | None = Field( description="Path to the client certificate file for SSL authentication", default=None, ) - REDIS_SSL_KEYFILE: Optional[str] = Field( + REDIS_SSL_KEYFILE: str | None = Field( description="Path to the client private key file for SSL authentication", default=None, ) - REDIS_USE_SENTINEL: Optional[bool] = Field( + REDIS_USE_SENTINEL: bool | None = Field( description="Enable Redis Sentinel mode for high availability", default=False, ) - REDIS_SENTINELS: Optional[str] = Field( + REDIS_SENTINELS: str | None = Field( description="Comma-separated list of Redis Sentinel nodes (host:port)", default=None, ) - REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field( + REDIS_SENTINEL_SERVICE_NAME: str | None = Field( description="Name of the Redis Sentinel service to monitor", default=None, ) - REDIS_SENTINEL_USERNAME: Optional[str] = Field( + REDIS_SENTINEL_USERNAME: str | None = Field( description="Username for Redis Sentinel authentication (if required)", default=None, ) - REDIS_SENTINEL_PASSWORD: Optional[str] = Field( + REDIS_SENTINEL_PASSWORD: str | None = Field( description="Password for Redis Sentinel authentication (if required)", default=None, ) - REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( + REDIS_SENTINEL_SOCKET_TIMEOUT: PositiveFloat | None = Field( description="Socket timeout in seconds for Redis Sentinel connections", default=0.1, ) @@ -94,12 +92,12 @@ class RedisConfig(BaseSettings): default=False, ) - REDIS_CLUSTERS: Optional[str] = Field( + REDIS_CLUSTERS: str | None = Field( description="Comma-separated list of Redis Clusters nodes (host:port)", default=None, ) - REDIS_CLUSTERS_PASSWORD: Optional[str] = Field( + REDIS_CLUSTERS_PASSWORD: str | None = Field( description="Password for Redis Clusters authentication (if required)", default=None, ) diff --git a/api/configs/middleware/storage/aliyun_oss_storage_config.py b/api/configs/middleware/storage/aliyun_oss_storage_config.py index 07eb527170..331c486d54 100644 --- a/api/configs/middleware/storage/aliyun_oss_storage_config.py +++ b/api/configs/middleware/storage/aliyun_oss_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,37 +7,37 @@ class AliyunOSSStorageConfig(BaseSettings): Configuration settings for Aliyun Object Storage Service (OSS) """ - ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field( + ALIYUN_OSS_BUCKET_NAME: str | None = Field( description="Name of the Aliyun OSS bucket to store and retrieve objects", default=None, ) - ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field( + ALIYUN_OSS_ACCESS_KEY: str | None = Field( description="Access key ID for authenticating with Aliyun OSS", default=None, ) - ALIYUN_OSS_SECRET_KEY: Optional[str] = Field( + ALIYUN_OSS_SECRET_KEY: str | None = Field( description="Secret access key for authenticating with Aliyun OSS", default=None, ) - ALIYUN_OSS_ENDPOINT: Optional[str] = Field( + ALIYUN_OSS_ENDPOINT: str | None = Field( description="URL of the Aliyun OSS endpoint for your chosen region", default=None, ) - ALIYUN_OSS_REGION: Optional[str] = Field( + ALIYUN_OSS_REGION: str | None = Field( description="Aliyun OSS region where your bucket is located (e.g., 'oss-cn-hangzhou')", default=None, ) - ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field( + ALIYUN_OSS_AUTH_VERSION: str | None = Field( description="Version of the authentication protocol to use with Aliyun OSS (e.g., 'v4')", default=None, ) - ALIYUN_OSS_PATH: Optional[str] = Field( + ALIYUN_OSS_PATH: str | None = Field( description="Base path within the bucket to store objects (e.g., 'my-app-data/')", default=None, ) diff --git a/api/configs/middleware/storage/amazon_s3_storage_config.py b/api/configs/middleware/storage/amazon_s3_storage_config.py index e14c210718..9277a335f7 100644 --- a/api/configs/middleware/storage/amazon_s3_storage_config.py +++ b/api/configs/middleware/storage/amazon_s3_storage_config.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +9,27 @@ class S3StorageConfig(BaseSettings): Configuration settings for S3-compatible object storage """ - S3_ENDPOINT: Optional[str] = Field( + S3_ENDPOINT: str | None = Field( description="URL of the S3-compatible storage endpoint (e.g., 'https://s3.amazonaws.com')", default=None, ) - S3_REGION: Optional[str] = Field( + S3_REGION: str | None = Field( description="Region where the S3 bucket is located (e.g., 'us-east-1')", default=None, ) - S3_BUCKET_NAME: Optional[str] = Field( + S3_BUCKET_NAME: str | None = Field( description="Name of the S3 bucket to store and retrieve objects", default=None, ) - S3_ACCESS_KEY: Optional[str] = Field( + S3_ACCESS_KEY: str | None = Field( description="Access key ID for authenticating with the S3 service", default=None, ) - S3_SECRET_KEY: Optional[str] = Field( + S3_SECRET_KEY: str | None = Field( description="Secret access key for authenticating with the S3 service", default=None, ) diff --git a/api/configs/middleware/storage/azure_blob_storage_config.py b/api/configs/middleware/storage/azure_blob_storage_config.py index b7ab5247a9..7195d446b1 100644 --- a/api/configs/middleware/storage/azure_blob_storage_config.py +++ b/api/configs/middleware/storage/azure_blob_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class AzureBlobStorageConfig(BaseSettings): Configuration settings for Azure Blob Storage """ - AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field( + AZURE_BLOB_ACCOUNT_NAME: str | None = Field( description="Name of the Azure Storage account (e.g., 'mystorageaccount')", default=None, ) - AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field( + AZURE_BLOB_ACCOUNT_KEY: str | None = Field( description="Access key for authenticating with the Azure Storage account", default=None, ) - AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field( + AZURE_BLOB_CONTAINER_NAME: str | None = Field( description="Name of the Azure Blob container to store and retrieve objects", default=None, ) - AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field( + AZURE_BLOB_ACCOUNT_URL: str | None = Field( description="URL of the Azure Blob storage endpoint (e.g., 'https://mystorageaccount.blob.core.windows.net')", default=None, ) diff --git a/api/configs/middleware/storage/baidu_obs_storage_config.py b/api/configs/middleware/storage/baidu_obs_storage_config.py index e7913b0acc..138a0db650 100644 --- a/api/configs/middleware/storage/baidu_obs_storage_config.py +++ b/api/configs/middleware/storage/baidu_obs_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class BaiduOBSStorageConfig(BaseSettings): Configuration settings for Baidu Object Storage Service (OBS) """ - BAIDU_OBS_BUCKET_NAME: Optional[str] = Field( + BAIDU_OBS_BUCKET_NAME: str | None = Field( description="Name of the Baidu OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')", default=None, ) - BAIDU_OBS_ACCESS_KEY: Optional[str] = Field( + BAIDU_OBS_ACCESS_KEY: str | None = Field( description="Access Key ID for authenticating with Baidu OBS", default=None, ) - BAIDU_OBS_SECRET_KEY: Optional[str] = Field( + BAIDU_OBS_SECRET_KEY: str | None = Field( description="Secret Access Key for authenticating with Baidu OBS", default=None, ) - BAIDU_OBS_ENDPOINT: Optional[str] = Field( + BAIDU_OBS_ENDPOINT: str | None = Field( description="URL of the Baidu OSS endpoint for your chosen region (e.g., 'https://.bj.bcebos.com')", default=None, ) diff --git a/api/configs/middleware/storage/clickzetta_volume_storage_config.py b/api/configs/middleware/storage/clickzetta_volume_storage_config.py index 56e1b6a957..035650d98a 100644 --- a/api/configs/middleware/storage/clickzetta_volume_storage_config.py +++ b/api/configs/middleware/storage/clickzetta_volume_storage_config.py @@ -1,7 +1,5 @@ """ClickZetta Volume Storage Configuration""" -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,17 +7,17 @@ from pydantic_settings import BaseSettings class ClickZettaVolumeStorageConfig(BaseSettings): """Configuration for ClickZetta Volume storage.""" - CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field( + CLICKZETTA_VOLUME_USERNAME: str | None = Field( description="Username for ClickZetta Volume authentication", default=None, ) - CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field( + CLICKZETTA_VOLUME_PASSWORD: str | None = Field( description="Password for ClickZetta Volume authentication", default=None, ) - CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field( + CLICKZETTA_VOLUME_INSTANCE: str | None = Field( description="ClickZetta instance identifier", default=None, ) @@ -49,7 +47,7 @@ class ClickZettaVolumeStorageConfig(BaseSettings): default="user", ) - CLICKZETTA_VOLUME_NAME: Optional[str] = Field( + CLICKZETTA_VOLUME_NAME: str | None = Field( description="ClickZetta volume name for external volumes", default=None, ) diff --git a/api/configs/middleware/storage/google_cloud_storage_config.py b/api/configs/middleware/storage/google_cloud_storage_config.py index e5d763d7f5..a63eb798a8 100644 --- a/api/configs/middleware/storage/google_cloud_storage_config.py +++ b/api/configs/middleware/storage/google_cloud_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class GoogleCloudStorageConfig(BaseSettings): Configuration settings for Google Cloud Storage """ - GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field( + GOOGLE_STORAGE_BUCKET_NAME: str | None = Field( description="Name of the Google Cloud Storage bucket to store and retrieve objects (e.g., 'my-gcs-bucket')", default=None, ) - GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field( + GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: str | None = Field( description="Base64-encoded JSON key file for Google Cloud service account authentication", default=None, ) diff --git a/api/configs/middleware/storage/huawei_obs_storage_config.py b/api/configs/middleware/storage/huawei_obs_storage_config.py index be983b5187..5b5cd2f750 100644 --- a/api/configs/middleware/storage/huawei_obs_storage_config.py +++ b/api/configs/middleware/storage/huawei_obs_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class HuaweiCloudOBSStorageConfig(BaseSettings): Configuration settings for Huawei Cloud Object Storage Service (OBS) """ - HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field( + HUAWEI_OBS_BUCKET_NAME: str | None = Field( description="Name of the Huawei Cloud OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')", default=None, ) - HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field( + HUAWEI_OBS_ACCESS_KEY: str | None = Field( description="Access Key ID for authenticating with Huawei Cloud OBS", default=None, ) - HUAWEI_OBS_SECRET_KEY: Optional[str] = Field( + HUAWEI_OBS_SECRET_KEY: str | None = Field( description="Secret Access Key for authenticating with Huawei Cloud OBS", default=None, ) - HUAWEI_OBS_SERVER: Optional[str] = Field( + HUAWEI_OBS_SERVER: str | None = Field( description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')", default=None, ) diff --git a/api/configs/middleware/storage/oci_storage_config.py b/api/configs/middleware/storage/oci_storage_config.py index edc245bcac..70815a0055 100644 --- a/api/configs/middleware/storage/oci_storage_config.py +++ b/api/configs/middleware/storage/oci_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class OCIStorageConfig(BaseSettings): Configuration settings for Oracle Cloud Infrastructure (OCI) Object Storage """ - OCI_ENDPOINT: Optional[str] = Field( + OCI_ENDPOINT: str | None = Field( description="URL of the OCI Object Storage endpoint (e.g., 'https://objectstorage.us-phoenix-1.oraclecloud.com')", default=None, ) - OCI_REGION: Optional[str] = Field( + OCI_REGION: str | None = Field( description="OCI region where the bucket is located (e.g., 'us-phoenix-1')", default=None, ) - OCI_BUCKET_NAME: Optional[str] = Field( + OCI_BUCKET_NAME: str | None = Field( description="Name of the OCI Object Storage bucket to store and retrieve objects (e.g., 'my-oci-bucket')", default=None, ) - OCI_ACCESS_KEY: Optional[str] = Field( + OCI_ACCESS_KEY: str | None = Field( description="Access key (also known as API key) for authenticating with OCI Object Storage", default=None, ) - OCI_SECRET_KEY: Optional[str] = Field( + OCI_SECRET_KEY: str | None = Field( description="Secret key associated with the access key for authenticating with OCI Object Storage", default=None, ) diff --git a/api/configs/middleware/storage/supabase_storage_config.py b/api/configs/middleware/storage/supabase_storage_config.py index dcf7c20cf9..7f140fc5b9 100644 --- a/api/configs/middleware/storage/supabase_storage_config.py +++ b/api/configs/middleware/storage/supabase_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,17 +7,17 @@ class SupabaseStorageConfig(BaseSettings): Configuration settings for Supabase Object Storage Service """ - SUPABASE_BUCKET_NAME: Optional[str] = Field( + SUPABASE_BUCKET_NAME: str | None = Field( description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')", default=None, ) - SUPABASE_API_KEY: Optional[str] = Field( + SUPABASE_API_KEY: str | None = Field( description="API KEY for authenticating with Supabase", default=None, ) - SUPABASE_URL: Optional[str] = Field( + SUPABASE_URL: str | None = Field( description="URL of the Supabase", default=None, ) diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py index 255c4e8938..e297e748e9 100644 --- a/api/configs/middleware/storage/tencent_cos_storage_config.py +++ b/api/configs/middleware/storage/tencent_cos_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class TencentCloudCOSStorageConfig(BaseSettings): Configuration settings for Tencent Cloud Object Storage (COS) """ - TENCENT_COS_BUCKET_NAME: Optional[str] = Field( + TENCENT_COS_BUCKET_NAME: str | None = Field( description="Name of the Tencent Cloud COS bucket to store and retrieve objects", default=None, ) - TENCENT_COS_REGION: Optional[str] = Field( + TENCENT_COS_REGION: str | None = Field( description="Tencent Cloud region where the COS bucket is located (e.g., 'ap-guangzhou')", default=None, ) - TENCENT_COS_SECRET_ID: Optional[str] = Field( + TENCENT_COS_SECRET_ID: str | None = Field( description="SecretId for authenticating with Tencent Cloud COS (part of API credentials)", default=None, ) - TENCENT_COS_SECRET_KEY: Optional[str] = Field( + TENCENT_COS_SECRET_KEY: str | None = Field( description="SecretKey for authenticating with Tencent Cloud COS (part of API credentials)", default=None, ) - TENCENT_COS_SCHEME: Optional[str] = Field( + TENCENT_COS_SCHEME: str | None = Field( description="Protocol scheme for COS requests: 'https' (recommended) or 'http'", default=None, ) diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py index 06c3ae4d3e..be01f2dc36 100644 --- a/api/configs/middleware/storage/volcengine_tos_storage_config.py +++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class VolcengineTOSStorageConfig(BaseSettings): Configuration settings for Volcengine Tinder Object Storage (TOS) """ - VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field( + VOLCENGINE_TOS_BUCKET_NAME: str | None = Field( description="Name of the Volcengine TOS bucket to store and retrieve objects (e.g., 'my-tos-bucket')", default=None, ) - VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field( + VOLCENGINE_TOS_ACCESS_KEY: str | None = Field( description="Access Key ID for authenticating with Volcengine TOS", default=None, ) - VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field( + VOLCENGINE_TOS_SECRET_KEY: str | None = Field( description="Secret Access Key for authenticating with Volcengine TOS", default=None, ) - VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field( + VOLCENGINE_TOS_ENDPOINT: str | None = Field( description="URL of the Volcengine TOS endpoint (e.g., 'https://tos-cn-beijing.volces.com')", default=None, ) - VOLCENGINE_TOS_REGION: Optional[str] = Field( + VOLCENGINE_TOS_REGION: str | None = Field( description="Volcengine region where the TOS bucket is located (e.g., 'cn-beijing')", default=None, ) diff --git a/api/configs/middleware/vdb/analyticdb_config.py b/api/configs/middleware/vdb/analyticdb_config.py index cb8dc7d724..539b9c0963 100644 --- a/api/configs/middleware/vdb/analyticdb_config.py +++ b/api/configs/middleware/vdb/analyticdb_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -11,37 +9,37 @@ class AnalyticdbConfig(BaseSettings): https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled """ - ANALYTICDB_KEY_ID: Optional[str] = Field( + ANALYTICDB_KEY_ID: str | None = Field( default=None, description="The Access Key ID provided by Alibaba Cloud for API authentication." ) - ANALYTICDB_KEY_SECRET: Optional[str] = Field( + ANALYTICDB_KEY_SECRET: str | None = Field( default=None, description="The Secret Access Key corresponding to the Access Key ID for secure API access." ) - ANALYTICDB_REGION_ID: Optional[str] = Field( + ANALYTICDB_REGION_ID: str | None = Field( default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou', 'ap-southeast-1').", ) - ANALYTICDB_INSTANCE_ID: Optional[str] = Field( + ANALYTICDB_INSTANCE_ID: str | None = Field( default=None, description="The unique identifier of the AnalyticDB instance you want to connect to.", ) - ANALYTICDB_ACCOUNT: Optional[str] = Field( + ANALYTICDB_ACCOUNT: str | None = Field( default=None, description="The account name used to log in to the AnalyticDB instance" " (usually the initial account created with the instance).", ) - ANALYTICDB_PASSWORD: Optional[str] = Field( + ANALYTICDB_PASSWORD: str | None = Field( default=None, description="The password associated with the AnalyticDB account for database authentication." ) - ANALYTICDB_NAMESPACE: Optional[str] = Field( + ANALYTICDB_NAMESPACE: str | None = Field( default=None, description="The namespace within AnalyticDB for schema isolation (if using namespace feature)." ) - ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field( + ANALYTICDB_NAMESPACE_PASSWORD: str | None = Field( default=None, description="The password for accessing the specified namespace within the AnalyticDB instance" " (if namespace feature is enabled).", ) - ANALYTICDB_HOST: Optional[str] = Field( + ANALYTICDB_HOST: str | None = Field( default=None, description="The host of the AnalyticDB instance you want to connect to." ) ANALYTICDB_PORT: PositiveInt = Field( diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py index 44742c2e2f..4b6ddb3bde 100644 --- a/api/configs/middleware/vdb/baidu_vector_config.py +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class BaiduVectorDBConfig(BaseSettings): Configuration settings for Baidu Vector Database """ - BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field( + BAIDU_VECTOR_DB_ENDPOINT: str | None = Field( description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')", default=None, ) @@ -19,17 +17,17 @@ class BaiduVectorDBConfig(BaseSettings): default=30000, ) - BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field( + BAIDU_VECTOR_DB_ACCOUNT: str | None = Field( description="Account for authenticating with the Baidu Vector Database", default=None, ) - BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field( + BAIDU_VECTOR_DB_API_KEY: str | None = Field( description="API key for authenticating with the Baidu Vector Database service", default=None, ) - BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field( + BAIDU_VECTOR_DB_DATABASE: str | None = Field( description="Name of the specific Baidu Vector Database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/chroma_config.py b/api/configs/middleware/vdb/chroma_config.py index e83a9902de..3a78980b91 100644 --- a/api/configs/middleware/vdb/chroma_config.py +++ b/api/configs/middleware/vdb/chroma_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class ChromaConfig(BaseSettings): Configuration settings for Chroma vector database """ - CHROMA_HOST: Optional[str] = Field( + CHROMA_HOST: str | None = Field( description="Hostname or IP address of the Chroma server (e.g., 'localhost' or '192.168.1.100')", default=None, ) @@ -19,22 +17,22 @@ class ChromaConfig(BaseSettings): default=8000, ) - CHROMA_TENANT: Optional[str] = Field( + CHROMA_TENANT: str | None = Field( description="Tenant identifier for multi-tenancy support in Chroma", default=None, ) - CHROMA_DATABASE: Optional[str] = Field( + CHROMA_DATABASE: str | None = Field( description="Name of the Chroma database to connect to", default=None, ) - CHROMA_AUTH_PROVIDER: Optional[str] = Field( + CHROMA_AUTH_PROVIDER: str | None = Field( description="Authentication provider for Chroma (e.g., 'basic', 'token', or a custom provider)", default=None, ) - CHROMA_AUTH_CREDENTIALS: Optional[str] = Field( + CHROMA_AUTH_CREDENTIALS: str | None = Field( description="Authentication credentials for Chroma (format depends on the auth provider)", default=None, ) diff --git a/api/configs/middleware/vdb/clickzetta_config.py b/api/configs/middleware/vdb/clickzetta_config.py index 61bc01202b..e8172b5299 100644 --- a/api/configs/middleware/vdb/clickzetta_config.py +++ b/api/configs/middleware/vdb/clickzetta_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,62 +7,62 @@ class ClickzettaConfig(BaseSettings): Clickzetta Lakehouse vector database configuration """ - CLICKZETTA_USERNAME: Optional[str] = Field( + CLICKZETTA_USERNAME: str | None = Field( description="Username for authenticating with Clickzetta Lakehouse", default=None, ) - CLICKZETTA_PASSWORD: Optional[str] = Field( + CLICKZETTA_PASSWORD: str | None = Field( description="Password for authenticating with Clickzetta Lakehouse", default=None, ) - CLICKZETTA_INSTANCE: Optional[str] = Field( + CLICKZETTA_INSTANCE: str | None = Field( description="Clickzetta Lakehouse instance ID", default=None, ) - CLICKZETTA_SERVICE: Optional[str] = Field( + CLICKZETTA_SERVICE: str | None = Field( description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')", default="api.clickzetta.com", ) - CLICKZETTA_WORKSPACE: Optional[str] = Field( + CLICKZETTA_WORKSPACE: str | None = Field( description="Clickzetta workspace name", default="default", ) - CLICKZETTA_VCLUSTER: Optional[str] = Field( + CLICKZETTA_VCLUSTER: str | None = Field( description="Clickzetta virtual cluster name", default="default_ap", ) - CLICKZETTA_SCHEMA: Optional[str] = Field( + CLICKZETTA_SCHEMA: str | None = Field( description="Database schema name in Clickzetta", default="public", ) - CLICKZETTA_BATCH_SIZE: Optional[int] = Field( + CLICKZETTA_BATCH_SIZE: int | None = Field( description="Batch size for bulk insert operations", default=100, ) - CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field( + CLICKZETTA_ENABLE_INVERTED_INDEX: bool | None = Field( description="Enable inverted index for full-text search capabilities", default=True, ) - CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field( + CLICKZETTA_ANALYZER_TYPE: str | None = Field( description="Analyzer type for full-text search: keyword, english, chinese, unicode", default="chinese", ) - CLICKZETTA_ANALYZER_MODE: Optional[str] = Field( + CLICKZETTA_ANALYZER_MODE: str | None = Field( description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)", default="smart", ) - CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field( + CLICKZETTA_VECTOR_DISTANCE_FUNCTION: str | None = Field( description="Distance function for vector similarity: l2_distance or cosine_distance", default="cosine_distance", ) diff --git a/api/configs/middleware/vdb/couchbase_config.py b/api/configs/middleware/vdb/couchbase_config.py index b81cbf8959..a365e30263 100644 --- a/api/configs/middleware/vdb/couchbase_config.py +++ b/api/configs/middleware/vdb/couchbase_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class CouchbaseConfig(BaseSettings): Couchbase configs """ - COUCHBASE_CONNECTION_STRING: Optional[str] = Field( + COUCHBASE_CONNECTION_STRING: str | None = Field( description="COUCHBASE connection string", default=None, ) - COUCHBASE_USER: Optional[str] = Field( + COUCHBASE_USER: str | None = Field( description="COUCHBASE user", default=None, ) - COUCHBASE_PASSWORD: Optional[str] = Field( + COUCHBASE_PASSWORD: str | None = Field( description="COUCHBASE password", default=None, ) - COUCHBASE_BUCKET_NAME: Optional[str] = Field( + COUCHBASE_BUCKET_NAME: str | None = Field( description="COUCHBASE bucket name", default=None, ) - COUCHBASE_SCOPE_NAME: Optional[str] = Field( + COUCHBASE_SCOPE_NAME: str | None = Field( description="COUCHBASE scope name", default=None, ) diff --git a/api/configs/middleware/vdb/elasticsearch_config.py b/api/configs/middleware/vdb/elasticsearch_config.py index 8c4b333d45..a0efd41417 100644 --- a/api/configs/middleware/vdb/elasticsearch_config.py +++ b/api/configs/middleware/vdb/elasticsearch_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt, model_validator from pydantic_settings import BaseSettings @@ -10,7 +8,7 @@ class ElasticsearchConfig(BaseSettings): Can load from environment variables or .env files. """ - ELASTICSEARCH_HOST: Optional[str] = Field( + ELASTICSEARCH_HOST: str | None = Field( description="Hostname or IP address of the Elasticsearch server (e.g., 'localhost' or '192.168.1.100')", default="127.0.0.1", ) @@ -20,30 +18,28 @@ class ElasticsearchConfig(BaseSettings): default=9200, ) - ELASTICSEARCH_USERNAME: Optional[str] = Field( + ELASTICSEARCH_USERNAME: str | None = Field( description="Username for authenticating with Elasticsearch (default is 'elastic')", default="elastic", ) - ELASTICSEARCH_PASSWORD: Optional[str] = Field( + ELASTICSEARCH_PASSWORD: str | None = Field( description="Password for authenticating with Elasticsearch (default is 'elastic')", default="elastic", ) # Elastic Cloud (optional) - ELASTICSEARCH_USE_CLOUD: Optional[bool] = Field( + ELASTICSEARCH_USE_CLOUD: bool | None = Field( description="Set to True to use Elastic Cloud instead of self-hosted Elasticsearch", default=False ) - ELASTICSEARCH_CLOUD_URL: Optional[str] = Field( + ELASTICSEARCH_CLOUD_URL: str | None = Field( description="Full URL for Elastic Cloud deployment (e.g., 'https://example.es.region.aws.found.io:443')", default=None, ) - ELASTICSEARCH_API_KEY: Optional[str] = Field( - description="API key for authenticating with Elastic Cloud", default=None - ) + ELASTICSEARCH_API_KEY: str | None = Field(description="API key for authenticating with Elastic Cloud", default=None) # Common options - ELASTICSEARCH_CA_CERTS: Optional[str] = Field( + ELASTICSEARCH_CA_CERTS: str | None = Field( description="Path to CA certificate file for SSL verification", default=None ) ELASTICSEARCH_VERIFY_CERTS: bool = Field( diff --git a/api/configs/middleware/vdb/huawei_cloud_config.py b/api/configs/middleware/vdb/huawei_cloud_config.py index 2290c60499..d64cb870fa 100644 --- a/api/configs/middleware/vdb/huawei_cloud_config.py +++ b/api/configs/middleware/vdb/huawei_cloud_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,17 +7,17 @@ class HuaweiCloudConfig(BaseSettings): Configuration settings for Huawei cloud search service """ - HUAWEI_CLOUD_HOSTS: Optional[str] = Field( + HUAWEI_CLOUD_HOSTS: str | None = Field( description="Hostname or IP address of the Huawei cloud search service instance", default=None, ) - HUAWEI_CLOUD_USER: Optional[str] = Field( + HUAWEI_CLOUD_USER: str | None = Field( description="Username for authenticating with Huawei cloud search service", default=None, ) - HUAWEI_CLOUD_PASSWORD: Optional[str] = Field( + HUAWEI_CLOUD_PASSWORD: str | None = Field( description="Password for authenticating with Huawei cloud search service", default=None, ) diff --git a/api/configs/middleware/vdb/lindorm_config.py b/api/configs/middleware/vdb/lindorm_config.py index e80e3f4a35..262d5a1f26 100644 --- a/api/configs/middleware/vdb/lindorm_config.py +++ b/api/configs/middleware/vdb/lindorm_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class LindormConfig(BaseSettings): Lindorm configs """ - LINDORM_URL: Optional[str] = Field( + LINDORM_URL: str | None = Field( description="Lindorm url", default=None, ) - LINDORM_USERNAME: Optional[str] = Field( + LINDORM_USERNAME: str | None = Field( description="Lindorm user", default=None, ) - LINDORM_PASSWORD: Optional[str] = Field( + LINDORM_PASSWORD: str | None = Field( description="Lindorm password", default=None, ) - DEFAULT_INDEX_TYPE: Optional[str] = Field( + DEFAULT_INDEX_TYPE: str | None = Field( description="Lindorm Vector Index Type, hnsw or flat is available in dify", default="hnsw", ) - DEFAULT_DISTANCE_TYPE: Optional[str] = Field( + DEFAULT_DISTANCE_TYPE: str | None = Field( description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2" ) - USING_UGC_INDEX: Optional[bool] = Field( + USING_UGC_INDEX: bool | None = Field( description="Using UGC index will store the same type of Index in a single index but can retrieve separately.", default=False, ) - LINDORM_QUERY_TIMEOUT: Optional[float] = Field(description="The lindorm search request timeout (s)", default=2.0) + LINDORM_QUERY_TIMEOUT: float | None = Field(description="The lindorm search request timeout (s)", default=2.0) diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index d398ef5bd8..05cee51cc9 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class MilvusConfig(BaseSettings): Configuration settings for Milvus vector database """ - MILVUS_URI: Optional[str] = Field( + MILVUS_URI: str | None = Field( description="URI for connecting to the Milvus server (e.g., 'http://localhost:19530' or 'https://milvus-instance.example.com:19530')", default="http://127.0.0.1:19530", ) - MILVUS_TOKEN: Optional[str] = Field( + MILVUS_TOKEN: str | None = Field( description="Authentication token for Milvus, if token-based authentication is enabled", default=None, ) - MILVUS_USER: Optional[str] = Field( + MILVUS_USER: str | None = Field( description="Username for authenticating with Milvus, if username/password authentication is enabled", default=None, ) - MILVUS_PASSWORD: Optional[str] = Field( + MILVUS_PASSWORD: str | None = Field( description="Password for authenticating with Milvus, if username/password authentication is enabled", default=None, ) @@ -40,7 +38,7 @@ class MilvusConfig(BaseSettings): default=True, ) - MILVUS_ANALYZER_PARAMS: Optional[str] = Field( + MILVUS_ANALYZER_PARAMS: str | None = Field( description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.', default=None, ) diff --git a/api/configs/middleware/vdb/oceanbase_config.py b/api/configs/middleware/vdb/oceanbase_config.py index 9b11a22732..8437328e76 100644 --- a/api/configs/middleware/vdb/oceanbase_config.py +++ b/api/configs/middleware/vdb/oceanbase_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class OceanBaseVectorConfig(BaseSettings): Configuration settings for OceanBase Vector database """ - OCEANBASE_VECTOR_HOST: Optional[str] = Field( + OCEANBASE_VECTOR_HOST: str | None = Field( description="Hostname or IP address of the OceanBase Vector server (e.g. 'localhost')", default=None, ) - OCEANBASE_VECTOR_PORT: Optional[PositiveInt] = Field( + OCEANBASE_VECTOR_PORT: PositiveInt | None = Field( description="Port number on which the OceanBase Vector server is listening (default is 2881)", default=2881, ) - OCEANBASE_VECTOR_USER: Optional[str] = Field( + OCEANBASE_VECTOR_USER: str | None = Field( description="Username for authenticating with the OceanBase Vector database", default=None, ) - OCEANBASE_VECTOR_PASSWORD: Optional[str] = Field( + OCEANBASE_VECTOR_PASSWORD: str | None = Field( description="Password for authenticating with the OceanBase Vector database", default=None, ) - OCEANBASE_VECTOR_DATABASE: Optional[str] = Field( + OCEANBASE_VECTOR_DATABASE: str | None = Field( description="Name of the OceanBase Vector database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/opengauss_config.py b/api/configs/middleware/vdb/opengauss_config.py index 87ea292ab4..b57c1e59a9 100644 --- a/api/configs/middleware/vdb/opengauss_config.py +++ b/api/configs/middleware/vdb/opengauss_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class OpenGaussConfig(BaseSettings): Configuration settings for OpenGauss """ - OPENGAUSS_HOST: Optional[str] = Field( + OPENGAUSS_HOST: str | None = Field( description="Hostname or IP address of the OpenGauss server(e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class OpenGaussConfig(BaseSettings): default=6600, ) - OPENGAUSS_USER: Optional[str] = Field( + OPENGAUSS_USER: str | None = Field( description="Username for authenticating with the OpenGauss database", default=None, ) - OPENGAUSS_PASSWORD: Optional[str] = Field( + OPENGAUSS_PASSWORD: str | None = Field( description="Password for authenticating with the OpenGauss database", default=None, ) - OPENGAUSS_DATABASE: Optional[str] = Field( + OPENGAUSS_DATABASE: str | None = Field( description="Name of the OpenGauss database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 9fd9b60194..ba015a6eb9 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -1,5 +1,5 @@ -import enum -from typing import Literal, Optional +from enum import Enum +from typing import Literal from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -10,7 +10,7 @@ class OpenSearchConfig(BaseSettings): Configuration settings for OpenSearch """ - class AuthMethod(enum.StrEnum): + class AuthMethod(Enum): """ Authentication method for OpenSearch """ @@ -18,7 +18,7 @@ class OpenSearchConfig(BaseSettings): BASIC = "basic" AWS_MANAGED_IAM = "aws_managed_iam" - OPENSEARCH_HOST: Optional[str] = Field( + OPENSEARCH_HOST: str | None = Field( description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", default=None, ) @@ -43,21 +43,21 @@ class OpenSearchConfig(BaseSettings): default=AuthMethod.BASIC, ) - OPENSEARCH_USER: Optional[str] = Field( + OPENSEARCH_USER: str | None = Field( description="Username for authenticating with OpenSearch", default=None, ) - OPENSEARCH_PASSWORD: Optional[str] = Field( + OPENSEARCH_PASSWORD: str | None = Field( description="Password for authenticating with OpenSearch", default=None, ) - OPENSEARCH_AWS_REGION: Optional[str] = Field( + OPENSEARCH_AWS_REGION: str | None = Field( description="AWS region for OpenSearch (e.g. 'us-west-2')", default=None, ) - OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field( + OPENSEARCH_AWS_SERVICE: Literal["es", "aoss"] | None = Field( description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None ) diff --git a/api/configs/middleware/vdb/oracle_config.py b/api/configs/middleware/vdb/oracle_config.py index ea39909ef4..dc179e8e4f 100644 --- a/api/configs/middleware/vdb/oracle_config.py +++ b/api/configs/middleware/vdb/oracle_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,33 +7,33 @@ class OracleConfig(BaseSettings): Configuration settings for Oracle database """ - ORACLE_USER: Optional[str] = Field( + ORACLE_USER: str | None = Field( description="Username for authenticating with the Oracle database", default=None, ) - ORACLE_PASSWORD: Optional[str] = Field( + ORACLE_PASSWORD: str | None = Field( description="Password for authenticating with the Oracle database", default=None, ) - ORACLE_DSN: Optional[str] = Field( + ORACLE_DSN: str | None = Field( description="Oracle database connection string. For traditional database, use format 'host:port/service_name'. " "For autonomous database, use the service name from tnsnames.ora in the wallet", default=None, ) - ORACLE_CONFIG_DIR: Optional[str] = Field( + ORACLE_CONFIG_DIR: str | None = Field( description="Directory containing the tnsnames.ora configuration file. Only used in thin mode connection", default=None, ) - ORACLE_WALLET_LOCATION: Optional[str] = Field( + ORACLE_WALLET_LOCATION: str | None = Field( description="Oracle wallet directory path containing the wallet files for secure connection", default=None, ) - ORACLE_WALLET_PASSWORD: Optional[str] = Field( + ORACLE_WALLET_PASSWORD: str | None = Field( description="Password to decrypt the Oracle wallet, if it is encrypted", default=None, ) diff --git a/api/configs/middleware/vdb/pgvector_config.py b/api/configs/middleware/vdb/pgvector_config.py index 9f5f7284d7..62334636a5 100644 --- a/api/configs/middleware/vdb/pgvector_config.py +++ b/api/configs/middleware/vdb/pgvector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class PGVectorConfig(BaseSettings): Configuration settings for PGVector (PostgreSQL with vector extension) """ - PGVECTOR_HOST: Optional[str] = Field( + PGVECTOR_HOST: str | None = Field( description="Hostname or IP address of the PostgreSQL server with PGVector extension (e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class PGVectorConfig(BaseSettings): default=5433, ) - PGVECTOR_USER: Optional[str] = Field( + PGVECTOR_USER: str | None = Field( description="Username for authenticating with the PostgreSQL database", default=None, ) - PGVECTOR_PASSWORD: Optional[str] = Field( + PGVECTOR_PASSWORD: str | None = Field( description="Password for authenticating with the PostgreSQL database", default=None, ) - PGVECTOR_DATABASE: Optional[str] = Field( + PGVECTOR_DATABASE: str | None = Field( description="Name of the PostgreSQL database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/pgvectors_config.py b/api/configs/middleware/vdb/pgvectors_config.py index fa3bca5bb7..7bc144c4ab 100644 --- a/api/configs/middleware/vdb/pgvectors_config.py +++ b/api/configs/middleware/vdb/pgvectors_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class PGVectoRSConfig(BaseSettings): Configuration settings for PGVecto.RS (Rust-based vector extension for PostgreSQL) """ - PGVECTO_RS_HOST: Optional[str] = Field( + PGVECTO_RS_HOST: str | None = Field( description="Hostname or IP address of the PostgreSQL server with PGVecto.RS extension (e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class PGVectoRSConfig(BaseSettings): default=5431, ) - PGVECTO_RS_USER: Optional[str] = Field( + PGVECTO_RS_USER: str | None = Field( description="Username for authenticating with the PostgreSQL database using PGVecto.RS", default=None, ) - PGVECTO_RS_PASSWORD: Optional[str] = Field( + PGVECTO_RS_PASSWORD: str | None = Field( description="Password for authenticating with the PostgreSQL database using PGVecto.RS", default=None, ) - PGVECTO_RS_DATABASE: Optional[str] = Field( + PGVECTO_RS_DATABASE: str | None = Field( description="Name of the PostgreSQL database with PGVecto.RS extension to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/qdrant_config.py b/api/configs/middleware/vdb/qdrant_config.py index 0a753eddec..b9e8e861da 100644 --- a/api/configs/middleware/vdb/qdrant_config.py +++ b/api/configs/middleware/vdb/qdrant_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class QdrantConfig(BaseSettings): Configuration settings for Qdrant vector database """ - QDRANT_URL: Optional[str] = Field( + QDRANT_URL: str | None = Field( description="URL of the Qdrant server (e.g., 'http://localhost:6333' or 'https://qdrant.example.com')", default=None, ) - QDRANT_API_KEY: Optional[str] = Field( + QDRANT_API_KEY: str | None = Field( description="API key for authenticating with the Qdrant server", default=None, ) diff --git a/api/configs/middleware/vdb/relyt_config.py b/api/configs/middleware/vdb/relyt_config.py index 5ffbea7b19..0ed5357852 100644 --- a/api/configs/middleware/vdb/relyt_config.py +++ b/api/configs/middleware/vdb/relyt_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class RelytConfig(BaseSettings): Configuration settings for Relyt database """ - RELYT_HOST: Optional[str] = Field( + RELYT_HOST: str | None = Field( description="Hostname or IP address of the Relyt server (e.g., 'localhost' or 'relyt.example.com')", default=None, ) @@ -19,17 +17,17 @@ class RelytConfig(BaseSettings): default=9200, ) - RELYT_USER: Optional[str] = Field( + RELYT_USER: str | None = Field( description="Username for authenticating with the Relyt database", default=None, ) - RELYT_PASSWORD: Optional[str] = Field( + RELYT_PASSWORD: str | None = Field( description="Password for authenticating with the Relyt database", default=None, ) - RELYT_DATABASE: Optional[str] = Field( + RELYT_DATABASE: str | None = Field( description="Name of the Relyt database to connect to (default is 'default')", default="default", ) diff --git a/api/configs/middleware/vdb/tablestore_config.py b/api/configs/middleware/vdb/tablestore_config.py index 1aab01c6e1..2cec384b5d 100644 --- a/api/configs/middleware/vdb/tablestore_config.py +++ b/api/configs/middleware/vdb/tablestore_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,22 +7,22 @@ class TableStoreConfig(BaseSettings): Configuration settings for TableStore. """ - TABLESTORE_ENDPOINT: Optional[str] = Field( + TABLESTORE_ENDPOINT: str | None = Field( description="Endpoint address of the TableStore server (e.g. 'https://instance-name.cn-hangzhou.ots.aliyuncs.com')", default=None, ) - TABLESTORE_INSTANCE_NAME: Optional[str] = Field( + TABLESTORE_INSTANCE_NAME: str | None = Field( description="Instance name to access TableStore server (eg. 'instance-name')", default=None, ) - TABLESTORE_ACCESS_KEY_ID: Optional[str] = Field( + TABLESTORE_ACCESS_KEY_ID: str | None = Field( description="AccessKey id for the instance name", default=None, ) - TABLESTORE_ACCESS_KEY_SECRET: Optional[str] = Field( + TABLESTORE_ACCESS_KEY_SECRET: str | None = Field( description="AccessKey secret for the instance name", default=None, ) diff --git a/api/configs/middleware/vdb/tencent_vector_config.py b/api/configs/middleware/vdb/tencent_vector_config.py index a51823c3f3..3dc21ab89a 100644 --- a/api/configs/middleware/vdb/tencent_vector_config.py +++ b/api/configs/middleware/vdb/tencent_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class TencentVectorDBConfig(BaseSettings): Configuration settings for Tencent Vector Database """ - TENCENT_VECTOR_DB_URL: Optional[str] = Field( + TENCENT_VECTOR_DB_URL: str | None = Field( description="URL of the Tencent Vector Database service (e.g., 'https://vectordb.tencentcloudapi.com')", default=None, ) - TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field( + TENCENT_VECTOR_DB_API_KEY: str | None = Field( description="API key for authenticating with the Tencent Vector Database service", default=None, ) @@ -24,12 +22,12 @@ class TencentVectorDBConfig(BaseSettings): default=30, ) - TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field( + TENCENT_VECTOR_DB_USERNAME: str | None = Field( description="Username for authenticating with the Tencent Vector Database (if required)", default=None, ) - TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field( + TENCENT_VECTOR_DB_PASSWORD: str | None = Field( description="Password for authenticating with the Tencent Vector Database (if required)", default=None, ) @@ -44,7 +42,7 @@ class TencentVectorDBConfig(BaseSettings): default=2, ) - TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field( + TENCENT_VECTOR_DB_DATABASE: str | None = Field( description="Name of the specific Tencent Vector Database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/tidb_on_qdrant_config.py b/api/configs/middleware/vdb/tidb_on_qdrant_config.py index d2625af264..9ca0955129 100644 --- a/api/configs/middleware/vdb/tidb_on_qdrant_config.py +++ b/api/configs/middleware/vdb/tidb_on_qdrant_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, NonNegativeInt, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class TidbOnQdrantConfig(BaseSettings): Tidb on Qdrant configs """ - TIDB_ON_QDRANT_URL: Optional[str] = Field( + TIDB_ON_QDRANT_URL: str | None = Field( description="Tidb on Qdrant url", default=None, ) - TIDB_ON_QDRANT_API_KEY: Optional[str] = Field( + TIDB_ON_QDRANT_API_KEY: str | None = Field( description="Tidb on Qdrant api key", default=None, ) @@ -34,37 +32,37 @@ class TidbOnQdrantConfig(BaseSettings): default=6334, ) - TIDB_PUBLIC_KEY: Optional[str] = Field( + TIDB_PUBLIC_KEY: str | None = Field( description="Tidb account public key", default=None, ) - TIDB_PRIVATE_KEY: Optional[str] = Field( + TIDB_PRIVATE_KEY: str | None = Field( description="Tidb account private key", default=None, ) - TIDB_API_URL: Optional[str] = Field( + TIDB_API_URL: str | None = Field( description="Tidb API url", default=None, ) - TIDB_IAM_API_URL: Optional[str] = Field( + TIDB_IAM_API_URL: str | None = Field( description="Tidb IAM API url", default=None, ) - TIDB_REGION: Optional[str] = Field( + TIDB_REGION: str | None = Field( description="Tidb serverless region", default="regions/aws-us-east-1", ) - TIDB_PROJECT_ID: Optional[str] = Field( + TIDB_PROJECT_ID: str | None = Field( description="Tidb project id", default=None, ) - TIDB_SPEND_LIMIT: Optional[int] = Field( + TIDB_SPEND_LIMIT: int | None = Field( description="Tidb spend limit", default=100, ) diff --git a/api/configs/middleware/vdb/tidb_vector_config.py b/api/configs/middleware/vdb/tidb_vector_config.py index bc68be69d8..0ebf226bea 100644 --- a/api/configs/middleware/vdb/tidb_vector_config.py +++ b/api/configs/middleware/vdb/tidb_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,27 +7,27 @@ class TiDBVectorConfig(BaseSettings): Configuration settings for TiDB Vector database """ - TIDB_VECTOR_HOST: Optional[str] = Field( + TIDB_VECTOR_HOST: str | None = Field( description="Hostname or IP address of the TiDB Vector server (e.g., 'localhost' or 'tidb.example.com')", default=None, ) - TIDB_VECTOR_PORT: Optional[PositiveInt] = Field( + TIDB_VECTOR_PORT: PositiveInt | None = Field( description="Port number on which the TiDB Vector server is listening (default is 4000)", default=4000, ) - TIDB_VECTOR_USER: Optional[str] = Field( + TIDB_VECTOR_USER: str | None = Field( description="Username for authenticating with the TiDB Vector database", default=None, ) - TIDB_VECTOR_PASSWORD: Optional[str] = Field( + TIDB_VECTOR_PASSWORD: str | None = Field( description="Password for authenticating with the TiDB Vector database", default=None, ) - TIDB_VECTOR_DATABASE: Optional[str] = Field( + TIDB_VECTOR_DATABASE: str | None = Field( description="Name of the TiDB Vector database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/upstash_config.py b/api/configs/middleware/vdb/upstash_config.py index 412c56374a..01a0442f70 100644 --- a/api/configs/middleware/vdb/upstash_config.py +++ b/api/configs/middleware/vdb/upstash_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class UpstashConfig(BaseSettings): Configuration settings for Upstash vector database """ - UPSTASH_VECTOR_URL: Optional[str] = Field( + UPSTASH_VECTOR_URL: str | None = Field( description="URL of the upstash server (e.g., 'https://vector.upstash.io')", default=None, ) - UPSTASH_VECTOR_TOKEN: Optional[str] = Field( + UPSTASH_VECTOR_TOKEN: str | None = Field( description="Token for authenticating with the upstash server", default=None, ) diff --git a/api/configs/middleware/vdb/vastbase_vector_config.py b/api/configs/middleware/vdb/vastbase_vector_config.py index 816d6df90a..ced4cf154c 100644 --- a/api/configs/middleware/vdb/vastbase_vector_config.py +++ b/api/configs/middleware/vdb/vastbase_vector_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,7 +7,7 @@ class VastbaseVectorConfig(BaseSettings): Configuration settings for Vector (Vastbase with vector extension) """ - VASTBASE_HOST: Optional[str] = Field( + VASTBASE_HOST: str | None = Field( description="Hostname or IP address of the Vastbase server with Vector extension (e.g., 'localhost')", default=None, ) @@ -19,17 +17,17 @@ class VastbaseVectorConfig(BaseSettings): default=5432, ) - VASTBASE_USER: Optional[str] = Field( + VASTBASE_USER: str | None = Field( description="Username for authenticating with the Vastbase database", default=None, ) - VASTBASE_PASSWORD: Optional[str] = Field( + VASTBASE_PASSWORD: str | None = Field( description="Password for authenticating with the Vastbase database", default=None, ) - VASTBASE_DATABASE: Optional[str] = Field( + VASTBASE_DATABASE: str | None = Field( description="Name of the Vastbase database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/vikingdb_config.py b/api/configs/middleware/vdb/vikingdb_config.py index aba49ff670..3d5306bb61 100644 --- a/api/configs/middleware/vdb/vikingdb_config.py +++ b/api/configs/middleware/vdb/vikingdb_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from pydantic_settings import BaseSettings @@ -11,14 +9,14 @@ class VikingDBConfig(BaseSettings): https://www.volcengine.com/docs/6291/65568 """ - VIKINGDB_ACCESS_KEY: Optional[str] = Field( + VIKINGDB_ACCESS_KEY: str | None = Field( description="The Access Key provided by Volcengine VikingDB for API authentication." "Refer to the following documentation for details on obtaining credentials:" "https://www.volcengine.com/docs/6291/65568", default=None, ) - VIKINGDB_SECRET_KEY: Optional[str] = Field( + VIKINGDB_SECRET_KEY: str | None = Field( description="The Secret Key provided by Volcengine VikingDB for API authentication.", default=None, ) diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py index 25000e8bde..6a79412ab8 100644 --- a/api/configs/middleware/vdb/weaviate_config.py +++ b/api/configs/middleware/vdb/weaviate_config.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,12 +7,12 @@ class WeaviateConfig(BaseSettings): Configuration settings for Weaviate vector database """ - WEAVIATE_ENDPOINT: Optional[str] = Field( + WEAVIATE_ENDPOINT: str | None = Field( description="URL of the Weaviate server (e.g., 'http://localhost:8080' or 'https://weaviate.example.com')", default=None, ) - WEAVIATE_API_KEY: Optional[str] = Field( + WEAVIATE_API_KEY: str | None = Field( description="API key for authenticating with the Weaviate server", default=None, ) diff --git a/api/configs/remote_settings_sources/apollo/__init__.py b/api/configs/remote_settings_sources/apollo/__init__.py index f02f7dc9ff..55c14ead56 100644 --- a/api/configs/remote_settings_sources/apollo/__init__.py +++ b/api/configs/remote_settings_sources/apollo/__init__.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from pydantic import Field from pydantic.fields import FieldInfo @@ -15,22 +15,22 @@ class ApolloSettingsSourceInfo(BaseSettings): Packaging build information """ - APOLLO_APP_ID: Optional[str] = Field( + APOLLO_APP_ID: str | None = Field( description="apollo app_id", default=None, ) - APOLLO_CLUSTER: Optional[str] = Field( + APOLLO_CLUSTER: str | None = Field( description="apollo cluster", default=None, ) - APOLLO_CONFIG_URL: Optional[str] = Field( + APOLLO_CONFIG_URL: str | None = Field( description="apollo config url", default=None, ) - APOLLO_NAMESPACE: Optional[str] = Field( + APOLLO_NAMESPACE: str | None = Field( description="apollo namespace", default=None, ) diff --git a/api/constants/model_template.py b/api/constants/model_template.py index c26d8c0186..cacf6b6874 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -7,7 +7,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # workflow default mode AppMode.WORKFLOW: { "app": { - "mode": AppMode.WORKFLOW.value, + "mode": AppMode.WORKFLOW, "enable_site": True, "enable_api": True, } @@ -15,7 +15,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # completion default mode AppMode.COMPLETION: { "app": { - "mode": AppMode.COMPLETION.value, + "mode": AppMode.COMPLETION, "enable_site": True, "enable_api": True, }, @@ -44,7 +44,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # chat default mode AppMode.CHAT: { "app": { - "mode": AppMode.CHAT.value, + "mode": AppMode.CHAT, "enable_site": True, "enable_api": True, }, @@ -60,7 +60,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # advanced-chat default mode AppMode.ADVANCED_CHAT: { "app": { - "mode": AppMode.ADVANCED_CHAT.value, + "mode": AppMode.ADVANCED_CHAT, "enable_site": True, "enable_api": True, }, @@ -68,7 +68,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # agent-chat default mode AppMode.AGENT_CHAT: { "app": { - "mode": AppMode.AGENT_CHAT.value, + "mode": AppMode.AGENT_CHAT, "enable_site": True, "enable_api": True, }, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 1400ee7085..e13edf6a37 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -54,89 +54,90 @@ api.add_resource(AppImportCheckDependenciesApi, "/apps/imports//c # Import other controllers from . import ( - admin, # pyright: ignore[reportUnusedImport] - apikey, # pyright: ignore[reportUnusedImport] - extension, # pyright: ignore[reportUnusedImport] - feature, # pyright: ignore[reportUnusedImport] - init_validate, # pyright: ignore[reportUnusedImport] - ping, # pyright: ignore[reportUnusedImport] - setup, # pyright: ignore[reportUnusedImport] - version, # pyright: ignore[reportUnusedImport] + admin, + apikey, + extension, + feature, + init_validate, + ping, + setup, + version, ) # Import app controllers from .app import ( - advanced_prompt_template, # pyright: ignore[reportUnusedImport] - agent, # pyright: ignore[reportUnusedImport] - annotation, # pyright: ignore[reportUnusedImport] - app, # pyright: ignore[reportUnusedImport] - audio, # pyright: ignore[reportUnusedImport] - completion, # pyright: ignore[reportUnusedImport] - conversation, # pyright: ignore[reportUnusedImport] - conversation_variables, # pyright: ignore[reportUnusedImport] - generator, # pyright: ignore[reportUnusedImport] - mcp_server, # pyright: ignore[reportUnusedImport] - message, # pyright: ignore[reportUnusedImport] - model_config, # pyright: ignore[reportUnusedImport] - ops_trace, # pyright: ignore[reportUnusedImport] - site, # pyright: ignore[reportUnusedImport] - statistic, # pyright: ignore[reportUnusedImport] - workflow, # pyright: ignore[reportUnusedImport] - workflow_app_log, # pyright: ignore[reportUnusedImport] - workflow_draft_variable, # pyright: ignore[reportUnusedImport] - workflow_run, # pyright: ignore[reportUnusedImport] - workflow_statistic, # pyright: ignore[reportUnusedImport] + advanced_prompt_template, + agent, + annotation, + app, + audio, + completion, + conversation, + conversation_variables, + generator, + mcp_server, + message, + model_config, + ops_trace, + site, + statistic, + workflow, + workflow_app_log, + workflow_draft_variable, + workflow_run, + workflow_statistic, ) # Import auth controllers from .auth import ( - activate, # pyright: ignore[reportUnusedImport] - data_source_bearer_auth, # pyright: ignore[reportUnusedImport] - data_source_oauth, # pyright: ignore[reportUnusedImport] - forgot_password, # pyright: ignore[reportUnusedImport] - login, # pyright: ignore[reportUnusedImport] - oauth, # pyright: ignore[reportUnusedImport] - oauth_server, # pyright: ignore[reportUnusedImport] + activate, + data_source_bearer_auth, + data_source_oauth, + email_register, + forgot_password, + login, + oauth, + oauth_server, ) # Import billing controllers -from .billing import billing, compliance # pyright: ignore[reportUnusedImport] +from .billing import billing, compliance # Import datasets controllers from .datasets import ( - data_source, # pyright: ignore[reportUnusedImport] - datasets, # pyright: ignore[reportUnusedImport] - datasets_document, # pyright: ignore[reportUnusedImport] - datasets_segments, # pyright: ignore[reportUnusedImport] - external, # pyright: ignore[reportUnusedImport] - hit_testing, # pyright: ignore[reportUnusedImport] - metadata, # pyright: ignore[reportUnusedImport] - website, # pyright: ignore[reportUnusedImport] + data_source, + datasets, + datasets_document, + datasets_segments, + external, + hit_testing, + metadata, + website, ) # Import explore controllers from .explore import ( - installed_app, # pyright: ignore[reportUnusedImport] - parameter, # pyright: ignore[reportUnusedImport] - recommended_app, # pyright: ignore[reportUnusedImport] - saved_message, # pyright: ignore[reportUnusedImport] + installed_app, + parameter, + recommended_app, + saved_message, ) # Import tag controllers -from .tag import tags # pyright: ignore[reportUnusedImport] +from .tag import tags # Import workspace controllers from .workspace import ( - account, # pyright: ignore[reportUnusedImport] - agent_providers, # pyright: ignore[reportUnusedImport] - endpoint, # pyright: ignore[reportUnusedImport] - load_balancing_config, # pyright: ignore[reportUnusedImport] - members, # pyright: ignore[reportUnusedImport] - model_providers, # pyright: ignore[reportUnusedImport] - models, # pyright: ignore[reportUnusedImport] - plugin, # pyright: ignore[reportUnusedImport] - tool_providers, # pyright: ignore[reportUnusedImport] - workspace, # pyright: ignore[reportUnusedImport] + account, + agent_providers, + endpoint, + load_balancing_config, + members, + model_providers, + models, + plugin, + tool_providers, + workspace, ) # Explore Audio @@ -211,3 +212,70 @@ api.add_resource( ) api.add_namespace(console_ns) + +__all__ = [ + "account", + "activate", + "admin", + "advanced_prompt_template", + "agent", + "agent_providers", + "annotation", + "api", + "apikey", + "app", + "audio", + "billing", + "bp", + "completion", + "compliance", + "console_ns", + "conversation", + "conversation_variables", + "data_source", + "data_source_bearer_auth", + "data_source_oauth", + "datasets", + "datasets_document", + "datasets_segments", + "email_register", + "endpoint", + "extension", + "external", + "feature", + "forgot_password", + "generator", + "hit_testing", + "init_validate", + "installed_app", + "load_balancing_config", + "login", + "mcp_server", + "members", + "message", + "metadata", + "model_config", + "model_providers", + "models", + "oauth", + "oauth_server", + "ops_trace", + "parameter", + "ping", + "plugin", + "recommended_app", + "saved_message", + "setup", + "site", + "statistic", + "tags", + "tool_providers", + "version", + "website", + "workflow", + "workflow_app_log", + "workflow_draft_variable", + "workflow_run", + "workflow_statistic", + "workspace", +] diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 56c61c2886..fec527e4cb 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,5 +1,3 @@ -from typing import Optional - import flask_restx from flask_login import current_user from flask_restx import Resource, fields, marshal_with @@ -50,7 +48,7 @@ class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[type] = None + resource_model: type | None = None resource_id_field: str | None = None token_prefix: str | None = None max_keys = 10 @@ -103,7 +101,7 @@ class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[type] = None + resource_model: type | None = None resource_id_field: str | None = None def delete(self, resource_id, api_key_id): diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c6cb6f6e3a..315825db79 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,12 +1,26 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.advanced_prompt_template_service import AdvancedPromptTemplateService +@console_ns.route("/app/prompt-templates") class AdvancedPromptTemplateList(Resource): + @api.doc("get_advanced_prompt_templates") + @api.doc(description="Get advanced prompt templates based on app mode and model configuration") + @api.expect( + api.parser() + .add_argument("app_mode", type=str, required=True, location="args", help="Application mode") + .add_argument("model_mode", type=str, required=True, location="args", help="Model mode") + .add_argument("has_context", type=str, default="true", location="args", help="Whether has context") + .add_argument("model_name", type=str, required=True, location="args", help="Model name") + ) + @api.response( + 200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data")) + ) + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -19,6 +33,3 @@ class AdvancedPromptTemplateList(Resource): args = parser.parse_args() return AdvancedPromptTemplateService.get_prompt(args) - - -api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates") diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index a964154207..c063f336c7 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,6 +1,6 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from libs.helper import uuid_value @@ -9,7 +9,18 @@ from models.model import AppMode from services.agent_service import AgentService +@console_ns.route("/apps//agent/logs") class AgentLogApi(Resource): + @api.doc("get_agent_logs") + @api.doc(description="Get agent execution logs for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("message_id", type=str, required=True, location="args", help="Message UUID") + .add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID") + ) + @api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))) + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -23,6 +34,3 @@ class AgentLogApi(Resource): args = parser.parse_args() return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) - - -api.add_resource(AgentLogApi, "/apps//agent/logs") diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 37d23ccd9f..d0ee11fe75 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -2,11 +2,11 @@ from typing import Literal from flask import request from flask_login import current_user -from flask_restx import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.common.errors import NoFileUploadedError, TooManyFilesError -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, @@ -21,7 +21,23 @@ from libs.login import login_required from services.annotation_service import AppAnnotationService +@console_ns.route("/apps//annotation-reply/") class AnnotationReplyActionApi(Resource): + @api.doc("annotation_reply_action") + @api.doc(description="Enable or disable annotation reply for an app") + @api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"}) + @api.expect( + api.model( + "AnnotationReplyActionRequest", + { + "score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"), + "embedding_provider_name": fields.String(required=True, description="Embedding provider name"), + "embedding_model_name": fields.String(required=True, description="Embedding model name"), + }, + ) + ) + @api.response(200, "Action completed successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -43,7 +59,13 @@ class AnnotationReplyActionApi(Resource): return result, 200 +@console_ns.route("/apps//annotation-setting") class AppAnnotationSettingDetailApi(Resource): + @api.doc("get_annotation_setting") + @api.doc(description="Get annotation settings for an app") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Annotation settings retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -56,7 +78,23 @@ class AppAnnotationSettingDetailApi(Resource): return result, 200 +@console_ns.route("/apps//annotation-settings/") class AppAnnotationSettingUpdateApi(Resource): + @api.doc("update_annotation_setting") + @api.doc(description="Update annotation settings for an app") + @api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"}) + @api.expect( + api.model( + "AnnotationSettingUpdateRequest", + { + "score_threshold": fields.Float(required=True, description="Score threshold"), + "embedding_provider_name": fields.String(required=True, description="Embedding provider"), + "embedding_model_name": fields.String(required=True, description="Embedding model"), + }, + ) + ) + @api.response(200, "Settings updated successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -75,7 +113,13 @@ class AppAnnotationSettingUpdateApi(Resource): return result, 200 +@console_ns.route("/apps//annotation-reply//status/") class AnnotationReplyActionStatusApi(Resource): + @api.doc("get_annotation_reply_action_status") + @api.doc(description="Get status of annotation reply action job") + @api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"}) + @api.response(200, "Job status retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -99,7 +143,19 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 +@console_ns.route("/apps//annotations") class AnnotationApi(Resource): + @api.doc("list_annotations") + @api.doc(description="Get annotations for an app with pagination") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size") + .add_argument("keyword", type=str, location="args", default="", help="Search keyword") + ) + @api.response(200, "Annotations retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -122,6 +178,21 @@ class AnnotationApi(Resource): } return response, 200 + @api.doc("create_annotation") + @api.doc(description="Create a new annotation for an app") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "CreateAnnotationRequest", + { + "question": fields.String(required=True, description="Question text"), + "answer": fields.String(required=True, description="Answer text"), + "annotation_reply": fields.Raw(description="Annotation reply data"), + }, + ) + ) + @api.response(201, "Annotation created successfully", annotation_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -168,7 +239,13 @@ class AnnotationApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//annotations/export") class AnnotationExportApi(Resource): + @api.doc("export_annotations") + @api.doc(description="Export all annotations for an app") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields))) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -182,7 +259,14 @@ class AnnotationExportApi(Resource): return response, 200 +@console_ns.route("/apps//annotations/") class AnnotationUpdateDeleteApi(Resource): + @api.doc("update_delete_annotation") + @api.doc(description="Update or delete an annotation") + @api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) + @api.response(200, "Annotation updated successfully", annotation_fields) + @api.response(204, "Annotation deleted successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -214,7 +298,14 @@ class AnnotationUpdateDeleteApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//annotations/batch-import") class AnnotationBatchImportApi(Resource): + @api.doc("batch_import_annotations") + @api.doc(description="Batch import annotations from CSV file") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Batch import started successfully") + @api.response(403, "Insufficient permissions") + @api.response(400, "No file uploaded or too many files") @setup_required @login_required @account_initialization_required @@ -239,7 +330,13 @@ class AnnotationBatchImportApi(Resource): return AppAnnotationService.batch_import_app_annotations(app_id, file) +@console_ns.route("/apps//annotations/batch-import-status/") class AnnotationBatchImportStatusApi(Resource): + @api.doc("get_batch_import_status") + @api.doc(description="Get status of batch import job") + @api.doc(params={"app_id": "Application ID", "job_id": "Job ID"}) + @api.response(200, "Job status retrieved successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -262,7 +359,20 @@ class AnnotationBatchImportStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 +@console_ns.route("/apps//annotations//hit-histories") class AnnotationHitHistoryListApi(Resource): + @api.doc("list_annotation_hit_histories") + @api.doc(description="Get hit histories for an annotation") + @api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) + @api.expect( + api.parser() + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size") + ) + @api.response( + 200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields)) + ) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -285,17 +395,3 @@ class AnnotationHitHistoryListApi(Resource): "page": page, } return response - - -api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply/") -api.add_resource( - AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" -) -api.add_resource(AnnotationApi, "/apps//annotations") -api.add_resource(AnnotationExportApi, "/apps//annotations/export") -api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") -api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") -api.add_resource(AnnotationBatchImportStatusApi, "/apps//annotations/batch-import-status/") -api.add_resource(AnnotationHitHistoryListApi, "/apps//annotations//hit-histories") -api.add_resource(AppAnnotationSettingDetailApi, "/apps//annotation-setting") -api.add_resource(AppAnnotationSettingUpdateApi, "/apps//annotation-settings/") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 1db9d2e764..2d2e4b448a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -2,12 +2,12 @@ import uuid from typing import cast from flask_login import current_user -from flask_restx import Resource, inputs, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( account_initialization_required, @@ -34,7 +34,27 @@ def _validate_description_length(description): return description +@console_ns.route("/apps") class AppListApi(Resource): + @api.doc("list_apps") + @api.doc(description="Get list of applications with pagination and filtering") + @api.expect( + api.parser() + .add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1) + .add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20) + .add_argument( + "mode", + type=str, + location="args", + choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"], + default="all", + help="App mode filter", + ) + .add_argument("name", type=str, location="args", help="Filter by app name") + .add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs") + .add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator") + ) + @api.response(200, "Success", app_pagination_fields) @setup_required @login_required @account_initialization_required @@ -91,6 +111,24 @@ class AppListApi(Resource): return marshal(app_pagination, app_pagination_fields), 200 + @api.doc("create_app") + @api.doc(description="Create a new application") + @api.expect( + api.model( + "CreateAppRequest", + { + "name": fields.String(required=True, description="App name"), + "description": fields.String(description="App description (max 400 chars)"), + "mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + }, + ) + ) + @api.response(201, "App created successfully", app_detail_fields) + @api.response(403, "Insufficient permissions") + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -124,7 +162,12 @@ class AppListApi(Resource): return app, 201 +@console_ns.route("/apps/") class AppApi(Resource): + @api.doc("get_app_detail") + @api.doc(description="Get application details") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Success", app_detail_fields_with_site) @setup_required @login_required @account_initialization_required @@ -143,6 +186,26 @@ class AppApi(Resource): return app_model + @api.doc("update_app") + @api.doc(description="Update application details") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "UpdateAppRequest", + { + "name": fields.String(required=True, description="App name"), + "description": fields.String(description="App description (max 400 chars)"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + "use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"), + "max_active_requests": fields.Integer(description="Maximum active requests"), + }, + ) + ) + @api.response(200, "App updated successfully", app_detail_fields_with_site) + @api.response(403, "Insufficient permissions") + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -181,6 +244,11 @@ class AppApi(Resource): return app_model + @api.doc("delete_app") + @api.doc(description="Delete application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(204, "App deleted successfully") + @api.response(403, "Insufficient permissions") @get_app_model @setup_required @login_required @@ -197,7 +265,25 @@ class AppApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//copy") class AppCopyApi(Resource): + @api.doc("copy_app") + @api.doc(description="Create a copy of an existing application") + @api.doc(params={"app_id": "Application ID to copy"}) + @api.expect( + api.model( + "CopyAppRequest", + { + "name": fields.String(description="Name for the copied app"), + "description": fields.String(description="Description for the copied app"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + }, + ) + ) + @api.response(201, "App copied successfully", app_detail_fields_with_site) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -239,7 +325,22 @@ class AppCopyApi(Resource): return app, 201 +@console_ns.route("/apps//export") class AppExportApi(Resource): + @api.doc("export_app") + @api.doc(description="Export application configuration as DSL") + @api.doc(params={"app_id": "Application ID to export"}) + @api.expect( + api.parser() + .add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export") + .add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export") + ) + @api.response( + 200, + "App exported successfully", + api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}), + ) + @api.response(403, "Insufficient permissions") @get_app_model @setup_required @login_required @@ -263,7 +364,13 @@ class AppExportApi(Resource): } +@console_ns.route("/apps//name") class AppNameApi(Resource): + @api.doc("check_app_name") + @api.doc(description="Check if app name is available") + @api.doc(params={"app_id": "Application ID"}) + @api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check")) + @api.response(200, "Name availability checked") @setup_required @login_required @account_initialization_required @@ -284,7 +391,23 @@ class AppNameApi(Resource): return app_model +@console_ns.route("/apps//icon") class AppIconApi(Resource): + @api.doc("update_app_icon") + @api.doc(description="Update application icon") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppIconRequest", + { + "icon": fields.String(required=True, description="Icon data"), + "icon_type": fields.String(description="Icon type"), + "icon_background": fields.String(description="Icon background color"), + }, + ) + ) + @api.response(200, "Icon updated successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -306,7 +429,18 @@ class AppIconApi(Resource): return app_model +@console_ns.route("/apps//site-enable") class AppSiteStatus(Resource): + @api.doc("update_app_site_status") + @api.doc(description="Enable or disable app site") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")} + ) + ) + @api.response(200, "Site status updated successfully", app_detail_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -327,7 +461,18 @@ class AppSiteStatus(Resource): return app_model +@console_ns.route("/apps//api-enable") class AppApiStatus(Resource): + @api.doc("update_app_api_status") + @api.doc(description="Enable or disable app API") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")} + ) + ) + @api.response(200, "API status updated successfully", app_detail_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -348,7 +493,12 @@ class AppApiStatus(Resource): return app_model +@console_ns.route("/apps//trace") class AppTraceApi(Resource): + @api.doc("get_app_trace") + @api.doc(description="Get app tracing configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Trace configuration retrieved successfully") @setup_required @login_required @account_initialization_required @@ -358,6 +508,20 @@ class AppTraceApi(Resource): return app_trace_config + @api.doc("update_app_trace") + @api.doc(description="Update app tracing configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppTraceRequest", + { + "enabled": fields.Boolean(required=True, description="Enable or disable tracing"), + "tracing_provider": fields.String(required=True, description="Tracing provider"), + }, + ) + ) + @api.response(200, "Trace configuration updated successfully") + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -377,14 +541,3 @@ class AppTraceApi(Resource): ) return {"result": "success"} - - -api.add_resource(AppListApi, "/apps") -api.add_resource(AppApi, "/apps/") -api.add_resource(AppCopyApi, "/apps//copy") -api.add_resource(AppExportApi, "/apps//export") -api.add_resource(AppNameApi, "/apps//name") -api.add_resource(AppIconApi, "/apps//icon") -api.add_resource(AppSiteStatus, "/apps//site-enable") -api.add_resource(AppApiStatus, "/apps//api-enable") -api.add_resource(AppTraceApi, "/apps//trace") diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 447bcb37c2..7d659dae0d 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,11 +1,11 @@ import logging from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from werkzeug.exceptions import InternalServerError import services -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -34,7 +34,18 @@ from services.errors.audio import ( logger = logging.getLogger(__name__) +@console_ns.route("/apps//audio-to-text") class ChatMessageAudioApi(Resource): + @api.doc("chat_message_audio_transcript") + @api.doc(description="Transcript audio to text for chat messages") + @api.doc(params={"app_id": "App ID"}) + @api.response( + 200, + "Audio transcription successful", + api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), + ) + @api.response(400, "Bad request - No audio uploaded or unsupported type") + @api.response(413, "Audio file too large") @setup_required @login_required @account_initialization_required @@ -76,7 +87,24 @@ class ChatMessageAudioApi(Resource): raise InternalServerError() +@console_ns.route("/apps//text-to-audio") class ChatMessageTextApi(Resource): + @api.doc("chat_message_text_to_speech") + @api.doc(description="Convert text to speech for chat messages") + @api.doc(params={"app_id": "App ID"}) + @api.expect( + api.model( + "TextToSpeechRequest", + { + "message_id": fields.String(description="Message ID"), + "text": fields.String(required=True, description="Text to convert to speech"), + "voice": fields.String(description="Voice to use for TTS"), + "streaming": fields.Boolean(description="Whether to stream the audio"), + }, + ) + ) + @api.response(200, "Text to speech conversion successful") + @api.response(400, "Bad request - Invalid parameters") @get_app_model @setup_required @login_required @@ -124,7 +152,14 @@ class ChatMessageTextApi(Resource): raise InternalServerError() +@console_ns.route("/apps//text-to-audio/voices") class TextModesApi(Resource): + @api.doc("get_text_to_speech_voices") + @api.doc(description="Get available TTS voices for a specific language") + @api.doc(params={"app_id": "App ID"}) + @api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code")) + @api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))) + @api.response(400, "Invalid language parameter") @get_app_model @setup_required @login_required @@ -164,8 +199,3 @@ class TextModesApi(Resource): except Exception as e: logger.exception("Failed to handle get request to TextModesApi") raise InternalServerError() - - -api.add_resource(ChatMessageAudioApi, "/apps//audio-to-text") -api.add_resource(ChatMessageTextApi, "/apps//text-to-audio") -api.add_resource(TextModesApi, "/apps//text-to-audio/voices") diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 2083c15a9b..2f7b90e7fb 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,11 +1,11 @@ import logging from flask import request -from flask_restx import Resource, reqparse -from werkzeug.exceptions import InternalServerError, NotFound +from flask_restx import Resource, fields, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -38,7 +38,27 @@ logger = logging.getLogger(__name__) # define completion message api for user +@console_ns.route("/apps//completion-messages") class CompletionMessageApi(Resource): + @api.doc("create_completion_message") + @api.doc(description="Generate completion message for debugging") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "CompletionMessageRequest", + { + "inputs": fields.Raw(required=True, description="Input variables"), + "query": fields.String(description="Query text", default=""), + "files": fields.List(fields.Raw(), description="Uploaded files"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"), + "retriever_from": fields.String(default="dev", description="Retriever source"), + }, + ) + ) + @api.response(200, "Completion generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(404, "App not found") @setup_required @login_required @account_initialization_required @@ -86,7 +106,12 @@ class CompletionMessageApi(Resource): raise InternalServerError() +@console_ns.route("/apps//completion-messages//stop") class CompletionMessageStopApi(Resource): + @api.doc("stop_completion_message") + @api.doc(description="Stop a running completion message generation") + @api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"}) + @api.response(200, "Task stopped successfully") @setup_required @login_required @account_initialization_required @@ -99,12 +124,40 @@ class CompletionMessageStopApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/apps//chat-messages") class ChatMessageApi(Resource): + @api.doc("create_chat_message") + @api.doc(description="Generate chat message for debugging") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "ChatMessageRequest", + { + "inputs": fields.Raw(required=True, description="Input variables"), + "query": fields.String(required=True, description="User query"), + "files": fields.List(fields.Raw(), description="Uploaded files"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "conversation_id": fields.String(description="Conversation ID"), + "parent_message_id": fields.String(description="Parent message ID"), + "response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"), + "retriever_from": fields.String(default="dev", description="Retriever source"), + }, + ) + ) + @api.response(200, "Chat message generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(404, "App or conversation not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): + if not isinstance(current_user, Account): + raise Forbidden() + + if not current_user.has_edit_permission: + raise Forbidden() + parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument("query", type=str, required=True, location="json") @@ -155,7 +208,12 @@ class ChatMessageApi(Resource): raise InternalServerError() +@console_ns.route("/apps//chat-messages//stop") class ChatMessageStopApi(Resource): + @api.doc("stop_chat_message") + @api.doc(description="Stop a running chat message generation") + @api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"}) + @api.response(200, "Task stopped successfully") @setup_required @login_required @account_initialization_required @@ -166,9 +224,3 @@ class ChatMessageStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) return {"result": "success"}, 200 - - -api.add_resource(CompletionMessageApi, "/apps//completion-messages") -api.add_resource(CompletionMessageStopApi, "/apps//completion-messages//stop") -api.add_resource(ChatMessageApi, "/apps//chat-messages") -api.add_resource(ChatMessageStopApi, "/apps//chat-messages//stop") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 2f2cd66aaa..c0cbf6613e 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -8,7 +8,7 @@ from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload from werkzeug.exceptions import Forbidden, NotFound -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom @@ -28,7 +28,29 @@ from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError +@console_ns.route("/apps//completion-conversations") class CompletionConversationApi(Resource): + @api.doc("list_completion_conversations") + @api.doc(description="Get completion conversations with pagination and filtering") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("keyword", type=str, location="args", help="Search keyword") + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + .add_argument( + "annotation_status", + type=str, + location="args", + choices=["annotated", "not_annotated", "all"], + default="all", + help="Annotation status filter", + ) + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)") + ) + @api.response(200, "Success", conversation_pagination_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -101,7 +123,14 @@ class CompletionConversationApi(Resource): return conversations +@console_ns.route("/apps//completion-conversations/") class CompletionConversationDetailApi(Resource): + @api.doc("get_completion_conversation") + @api.doc(description="Get completion conversation details with messages") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(200, "Success", conversation_message_detail_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @@ -114,6 +143,12 @@ class CompletionConversationDetailApi(Resource): return _get_conversation(app_model, conversation_id) + @api.doc("delete_completion_conversation") + @api.doc(description="Delete a completion conversation") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(204, "Conversation deleted successfully") + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @@ -133,7 +168,38 @@ class CompletionConversationDetailApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//chat-conversations") class ChatConversationApi(Resource): + @api.doc("list_chat_conversations") + @api.doc(description="Get chat conversations with pagination, filtering and summary") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("keyword", type=str, location="args", help="Search keyword") + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + .add_argument( + "annotation_status", + type=str, + location="args", + choices=["annotated", "not_annotated", "all"], + default="all", + help="Annotation status filter", + ) + .add_argument("message_count_gte", type=int, location="args", help="Minimum message count") + .add_argument("page", type=int, location="args", default=1, help="Page number") + .add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)") + .add_argument( + "sort_by", + type=str, + location="args", + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + default="-updated_at", + help="Sort field and direction", + ) + ) + @api.response(200, "Success", conversation_with_summary_pagination_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -241,7 +307,7 @@ class ChatConversationApi(Resource): .having(func.count(Message.id) >= args["message_count_gte"]) ) - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) match args["sort_by"]: @@ -261,7 +327,14 @@ class ChatConversationApi(Resource): return conversations +@console_ns.route("/apps//chat-conversations/") class ChatConversationDetailApi(Resource): + @api.doc("get_chat_conversation") + @api.doc(description="Get chat conversation details") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(200, "Success", conversation_detail_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @@ -274,6 +347,12 @@ class ChatConversationDetailApi(Resource): return _get_conversation(app_model, conversation_id) + @api.doc("delete_chat_conversation") + @api.doc(description="Delete a chat conversation") + @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) + @api.response(204, "Conversation deleted successfully") + @api.response(403, "Insufficient permissions") + @api.response(404, "Conversation not found") @setup_required @login_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @@ -293,12 +372,6 @@ class ChatConversationDetailApi(Resource): return {"result": "success"}, 204 -api.add_resource(CompletionConversationApi, "/apps//completion-conversations") -api.add_resource(CompletionConversationDetailApi, "/apps//completion-conversations/") -api.add_resource(ChatConversationApi, "/apps//chat-conversations") -api.add_resource(ChatConversationDetailApi, "/apps//chat-conversations/") - - def _get_conversation(app_model, conversation_id): conversation = ( db.session.query(Conversation) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 5ca4c33f87..8a65a89963 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -2,7 +2,7 @@ from flask_restx import Resource, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db @@ -12,7 +12,17 @@ from models import ConversationVariable from models.model import AppMode +@console_ns.route("/apps//conversation-variables") class ConversationVariablesApi(Resource): + @api.doc("get_conversation_variables") + @api.doc(description="Get conversation variables for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser().add_argument( + "conversation_id", type=str, location="args", help="Conversation ID to filter variables" + ) + ) + @api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields) @setup_required @login_required @account_initialization_required @@ -55,6 +65,3 @@ class ConversationVariablesApi(Resource): for row in rows ], } - - -api.add_resource(ConversationVariablesApi, "/apps//conversation-variables") diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index a2cb226014..d911b25028 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,9 +1,9 @@ from collections.abc import Sequence from flask_login import current_user -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -19,7 +19,23 @@ from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required +@console_ns.route("/rule-generate") class RuleGenerateApi(Resource): + @api.doc("generate_rule_config") + @api.doc(description="Generate rule configuration using LLM") + @api.expect( + api.model( + "RuleGenerateRequest", + { + "instruction": fields.String(required=True, description="Rule generation instruction"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"), + }, + ) + ) + @api.response(200, "Rule configuration generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -50,7 +66,26 @@ class RuleGenerateApi(Resource): return rules +@console_ns.route("/rule-code-generate") class RuleCodeGenerateApi(Resource): + @api.doc("generate_rule_code") + @api.doc(description="Generate code rules using LLM") + @api.expect( + api.model( + "RuleCodeGenerateRequest", + { + "instruction": fields.String(required=True, description="Code generation instruction"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"), + "code_language": fields.String( + default="javascript", description="Programming language for code generation" + ), + }, + ) + ) + @api.response(200, "Code rules generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -82,7 +117,22 @@ class RuleCodeGenerateApi(Resource): return code_result +@console_ns.route("/rule-structured-output-generate") class RuleStructuredOutputGenerateApi(Resource): + @api.doc("generate_structured_output") + @api.doc(description="Generate structured output rules using LLM") + @api.expect( + api.model( + "StructuredOutputGenerateRequest", + { + "instruction": fields.String(required=True, description="Structured output generation instruction"), + "model_config": fields.Raw(required=True, description="Model configuration"), + }, + ) + ) + @api.response(200, "Structured output generated successfully") + @api.response(400, "Invalid request parameters") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -111,7 +161,27 @@ class RuleStructuredOutputGenerateApi(Resource): return structured_output +@console_ns.route("/instruction-generate") class InstructionGenerateApi(Resource): + @api.doc("generate_instruction") + @api.doc(description="Generate instruction for workflow nodes or general use") + @api.expect( + api.model( + "InstructionGenerateRequest", + { + "flow_id": fields.String(required=True, description="Workflow/Flow ID"), + "node_id": fields.String(description="Node ID for workflow context"), + "current": fields.String(description="Current instruction text"), + "language": fields.String(default="javascript", description="Programming language (javascript/python)"), + "instruction": fields.String(required=True, description="Instruction for generation"), + "model_config": fields.Raw(required=True, description="Model configuration"), + "ideal_output": fields.String(description="Expected ideal output"), + }, + ) + ) + @api.response(200, "Instruction generated successfully") + @api.response(400, "Invalid request parameters or flow/workflow not found") + @api.response(402, "Provider quota exceeded") @setup_required @login_required @account_initialization_required @@ -203,7 +273,21 @@ class InstructionGenerateApi(Resource): raise CompletionRequestError(e.description) +@console_ns.route("/instruction-generate/template") class InstructionGenerationTemplateApi(Resource): + @api.doc("get_instruction_template") + @api.doc(description="Get instruction generation template") + @api.expect( + api.model( + "InstructionTemplateRequest", + { + "instruction": fields.String(required=True, description="Template instruction"), + "ideal_output": fields.String(description="Expected ideal output"), + }, + ) + ) + @api.response(200, "Template retrieved successfully") + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -222,10 +306,3 @@ class InstructionGenerationTemplateApi(Resource): return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE} case _: raise ValueError(f"Invalid type: {args['type']}") - - -api.add_resource(RuleGenerateApi, "/rule-generate") -api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") -api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate") -api.add_resource(InstructionGenerateApi, "/instruction-generate") -api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template") diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 541803e539..b9a383ee61 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,10 +2,10 @@ import json from enum import StrEnum from flask_login import current_user -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from werkzeug.exceptions import NotFound -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db @@ -19,7 +19,12 @@ class AppMCPServerStatus(StrEnum): INACTIVE = "inactive" +@console_ns.route("/apps//server") class AppMCPServerController(Resource): + @api.doc("get_app_mcp_server") + @api.doc(description="Get MCP server configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "MCP server configuration retrieved successfully", app_server_fields) @setup_required @login_required @account_initialization_required @@ -29,6 +34,20 @@ class AppMCPServerController(Resource): server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() return server + @api.doc("create_app_mcp_server") + @api.doc(description="Create MCP server configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MCPServerCreateRequest", + { + "description": fields.String(description="Server description"), + "parameters": fields.Raw(required=True, description="Server parameters configuration"), + }, + ) + ) + @api.response(201, "MCP server configuration created successfully", app_server_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -59,6 +78,23 @@ class AppMCPServerController(Resource): db.session.commit() return server + @api.doc("update_app_mcp_server") + @api.doc(description="Update MCP server configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MCPServerUpdateRequest", + { + "id": fields.String(required=True, description="Server ID"), + "description": fields.String(description="Server description"), + "parameters": fields.Raw(required=True, description="Server parameters configuration"), + "status": fields.String(description="Server status"), + }, + ) + ) + @api.response(200, "MCP server configuration updated successfully", app_server_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Server not found") @setup_required @login_required @account_initialization_required @@ -94,7 +130,14 @@ class AppMCPServerController(Resource): return server +@console_ns.route("/apps//server/refresh") class AppMCPServerRefreshController(Resource): + @api.doc("refresh_app_mcp_server") + @api.doc(description="Refresh MCP server configuration and regenerate server code") + @api.doc(params={"server_id": "Server ID"}) + @api.response(200, "MCP server refreshed successfully", app_server_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "Server not found") @setup_required @login_required @account_initialization_required @@ -113,7 +156,3 @@ class AppMCPServerRefreshController(Resource): server.server_code = AppMCPServer.generate_server_code(16) db.session.commit() return server - - -api.add_resource(AppMCPServerController, "/apps//server") -api.add_resource(AppMCPServerRefreshController, "/apps//server/refresh") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 272f360c06..3bd9c53a85 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -5,7 +5,7 @@ from flask_restx.inputs import int_range from sqlalchemy import exists, select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -37,6 +37,7 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) +@console_ns.route("/apps//chat-messages") class ChatMessageListApi(Resource): message_infinite_scroll_pagination_fields = { "limit": fields.Integer, @@ -44,6 +45,17 @@ class ChatMessageListApi(Resource): "data": fields.List(fields.Nested(message_detail_fields)), } + @api.doc("list_chat_messages") + @api.doc(description="Get chat messages for a conversation with pagination") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID") + .add_argument("first_id", type=str, location="args", help="First message ID for pagination") + .add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)") + ) + @api.response(200, "Success", message_infinite_scroll_pagination_fields) + @api.response(404, "Conversation not found") @setup_required @login_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @@ -117,7 +129,23 @@ class ChatMessageListApi(Resource): return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) +@console_ns.route("/apps//feedbacks") class MessageFeedbackApi(Resource): + @api.doc("create_message_feedback") + @api.doc(description="Create or update message feedback (like/dislike)") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MessageFeedbackRequest", + { + "message_id": fields.String(required=True, description="Message ID"), + "rating": fields.String(enum=["like", "dislike"], description="Feedback rating"), + }, + ) + ) + @api.response(200, "Feedback updated successfully") + @api.response(404, "Message not found") + @api.response(403, "Insufficient permissions") @get_app_model @setup_required @login_required @@ -162,7 +190,24 @@ class MessageFeedbackApi(Resource): return {"result": "success"} +@console_ns.route("/apps//annotations") class MessageAnnotationApi(Resource): + @api.doc("create_message_annotation") + @api.doc(description="Create message annotation") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "MessageAnnotationRequest", + { + "message_id": fields.String(description="Message ID"), + "question": fields.String(required=True, description="Question text"), + "answer": fields.String(required=True, description="Answer text"), + "annotation_reply": fields.Raw(description="Annotation reply"), + }, + ) + ) + @api.response(200, "Annotation created successfully", annotation_fields) + @api.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @@ -172,7 +217,7 @@ class MessageAnnotationApi(Resource): def post(self, app_model): if not isinstance(current_user, Account): raise Forbidden() - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -186,7 +231,16 @@ class MessageAnnotationApi(Resource): return annotation +@console_ns.route("/apps//annotations/count") class MessageAnnotationCountApi(Resource): + @api.doc("get_annotation_count") + @api.doc(description="Get count of message annotations for the app") + @api.doc(params={"app_id": "Application ID"}) + @api.response( + 200, + "Annotation count retrieved successfully", + api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), + ) @get_app_model @setup_required @login_required @@ -197,7 +251,17 @@ class MessageAnnotationCountApi(Resource): return {"count": count} +@console_ns.route("/apps//chat-messages//suggested-questions") class MessageSuggestedQuestionApi(Resource): + @api.doc("get_message_suggested_questions") + @api.doc(description="Get suggested questions for a message") + @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) + @api.response( + 200, + "Suggested questions retrieved successfully", + api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}), + ) + @api.response(404, "Message or conversation not found") @setup_required @login_required @account_initialization_required @@ -230,7 +294,13 @@ class MessageSuggestedQuestionApi(Resource): return {"data": questions} +@console_ns.route("/apps//messages/") class MessageApi(Resource): + @api.doc("get_message") + @api.doc(description="Get message details by ID") + @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) + @api.response(200, "Message retrieved successfully", message_detail_fields) + @api.response(404, "Message not found") @setup_required @login_required @account_initialization_required @@ -245,11 +315,3 @@ class MessageApi(Resource): raise NotFound("Message Not Exists.") return message - - -api.add_resource(MessageSuggestedQuestionApi, "/apps//chat-messages//suggested-questions") -api.add_resource(ChatMessageListApi, "/apps//chat-messages", endpoint="console_chat_messages") -api.add_resource(MessageFeedbackApi, "/apps//feedbacks") -api.add_resource(MessageAnnotationApi, "/apps//annotations") -api.add_resource(MessageAnnotationCountApi, "/apps//annotations/count") -api.add_resource(MessageApi, "/apps//messages/", endpoint="console_message") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 52ff9b923d..11df511840 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -3,9 +3,10 @@ from typing import cast from flask import request from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields +from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.agent.entities import AgentToolEntity @@ -14,17 +15,51 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.login import login_required +from models.account import Account from models.model import AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService +@console_ns.route("/apps//model-config") class ModelConfigResource(Resource): + @api.doc("update_app_model_config") + @api.doc(description="Update application model configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "ModelConfigRequest", + { + "provider": fields.String(description="Model provider"), + "model": fields.String(description="Model name"), + "configs": fields.Raw(description="Model configuration parameters"), + "opening_statement": fields.String(description="Opening statement"), + "suggested_questions": fields.List(fields.String(), description="Suggested questions"), + "more_like_this": fields.Raw(description="More like this configuration"), + "speech_to_text": fields.Raw(description="Speech to text configuration"), + "text_to_speech": fields.Raw(description="Text to speech configuration"), + "retrieval_model": fields.Raw(description="Retrieval model configuration"), + "tools": fields.List(fields.Raw(), description="Available tools"), + "dataset_configs": fields.Raw(description="Dataset configurations"), + "agent_mode": fields.Raw(description="Agent mode configuration"), + }, + ) + ) + @api.response(200, "Model configuration updated successfully") + @api.response(400, "Invalid configuration") + @api.response(404, "App not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): """Modify app model config""" + if not isinstance(current_user, Account): + raise Forbidden() + + if not current_user.has_edit_permission: + raise Forbidden() + + assert current_user.current_tenant_id is not None, "The tenant information should be loaded." # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, @@ -39,7 +74,7 @@ class ModelConfigResource(Resource): ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) - if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: # get original app model config original_app_model_config = ( db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() @@ -142,6 +177,3 @@ class ModelConfigResource(Resource): app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) return {"result": "success"} - - -api.add_resource(ModelConfigResource, "/apps//model-config") diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 74c2867c2f..981974e842 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,18 +1,31 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from werkzeug.exceptions import BadRequest -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.ops_service import OpsService +@console_ns.route("/apps//trace-config") class TraceAppConfigApi(Resource): """ Manage trace app configurations """ + @api.doc("get_trace_app_config") + @api.doc(description="Get tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser().add_argument( + "tracing_provider", type=str, required=True, location="args", help="Tracing provider name" + ) + ) + @api.response( + 200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data") + ) + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -29,6 +42,22 @@ class TraceAppConfigApi(Resource): except Exception as e: raise BadRequest(str(e)) + @api.doc("create_trace_app_config") + @api.doc(description="Create a new tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "TraceConfigCreateRequest", + { + "tracing_provider": fields.String(required=True, description="Tracing provider name"), + "tracing_config": fields.Raw(required=True, description="Tracing configuration data"), + }, + ) + ) + @api.response( + 201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data") + ) + @api.response(400, "Invalid request parameters or configuration already exists") @setup_required @login_required @account_initialization_required @@ -51,6 +80,20 @@ class TraceAppConfigApi(Resource): except Exception as e: raise BadRequest(str(e)) + @api.doc("update_trace_app_config") + @api.doc(description="Update an existing tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "TraceConfigUpdateRequest", + { + "tracing_provider": fields.String(required=True, description="Tracing provider name"), + "tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"), + }, + ) + ) + @api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response")) + @api.response(400, "Invalid request parameters or configuration not found") @setup_required @login_required @account_initialization_required @@ -71,6 +114,16 @@ class TraceAppConfigApi(Resource): except Exception as e: raise BadRequest(str(e)) + @api.doc("delete_trace_app_config") + @api.doc(description="Delete an existing tracing configuration for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser().add_argument( + "tracing_provider", type=str, required=True, location="args", help="Tracing provider name" + ) + ) + @api.response(204, "Tracing configuration deleted successfully") + @api.response(400, "Invalid request parameters or configuration not found") @setup_required @login_required @account_initialization_required @@ -87,6 +140,3 @@ class TraceAppConfigApi(Resource): return {"result": "success"}, 204 except Exception as e: raise BadRequest(str(e)) - - -api.add_resource(TraceAppConfigApi, "/apps//trace-config") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 871efd989c..95befc5df9 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,9 +1,9 @@ from flask_login import current_user -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db @@ -36,7 +36,39 @@ def parse_app_site_args(): return parser.parse_args() +@console_ns.route("/apps//site") class AppSite(Resource): + @api.doc("update_app_site") + @api.doc(description="Update application site configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AppSiteRequest", + { + "title": fields.String(description="Site title"), + "icon_type": fields.String(description="Icon type"), + "icon": fields.String(description="Icon"), + "icon_background": fields.String(description="Icon background color"), + "description": fields.String(description="Site description"), + "default_language": fields.String(description="Default language"), + "chat_color_theme": fields.String(description="Chat color theme"), + "chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"), + "customize_domain": fields.String(description="Custom domain"), + "copyright": fields.String(description="Copyright text"), + "privacy_policy": fields.String(description="Privacy policy"), + "custom_disclaimer": fields.String(description="Custom disclaimer"), + "customize_token_strategy": fields.String( + enum=["must", "allow", "not_allow"], description="Token strategy" + ), + "prompt_public": fields.Boolean(description="Make prompt public"), + "show_workflow_steps": fields.Boolean(description="Show workflow steps"), + "use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"), + }, + ) + ) + @api.response(200, "Site configuration updated successfully", app_site_fields) + @api.response(403, "Insufficient permissions") + @api.response(404, "App not found") @setup_required @login_required @account_initialization_required @@ -84,7 +116,14 @@ class AppSite(Resource): return site +@console_ns.route("/apps//site/access-token-reset") class AppSiteAccessTokenReset(Resource): + @api.doc("reset_app_site_access_token") + @api.doc(description="Reset access token for application site") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Access token reset successfully", app_site_fields) + @api.response(403, "Insufficient permissions (admin/owner required)") + @api.response(404, "App or site not found") @setup_required @login_required @account_initialization_required @@ -108,7 +147,3 @@ class AppSiteAccessTokenReset(Resource): db.session.commit() return site - - -api.add_resource(AppSite, "/apps//site") -api.add_resource(AppSiteAccessTokenReset, "/apps//site/access-token-reset") diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 2116732c73..6894458578 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -5,9 +5,9 @@ import pytz import sqlalchemy as sa from flask import jsonify from flask_login import current_user -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom @@ -17,7 +17,21 @@ from libs.login import login_required from models import AppMode, Message +@console_ns.route("/apps//statistics/daily-messages") class DailyMessageStatistic(Resource): + @api.doc("get_daily_message_statistics") + @api.doc(description="Get daily message statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily message statistics retrieved successfully", + fields.List(fields.Raw(description="Daily message count data")), + ) @get_app_model @setup_required @login_required @@ -74,7 +88,21 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/daily-conversations") class DailyConversationStatistic(Resource): + @api.doc("get_daily_conversation_statistics") + @api.doc(description="Get daily conversation statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily conversation statistics retrieved successfully", + fields.List(fields.Raw(description="Daily conversation count data")), + ) @get_app_model @setup_required @login_required @@ -126,7 +154,21 @@ class DailyConversationStatistic(Resource): return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/daily-end-users") class DailyTerminalsStatistic(Resource): + @api.doc("get_daily_terminals_statistics") + @api.doc(description="Get daily terminal/end-user statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily terminal statistics retrieved successfully", + fields.List(fields.Raw(description="Daily terminal count data")), + ) @get_app_model @setup_required @login_required @@ -183,7 +225,21 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/token-costs") class DailyTokenCostStatistic(Resource): + @api.doc("get_daily_token_cost_statistics") + @api.doc(description="Get daily token cost statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Daily token cost statistics retrieved successfully", + fields.List(fields.Raw(description="Daily token cost data")), + ) @get_app_model @setup_required @login_required @@ -243,7 +299,21 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/average-session-interactions") class AverageSessionInteractionStatistic(Resource): + @api.doc("get_average_session_interaction_statistics") + @api.doc(description="Get average session interaction statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Average session interaction statistics retrieved successfully", + fields.List(fields.Raw(description="Average session interaction data")), + ) @setup_required @login_required @account_initialization_required @@ -319,7 +389,21 @@ ORDER BY return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/user-satisfaction-rate") class UserSatisfactionRateStatistic(Resource): + @api.doc("get_user_satisfaction_rate_statistics") + @api.doc(description="Get user satisfaction rate statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "User satisfaction rate statistics retrieved successfully", + fields.List(fields.Raw(description="User satisfaction rate data")), + ) @get_app_model @setup_required @login_required @@ -385,7 +469,21 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/average-response-time") class AverageResponseTimeStatistic(Resource): + @api.doc("get_average_response_time_statistics") + @api.doc(description="Get average response time statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Average response time statistics retrieved successfully", + fields.List(fields.Raw(description="Average response time data")), + ) @setup_required @login_required @account_initialization_required @@ -442,7 +540,21 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//statistics/tokens-per-second") class TokensPerSecondStatistic(Resource): + @api.doc("get_tokens_per_second_statistics") + @api.doc(description="Get tokens per second statistics for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.parser() + .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") + .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") + ) + @api.response( + 200, + "Tokens per second statistics retrieved successfully", + fields.List(fields.Raw(description="Tokens per second data")), + ) @get_app_model @setup_required @login_required @@ -500,13 +612,3 @@ WHERE response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) return jsonify({"data": response_data}) - - -api.add_resource(DailyMessageStatistic, "/apps//statistics/daily-messages") -api.add_resource(DailyConversationStatistic, "/apps//statistics/daily-conversations") -api.add_resource(DailyTerminalsStatistic, "/apps//statistics/daily-end-users") -api.add_resource(DailyTokenCostStatistic, "/apps//statistics/token-costs") -api.add_resource(AverageSessionInteractionStatistic, "/apps//statistics/average-session-interactions") -api.add_resource(UserSatisfactionRateStatistic, "/apps//statistics/user-satisfaction-rate") -api.add_resource(AverageResponseTimeStatistic, "/apps//statistics/average-response-time") -api.add_resource(TokensPerSecondStatistic, "/apps//statistics/tokens-per-second") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 05178328fe..bbbe1e9ec8 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -4,13 +4,13 @@ from collections.abc import Sequence from typing import cast from flask import abort, request -from flask_restx import Resource, inputs, marshal_with, reqparse +from flask_restx import Resource, fields, inputs, marshal_with, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from configs import dify_config -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required @@ -57,7 +57,13 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence return file_objs +@console_ns.route("/apps//workflows/draft") class DraftWorkflowApi(Resource): + @api.doc("get_draft_workflow") + @api.doc(description="Get draft workflow for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Draft workflow retrieved successfully", workflow_fields) + @api.response(404, "Draft workflow not found") @setup_required @login_required @account_initialization_required @@ -69,7 +75,7 @@ class DraftWorkflowApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() # fetch draft workflow by app_model @@ -86,13 +92,30 @@ class DraftWorkflowApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @api.doc("sync_draft_workflow") + @api.doc(description="Sync draft workflow configuration") + @api.expect( + api.model( + "SyncDraftWorkflowRequest", + { + "graph": fields.Raw(required=True, description="Workflow graph configuration"), + "features": fields.Raw(required=True, description="Workflow features configuration"), + "hash": fields.String(description="Workflow hash for validation"), + "environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"), + "conversation_variables": fields.List(fields.Raw, description="Conversation variables"), + }, + ) + ) + @api.response(200, "Draft workflow synced successfully", workflow_fields) + @api.response(400, "Invalid workflow configuration") + @api.response(403, "Permission denied") def post(self, app_model: App): """ Sync draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() content_type = request.headers.get("Content-Type", "") @@ -159,7 +182,25 @@ class DraftWorkflowApi(Resource): } +@console_ns.route("/apps//advanced-chat/workflows/draft/run") class AdvancedChatDraftWorkflowRunApi(Resource): + @api.doc("run_advanced_chat_draft_workflow") + @api.doc(description="Run draft workflow for advanced chat application") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "AdvancedChatWorkflowRunRequest", + { + "query": fields.String(required=True, description="User query"), + "inputs": fields.Raw(description="Input variables"), + "files": fields.List(fields.Raw, description="File uploads"), + "conversation_id": fields.String(description="Conversation ID"), + }, + ) + ) + @api.response(200, "Workflow run started successfully") + @api.response(400, "Invalid request parameters") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -170,7 +211,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() if not isinstance(current_user, Account): @@ -208,7 +249,23 @@ class AdvancedChatDraftWorkflowRunApi(Resource): raise InternalServerError() +@console_ns.route("/apps//advanced-chat/workflows/draft/iteration/nodes//run") class AdvancedChatDraftRunIterationNodeApi(Resource): + @api.doc("run_advanced_chat_draft_iteration_node") + @api.doc(description="Run draft workflow iteration node for advanced chat") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.expect( + api.model( + "IterationNodeRunRequest", + { + "task_id": fields.String(required=True, description="Task ID"), + "inputs": fields.Raw(description="Input variables"), + }, + ) + ) + @api.response(200, "Iteration node run started successfully") + @api.response(403, "Permission denied") + @api.response(404, "Node not found") @setup_required @login_required @account_initialization_required @@ -220,7 +277,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -244,7 +301,23 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): raise InternalServerError() +@console_ns.route("/apps//workflows/draft/iteration/nodes//run") class WorkflowDraftRunIterationNodeApi(Resource): + @api.doc("run_workflow_draft_iteration_node") + @api.doc(description="Run draft workflow iteration node") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.expect( + api.model( + "WorkflowIterationNodeRunRequest", + { + "task_id": fields.String(required=True, description="Task ID"), + "inputs": fields.Raw(description="Input variables"), + }, + ) + ) + @api.response(200, "Workflow iteration node run started successfully") + @api.response(403, "Permission denied") + @api.response(404, "Node not found") @setup_required @login_required @account_initialization_required @@ -256,7 +329,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not isinstance(current_user, Account): raise Forbidden() - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -280,7 +353,23 @@ class WorkflowDraftRunIterationNodeApi(Resource): raise InternalServerError() +@console_ns.route("/apps//advanced-chat/workflows/draft/loop/nodes//run") class AdvancedChatDraftRunLoopNodeApi(Resource): + @api.doc("run_advanced_chat_draft_loop_node") + @api.doc(description="Run draft workflow loop node for advanced chat") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.expect( + api.model( + "LoopNodeRunRequest", + { + "task_id": fields.String(required=True, description="Task ID"), + "inputs": fields.Raw(description="Input variables"), + }, + ) + ) + @api.response(200, "Loop node run started successfully") + @api.response(403, "Permission denied") + @api.response(404, "Node not found") @setup_required @login_required @account_initialization_required @@ -293,7 +382,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -317,7 +406,23 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): raise InternalServerError() +@console_ns.route("/apps//workflows/draft/loop/nodes//run") class WorkflowDraftRunLoopNodeApi(Resource): + @api.doc("run_workflow_draft_loop_node") + @api.doc(description="Run draft workflow loop node") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.expect( + api.model( + "WorkflowLoopNodeRunRequest", + { + "task_id": fields.String(required=True, description="Task ID"), + "inputs": fields.Raw(description="Input variables"), + }, + ) + ) + @api.response(200, "Workflow loop node run started successfully") + @api.response(403, "Permission denied") + @api.response(404, "Node not found") @setup_required @login_required @account_initialization_required @@ -330,7 +435,7 @@ class WorkflowDraftRunLoopNodeApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -354,7 +459,22 @@ class WorkflowDraftRunLoopNodeApi(Resource): raise InternalServerError() +@console_ns.route("/apps//workflows/draft/run") class DraftWorkflowRunApi(Resource): + @api.doc("run_draft_workflow") + @api.doc(description="Run draft workflow") + @api.doc(params={"app_id": "Application ID"}) + @api.expect( + api.model( + "DraftWorkflowRunRequest", + { + "inputs": fields.Raw(required=True, description="Input variables"), + "files": fields.List(fields.Raw, description="File uploads"), + }, + ) + ) + @api.response(200, "Draft workflow run started successfully") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -367,7 +487,7 @@ class DraftWorkflowRunApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -393,7 +513,14 @@ class DraftWorkflowRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) +@console_ns.route("/apps//workflows/tasks//stop") class WorkflowTaskStopApi(Resource): + @api.doc("stop_workflow_task") + @api.doc(description="Stop running workflow task") + @api.doc(params={"app_id": "Application ID", "task_id": "Task ID"}) + @api.response(200, "Task stopped successfully") + @api.response(404, "Task not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -406,7 +533,7 @@ class WorkflowTaskStopApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) @@ -414,7 +541,22 @@ class WorkflowTaskStopApi(Resource): return {"result": "success"} +@console_ns.route("/apps//workflows/draft/nodes//run") class DraftWorkflowNodeRunApi(Resource): + @api.doc("run_draft_workflow_node") + @api.doc(description="Run draft workflow node") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.expect( + api.model( + "DraftWorkflowNodeRunRequest", + { + "inputs": fields.Raw(description="Input variables"), + }, + ) + ) + @api.response(200, "Node run started successfully", workflow_run_node_execution_fields) + @api.response(403, "Permission denied") + @api.response(404, "Node not found") @setup_required @login_required @account_initialization_required @@ -428,7 +570,7 @@ class DraftWorkflowNodeRunApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -462,7 +604,13 @@ class DraftWorkflowNodeRunApi(Resource): return workflow_node_execution +@console_ns.route("/apps//workflows/publish") class PublishedWorkflowApi(Resource): + @api.doc("get_published_workflow") + @api.doc(description="Get published workflow for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Published workflow retrieved successfully", workflow_fields) + @api.response(404, "Published workflow not found") @setup_required @login_required @account_initialization_required @@ -476,7 +624,7 @@ class PublishedWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() # fetch published workflow by app_model @@ -497,7 +645,7 @@ class PublishedWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -534,7 +682,12 @@ class PublishedWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/default-block-configs") class DefaultBlockConfigsApi(Resource): + @api.doc("get_default_block_configs") + @api.doc(description="Get default block configurations for workflow") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Default block configurations retrieved successfully") @setup_required @login_required @account_initialization_required @@ -547,7 +700,7 @@ class DefaultBlockConfigsApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() # Get default block configs @@ -555,7 +708,13 @@ class DefaultBlockConfigsApi(Resource): return workflow_service.get_default_block_configs() +@console_ns.route("/apps//workflows/default-block-configs/") class DefaultBlockConfigApi(Resource): + @api.doc("get_default_block_config") + @api.doc(description="Get default block configuration by type") + @api.doc(params={"app_id": "Application ID", "block_type": "Block type"}) + @api.response(200, "Default block configuration retrieved successfully") + @api.response(404, "Block type not found") @setup_required @login_required @account_initialization_required @@ -567,7 +726,7 @@ class DefaultBlockConfigApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -588,7 +747,14 @@ class DefaultBlockConfigApi(Resource): return workflow_service.get_default_block_config(node_type=block_type, filters=filters) +@console_ns.route("/apps//convert-to-workflow") class ConvertToWorkflowApi(Resource): + @api.doc("convert_to_workflow") + @api.doc(description="Convert application to workflow mode") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Application converted to workflow successfully") + @api.response(400, "Application cannot be converted") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -602,7 +768,7 @@ class ConvertToWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() if request.data: @@ -625,9 +791,14 @@ class ConvertToWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/config") class WorkflowConfigApi(Resource): """Resource for workflow configuration.""" + @api.doc("get_workflow_config") + @api.doc(description="Get workflow configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Workflow configuration retrieved successfully") @setup_required @login_required @account_initialization_required @@ -638,7 +809,12 @@ class WorkflowConfigApi(Resource): } +@console_ns.route("/apps//workflows/published") class PublishedAllWorkflowApi(Resource): + @api.doc("get_all_published_workflows") + @api.doc(description="Get all published workflows for an application") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Published workflows retrieved successfully", workflow_pagination_fields) @setup_required @login_required @account_initialization_required @@ -651,7 +827,7 @@ class PublishedAllWorkflowApi(Resource): if not isinstance(current_user, Account): raise Forbidden() - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -689,7 +865,23 @@ class PublishedAllWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/") class WorkflowByIdApi(Resource): + @api.doc("update_workflow_by_id") + @api.doc(description="Update workflow by ID") + @api.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"}) + @api.expect( + api.model( + "UpdateWorkflowRequest", + { + "environment_variables": fields.List(fields.Raw, description="Environment variables"), + "conversation_variables": fields.List(fields.Raw, description="Conversation variables"), + }, + ) + ) + @api.response(200, "Workflow updated successfully", workflow_fields) + @api.response(404, "Workflow not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -702,7 +894,7 @@ class WorkflowByIdApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # Check permission - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -715,7 +907,6 @@ class WorkflowByIdApi(Resource): raise ValueError("Marked name cannot exceed 20 characters") if args.marked_comment and len(args.marked_comment) > 100: raise ValueError("Marked comment cannot exceed 100 characters") - args = parser.parse_args() # Prepare update data update_data = {} @@ -758,7 +949,7 @@ class WorkflowByIdApi(Resource): if not isinstance(current_user, Account): raise Forbidden() # Check permission - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() workflow_service = WorkflowService() @@ -781,7 +972,14 @@ class WorkflowByIdApi(Resource): return None, 204 +@console_ns.route("/apps//workflows/draft/nodes//last-run") class DraftWorkflowNodeLastRunApi(Resource): + @api.doc("get_draft_workflow_node_last_run") + @api.doc(description="Get last run result for draft workflow node") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.response(200, "Node last run retrieved successfully", workflow_run_node_execution_fields) + @api.response(404, "Node last run not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -800,73 +998,3 @@ class DraftWorkflowNodeLastRunApi(Resource): if node_exec is None: raise NotFound("last run not found") return node_exec - - -api.add_resource( - DraftWorkflowApi, - "/apps//workflows/draft", -) -api.add_resource( - WorkflowConfigApi, - "/apps//workflows/draft/config", -) -api.add_resource( - AdvancedChatDraftWorkflowRunApi, - "/apps//advanced-chat/workflows/draft/run", -) -api.add_resource( - DraftWorkflowRunApi, - "/apps//workflows/draft/run", -) -api.add_resource( - WorkflowTaskStopApi, - "/apps//workflow-runs/tasks//stop", -) -api.add_resource( - DraftWorkflowNodeRunApi, - "/apps//workflows/draft/nodes//run", -) -api.add_resource( - AdvancedChatDraftRunIterationNodeApi, - "/apps//advanced-chat/workflows/draft/iteration/nodes//run", -) -api.add_resource( - WorkflowDraftRunIterationNodeApi, - "/apps//workflows/draft/iteration/nodes//run", -) -api.add_resource( - AdvancedChatDraftRunLoopNodeApi, - "/apps//advanced-chat/workflows/draft/loop/nodes//run", -) -api.add_resource( - WorkflowDraftRunLoopNodeApi, - "/apps//workflows/draft/loop/nodes//run", -) -api.add_resource( - PublishedWorkflowApi, - "/apps//workflows/publish", -) -api.add_resource( - PublishedAllWorkflowApi, - "/apps//workflows", -) -api.add_resource( - DefaultBlockConfigsApi, - "/apps//workflows/default-workflow-block-configs", -) -api.add_resource( - DefaultBlockConfigApi, - "/apps//workflows/default-workflow-block-configs/", -) -api.add_resource( - ConvertToWorkflowApi, - "/apps//convert-to-workflow", -) -api.add_resource( - WorkflowByIdApi, - "/apps//workflows/", -) -api.add_resource( - DraftWorkflowNodeLastRunApi, - "/apps//workflows/draft/nodes//last-run", -) diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 76f02041ef..eb64faf6a5 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -3,7 +3,7 @@ from flask_restx import Resource, marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy.orm import Session -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.workflow.entities.workflow_execution import WorkflowExecutionStatus @@ -15,7 +15,24 @@ from models.model import AppMode from services.workflow_app_service import WorkflowAppService +@console_ns.route("/apps//workflow-app-logs") class WorkflowAppLogApi(Resource): + @api.doc("get_workflow_app_logs") + @api.doc(description="Get workflow application execution logs") + @api.doc(params={"app_id": "Application ID"}) + @api.doc( + params={ + "keyword": "Search keyword for filtering logs", + "status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)", + "created_at__before": "Filter logs created before this timestamp", + "created_at__after": "Filter logs created after this timestamp", + "created_by_end_user_session_id": "Filter by end user session ID", + "created_by_account": "Filter by account", + "page": "Page number (1-99999)", + "limit": "Number of items per page (1-100)", + } + ) + @api.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_fields) @setup_required @login_required @account_initialization_required @@ -78,6 +95,3 @@ class WorkflowAppLogApi(Resource): ) return workflow_app_log_pagination - - -api.add_resource(WorkflowAppLogApi, "/apps//workflow-app-logs") diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 5fced3e90f..eff25eb2e5 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -6,7 +6,7 @@ from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqpars from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( DraftWorkflowNotExist, ) @@ -137,14 +137,20 @@ def _api_prerequisite(f): @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) def wrapper(*args, **kwargs): assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() return f(*args, **kwargs) return wrapper +@console_ns.route("/apps//workflows/draft/variables") class WorkflowVariableCollectionApi(Resource): + @api.doc("get_workflow_variables") + @api.doc(description="Get draft workflow variables") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"}) + @api.response(200, "Workflow variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) def get(self, app_model: App): @@ -173,6 +179,9 @@ class WorkflowVariableCollectionApi(Resource): return workflow_vars + @api.doc("delete_workflow_variables") + @api.doc(description="Delete all draft workflow variables") + @api.response(204, "Workflow variables deleted successfully") @_api_prerequisite def delete(self, app_model: App): draft_var_srv = WorkflowDraftVariableService( @@ -201,7 +210,12 @@ def validate_node_id(node_id: str) -> NoReturn | None: return None +@console_ns.route("/apps//workflows/draft/nodes//variables") class NodeVariableCollectionApi(Resource): + @api.doc("get_node_variables") + @api.doc(description="Get variables for a specific node") + @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) + @api.response(200, "Node variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) def get(self, app_model: App, node_id: str): @@ -214,6 +228,9 @@ class NodeVariableCollectionApi(Resource): return node_vars + @api.doc("delete_node_variables") + @api.doc(description="Delete all variables for a specific node") + @api.response(204, "Node variables deleted successfully") @_api_prerequisite def delete(self, app_model: App, node_id: str): validate_node_id(node_id) @@ -223,10 +240,16 @@ class NodeVariableCollectionApi(Resource): return Response("", 204) +@console_ns.route("/apps//workflows/draft/variables/") class VariableApi(Resource): _PATCH_NAME_FIELD = "name" _PATCH_VALUE_FIELD = "value" + @api.doc("get_variable") + @api.doc(description="Get a specific workflow variable") + @api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"}) + @api.response(200, "Variable retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS) + @api.response(404, "Variable not found") @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) def get(self, app_model: App, variable_id: str): @@ -240,6 +263,19 @@ class VariableApi(Resource): raise NotFoundError(description=f"variable not found, id={variable_id}") return variable + @api.doc("update_variable") + @api.doc(description="Update a workflow variable") + @api.expect( + api.model( + "UpdateVariableRequest", + { + "name": fields.String(description="Variable name"), + "value": fields.Raw(description="Variable value"), + }, + ) + ) + @api.response(200, "Variable updated successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS) + @api.response(404, "Variable not found") @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) def patch(self, app_model: App, variable_id: str): @@ -302,6 +338,10 @@ class VariableApi(Resource): db.session.commit() return variable + @api.doc("delete_variable") + @api.doc(description="Delete a workflow variable") + @api.response(204, "Variable deleted successfully") + @api.response(404, "Variable not found") @_api_prerequisite def delete(self, app_model: App, variable_id: str): draft_var_srv = WorkflowDraftVariableService( @@ -317,7 +357,14 @@ class VariableApi(Resource): return Response("", 204) +@console_ns.route("/apps//workflows/draft/variables//reset") class VariableResetApi(Resource): + @api.doc("reset_variable") + @api.doc(description="Reset a workflow variable to its default value") + @api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"}) + @api.response(200, "Variable reset successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS) + @api.response(204, "Variable reset (no content)") + @api.response(404, "Variable not found") @_api_prerequisite def put(self, app_model: App, variable_id: str): draft_var_srv = WorkflowDraftVariableService( @@ -358,7 +405,13 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: return draft_vars +@console_ns.route("/apps//workflows/draft/conversation-variables") class ConversationVariableCollectionApi(Resource): + @api.doc("get_conversation_variables") + @api.doc(description="Get conversation variables for workflow") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Conversation variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + @api.response(404, "Draft workflow not found") @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) def get(self, app_model: App): @@ -374,14 +427,25 @@ class ConversationVariableCollectionApi(Resource): return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) +@console_ns.route("/apps//workflows/draft/system-variables") class SystemVariableCollectionApi(Resource): + @api.doc("get_system_variables") + @api.doc(description="Get system variables for workflow") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "System variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) def get(self, app_model: App): return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID) +@console_ns.route("/apps//workflows/draft/environment-variables") class EnvironmentVariableCollectionApi(Resource): + @api.doc("get_environment_variables") + @api.doc(description="Get environment variables for workflow") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Environment variables retrieved successfully") + @api.response(404, "Draft workflow not found") @_api_prerequisite def get(self, app_model: App): """ @@ -413,16 +477,3 @@ class EnvironmentVariableCollectionApi(Resource): ) return {"items": env_vars_list} - - -api.add_resource( - WorkflowVariableCollectionApi, - "/apps//workflows/draft/variables", -) -api.add_resource(NodeVariableCollectionApi, "/apps//workflows/draft/nodes//variables") -api.add_resource(VariableApi, "/apps//workflows/draft/variables/") -api.add_resource(VariableResetApi, "/apps//workflows/draft/variables//reset") - -api.add_resource(ConversationVariableCollectionApi, "/apps//workflows/draft/conversation-variables") -api.add_resource(SystemVariableCollectionApi, "/apps//workflows/draft/system-variables") -api.add_resource(EnvironmentVariableCollectionApi, "/apps//workflows/draft/environment-variables") diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index dccbfd8648..23ba63845c 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -4,7 +4,7 @@ from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from flask_restx.inputs import int_range -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from fields.workflow_run_fields import ( @@ -19,7 +19,13 @@ from models import Account, App, AppMode, EndUser from services.workflow_run_service import WorkflowRunService +@console_ns.route("/apps//advanced-chat/workflow-runs") class AdvancedChatAppWorkflowRunListApi(Resource): + @api.doc("get_advanced_chat_workflow_runs") + @api.doc(description="Get advanced chat workflow run list") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) + @api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields) @setup_required @login_required @account_initialization_required @@ -40,7 +46,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource): return result +@console_ns.route("/apps//workflow-runs") class WorkflowRunListApi(Resource): + @api.doc("get_workflow_runs") + @api.doc(description="Get workflow run list") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) + @api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields) @setup_required @login_required @account_initialization_required @@ -61,7 +73,13 @@ class WorkflowRunListApi(Resource): return result +@console_ns.route("/apps//workflow-runs/") class WorkflowRunDetailApi(Resource): + @api.doc("get_workflow_run_detail") + @api.doc(description="Get workflow run detail") + @api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @api.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_fields) + @api.response(404, "Workflow run not found") @setup_required @login_required @account_initialization_required @@ -79,7 +97,13 @@ class WorkflowRunDetailApi(Resource): return workflow_run +@console_ns.route("/apps//workflow-runs//node-executions") class WorkflowRunNodeExecutionListApi(Resource): + @api.doc("get_workflow_run_node_executions") + @api.doc(description="Get workflow run node execution list") + @api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @api.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_fields) + @api.response(404, "Workflow run not found") @setup_required @login_required @account_initialization_required @@ -100,9 +124,3 @@ class WorkflowRunNodeExecutionListApi(Resource): ) return {"data": node_executions} - - -api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps//advanced-chat/workflow-runs") -api.add_resource(WorkflowRunListApi, "/apps//workflow-runs") -api.add_resource(WorkflowRunDetailApi, "/apps//workflow-runs/") -api.add_resource(WorkflowRunNodeExecutionListApi, "/apps//workflow-runs//node-executions") diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index da7216086e..535e7cadd6 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -7,7 +7,7 @@ from flask import jsonify from flask_login import current_user from flask_restx import Resource, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db @@ -17,7 +17,13 @@ from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode +@console_ns.route("/apps//workflow/statistics/daily-conversations") class WorkflowDailyRunsStatistic(Resource): + @api.doc("get_workflow_daily_runs_statistic") + @api.doc(description="Get workflow daily runs statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) + @api.response(200, "Daily runs statistics retrieved successfully") @get_app_model @setup_required @login_required @@ -79,7 +85,13 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//workflow/statistics/daily-terminals") class WorkflowDailyTerminalsStatistic(Resource): + @api.doc("get_workflow_daily_terminals_statistic") + @api.doc(description="Get workflow daily terminals statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) + @api.response(200, "Daily terminals statistics retrieved successfully") @get_app_model @setup_required @login_required @@ -141,7 +153,13 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//workflow/statistics/token-costs") class WorkflowDailyTokenCostStatistic(Resource): + @api.doc("get_workflow_daily_token_cost_statistic") + @api.doc(description="Get workflow daily token cost statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) + @api.response(200, "Daily token cost statistics retrieved successfully") @get_app_model @setup_required @login_required @@ -208,7 +226,13 @@ WHERE return jsonify({"data": response_data}) +@console_ns.route("/apps//workflow/statistics/average-app-interactions") class WorkflowAverageAppInteractionStatistic(Resource): + @api.doc("get_workflow_average_app_interaction_statistic") + @api.doc(description="Get workflow average app interaction statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) + @api.response(200, "Average app interaction statistics retrieved successfully") @setup_required @login_required @account_initialization_required @@ -285,11 +309,3 @@ GROUP BY ) return jsonify({"data": response_data}) - - -api.add_resource(WorkflowDailyRunsStatistic, "/apps//workflow/statistics/daily-conversations") -api.add_resource(WorkflowDailyTerminalsStatistic, "/apps//workflow/statistics/daily-terminals") -api.add_resource(WorkflowDailyTokenCostStatistic, "/apps//workflow/statistics/token-costs") -api.add_resource( - WorkflowAverageAppInteractionStatistic, "/apps//workflow/statistics/average-app-interactions" -) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 5a871f896a..44aba01820 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional, ParamSpec, TypeVar, Union +from typing import ParamSpec, TypeVar, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db @@ -12,7 +12,7 @@ P = ParamSpec("P") R = TypeVar("R") -def _load_app_model(app_id: str) -> Optional[App]: +def _load_app_model(app_id: str) -> App | None: assert isinstance(current_user, Account) app_model = ( db.session.query(App) @@ -22,7 +22,7 @@ def _load_app_model(app_id: str) -> Optional[App]: return app_model -def get_app_model(view: Optional[Callable[P, R]] = None, *, mode: Union[AppMode, list[AppMode], None] = None): +def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): def decorator(view_func: Callable[P, R]): @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py new file mode 100644 index 0000000000..91de19a78a --- /dev/null +++ b/api/controllers/console/auth/email_register.py @@ -0,0 +1,155 @@ +from flask import request +from flask_restx import Resource, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from constants.languages import languages +from controllers.console import api +from controllers.console.auth.error import ( + EmailAlreadyInUseError, + EmailCodeError, + EmailRegisterLimitError, + InvalidEmailError, + InvalidTokenError, + PasswordMismatchError, +) +from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError +from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required +from extensions.ext_database import db +from libs.helper import email, extract_remote_ip +from libs.password import valid_password +from models.account import Account +from services.account_service import AccountService +from services.billing_service import BillingService +from services.errors.account import AccountNotFoundError, AccountRegisterError + + +class EmailRegisterSendEmailApi(Resource): + @setup_required + @email_password_login_enabled + @email_register_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + language = "en-US" + if args["language"] in languages: + language = args["language"] + + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): + raise AccountInFreezeError() + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + token = None + token = AccountService.send_email_register_email(email=args["email"], account=account, language=language) + return {"result": "success", "data": token} + + +class EmailRegisterCheckApi(Resource): + @setup_required + @email_password_login_enabled + @email_register_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + user_email = args["email"] + + is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"]) + if is_email_register_error_rate_limit: + raise EmailRegisterLimitError() + + token_data = AccountService.get_email_register_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_email_register_error_rate_limit(args["email"]) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_email_register_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_email_register_token( + user_email, code=args["code"], additional_data={"phase": "register"} + ) + + AccountService.reset_email_register_error_rate_limit(args["email"]) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class EmailRegisterResetApi(Resource): + @setup_required + @email_password_login_enabled + @email_register_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + args = parser.parse_args() + + # Validate passwords match + if args["new_password"] != args["password_confirm"]: + raise PasswordMismatchError() + + # Validate token and get register data + register_data = AccountService.get_email_register_data(args["token"]) + if not register_data: + raise InvalidTokenError() + # Must use token in reset phase + if register_data.get("phase", "") != "register": + raise InvalidTokenError() + + # Revoke token to prevent reuse + AccountService.revoke_email_register_token(args["token"]) + + email = register_data.get("email", "") + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + + if account: + raise EmailAlreadyInUseError() + else: + account = self._create_new_account(email, args["password_confirm"]) + if not account: + raise AccountNotFoundError() + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(email) + + return {"result": "success", "data": token_pair.model_dump()} + + def _create_new_account(self, email, password) -> Account | None: + # Create new account if allowed + account = None + try: + account = AccountService.create_account_and_tenant( + email=email, + name=email, + password=password, + interface_language=languages[0], + ) + except AccountRegisterError: + raise AccountInFreezeError() + + return account + + +api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email") +api.add_resource(EmailRegisterCheckApi, "/email-register/validity") +api.add_resource(EmailRegisterResetApi, "/email-register") diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index 7853bef917..81f1c6e70f 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -27,21 +27,43 @@ class InvalidTokenError(BaseHTTPException): class PasswordResetRateLimitExceededError(BaseHTTPException): error_code = "password_reset_rate_limit_exceeded" - description = "Too many password reset emails have been sent. Please try again in 1 minute." + description = "Too many password reset emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + + +class EmailRegisterRateLimitExceededError(BaseHTTPException): + error_code = "email_register_rate_limit_exceeded" + description = "Too many email register emails have been sent. Please try again in {minutes} minutes." + code = 429 + + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailChangeRateLimitExceededError(BaseHTTPException): error_code = "email_change_rate_limit_exceeded" - description = "Too many email change emails have been sent. Please try again in 1 minute." + description = "Too many email change emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class OwnerTransferRateLimitExceededError(BaseHTTPException): error_code = "owner_transfer_rate_limit_exceeded" - description = "Too many owner transfer emails have been sent. Please try again in 1 minute." + description = "Too many owner transfer emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 1): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailCodeError(BaseHTTPException): error_code = "email_code_error" @@ -69,15 +91,23 @@ class EmailPasswordLoginLimitError(BaseHTTPException): class EmailCodeLoginRateLimitExceededError(BaseHTTPException): error_code = "email_code_login_rate_limit_exceeded" - description = "Too many login emails have been sent. Please try again in 5 minutes." + description = "Too many login emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 5): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException): error_code = "email_code_account_deletion_rate_limit_exceeded" - description = "Too many account deletion emails have been sent. Please try again in 5 minutes." + description = "Too many account deletion emails have been sent. Please try again in {minutes} minutes." code = 429 + def __init__(self, minutes: int = 5): + description = self.description.format(minutes=int(minutes)) if self.description else None + super().__init__(description=description) + class EmailPasswordResetLimitError(BaseHTTPException): error_code = "email_password_reset_limit" @@ -85,6 +115,12 @@ class EmailPasswordResetLimitError(BaseHTTPException): code = 429 +class EmailRegisterLimitError(BaseHTTPException): + error_code = "email_register_limit" + description = "Too many failed email register attempts. Please try again in 24 hours." + code = 429 + + class EmailChangeLimitError(BaseHTTPException): error_code = "email_change_limit" description = "Too many failed email change attempts. Please try again in 24 hours." diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 7f34adc0f3..36ccb1d562 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -6,7 +6,6 @@ from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session -from constants.languages import languages from controllers.console import api, console_ns from controllers.console.auth.error import ( EmailCodeError, @@ -15,7 +14,7 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) -from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError +from controllers.console.error import AccountNotFound, EmailSendIpLimitError from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db @@ -23,8 +22,6 @@ from libs.helper import email, extract_remote_ip from libs.password import hash_password, valid_password from models.account import Account from services.account_service import AccountService, TenantService -from services.errors.account import AccountRegisterError -from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService @@ -73,15 +70,13 @@ class ForgotPasswordSendEmailApi(Resource): with Session(db.engine) as session: account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() - token = None - if account is None: - if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_reset_password_email(email=args["email"], language=language) - return {"result": "fail", "data": token, "code": "account_not_found"} - else: - raise AccountNotFound() - else: - token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) + + token = AccountService.send_reset_password_email( + account=account, + email=args["email"], + language=language, + is_allow_register=FeatureService.get_system_features().is_allow_register, + ) return {"result": "success", "data": token} @@ -207,7 +202,7 @@ class ForgotPasswordResetApi(Resource): if account: self._update_existing_account(account, password_hashed, salt, session) else: - self._create_new_account(email, args["password_confirm"]) + raise AccountNotFound() return {"result": "success"} @@ -227,18 +222,7 @@ class ForgotPasswordResetApi(Resource): account.current_tenant = tenant tenant_was_created.send(tenant) - def _create_new_account(self, email, password): - # Create new account if allowed - try: - AccountService.create_account_and_tenant( - email=email, - name=email, - password=password, - interface_language=languages[0], - ) - except WorkSpaceNotAllowedCreateError: - pass - except WorkspacesLimitExceededError: - pass - except AccountRegisterError: - raise AccountInFreezeError() + +api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") +api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") +api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index b11bc0c6ac..3b35ab3c23 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -26,7 +26,6 @@ from controllers.console.error import ( from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip -from libs.password import valid_password from models.account import Account from services.account_service import AccountService, RegisterService, TenantService from services.billing_service import BillingService @@ -44,10 +43,9 @@ class LoginApi(Resource): """Authenticate user and login.""" parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("password", type=valid_password, required=True, location="json") + parser.add_argument("password", type=str, required=True, location="json") parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") parser.add_argument("invite_token", type=str, required=False, default=None, location="json") - parser.add_argument("language", type=str, required=False, default="en-US", location="json") args = parser.parse_args() if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): @@ -61,11 +59,6 @@ class LoginApi(Resource): if invitation: invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) - if args["language"] is not None and args["language"] == "zh-Hans": - language = "zh-Hans" - else: - language = "en-US" - try: if invitation: data = invitation.get("data", {}) @@ -80,12 +73,6 @@ class LoginApi(Resource): except services.errors.account.AccountPasswordError: AccountService.add_login_error_rate_limit(args["email"]) raise AuthenticationFailedError() - except services.errors.account.AccountNotFoundError: - if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_reset_password_email(email=args["email"], language=language) - return {"result": "fail", "data": token, "code": "account_not_found"} - else: - raise AccountNotFound() # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: @@ -133,13 +120,12 @@ class ResetPasswordSendEmailApi(Resource): except AccountRegisterError: raise AccountInFreezeError() - if account is None: - if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_reset_password_email(email=args["email"], language=language) - else: - raise AccountNotFound() - else: - token = AccountService.send_reset_password_email(account=account, language=language) + token = AccountService.send_reset_password_email( + email=args["email"], + account=account, + language=language, + is_allow_register=FeatureService.get_system_features().is_allow_register, + ) return {"result": "success", "data": token} diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index c3c9de1589..1602ee6eea 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import requests from flask import current_app, redirect, request @@ -18,6 +17,7 @@ from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models import Account from models.account import AccountStatus from services.account_service import AccountService, RegisterService, TenantService +from services.billing_service import BillingService from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.feature_service import FeatureService @@ -156,8 +156,8 @@ class OAuthCallback(Resource): ) -def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: - account: Optional[Account] = Account.get_by_openid(provider, user_info.id) +def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None: + account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: with Session(db.engine) as session: @@ -183,7 +183,15 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): if not account: if not FeatureService.get_system_features().is_allow_register: - raise AccountNotFoundError() + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email): + raise AccountRegisterError( + description=( + "This email account has been deleted within the past " + "30 days and is temporarily unavailable for new account registration" + ) + ) + else: + raise AccountRegisterError(description=("Invalid email or password")) account_name = user_info.name or "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 9fb092607a..6ed3d39a2b 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,13 +1,13 @@ import flask_restx from flask import request from flask_login import current_user -from flask_restx import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with, reqparse from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services from configs import dify_config -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError @@ -48,7 +48,21 @@ def _validate_description_length(description): return description +@console_ns.route("/datasets") class DatasetListApi(Resource): + @api.doc("get_datasets") + @api.doc(description="Get list of datasets") + @api.doc( + params={ + "page": "Page number (default: 1)", + "limit": "Number of items per page (default: 20)", + "ids": "Filter by dataset IDs (list)", + "keyword": "Search keyword", + "tag_ids": "Filter by tag IDs (list)", + "include_all": "Include all datasets (default: false)", + } + ) + @api.response(200, "Datasets retrieved successfully") @setup_required @login_required @account_initialization_required @@ -100,6 +114,24 @@ class DatasetListApi(Resource): response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 + @api.doc("create_dataset") + @api.doc(description="Create a new dataset") + @api.expect( + api.model( + "CreateDatasetRequest", + { + "name": fields.String(required=True, description="Dataset name (1-40 characters)"), + "description": fields.String(description="Dataset description (max 400 characters)"), + "indexing_technique": fields.String(description="Indexing technique"), + "permission": fields.String(description="Dataset permission"), + "provider": fields.String(description="Provider"), + "external_knowledge_api_id": fields.String(description="External knowledge API ID"), + "external_knowledge_id": fields.String(description="External knowledge ID"), + }, + ) + ) + @api.response(201, "Dataset created successfully") + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -172,7 +204,14 @@ class DatasetListApi(Resource): return marshal(dataset, dataset_detail_fields), 201 +@console_ns.route("/datasets/") class DatasetApi(Resource): + @api.doc("get_dataset") + @api.doc(description="Get dataset details") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Dataset retrieved successfully", dataset_detail_fields) + @api.response(404, "Dataset not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -215,6 +254,23 @@ class DatasetApi(Resource): return data, 200 + @api.doc("update_dataset") + @api.doc(description="Update dataset details") + @api.expect( + api.model( + "UpdateDatasetRequest", + { + "name": fields.String(description="Dataset name"), + "description": fields.String(description="Dataset description"), + "permission": fields.String(description="Dataset permission"), + "indexing_technique": fields.String(description="Indexing technique"), + "external_retrieval_model": fields.Raw(description="External retrieval model settings"), + }, + ) + ) + @api.response(200, "Dataset updated successfully", dataset_detail_fields) + @api.response(404, "Dataset not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -344,7 +400,12 @@ class DatasetApi(Resource): raise DatasetInUseError() +@console_ns.route("/datasets//use-check") class DatasetUseCheckApi(Resource): + @api.doc("check_dataset_use") + @api.doc(description="Check if dataset is in use") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Dataset use status retrieved successfully") @setup_required @login_required @account_initialization_required @@ -355,7 +416,12 @@ class DatasetUseCheckApi(Resource): return {"is_using": dataset_is_using}, 200 +@console_ns.route("/datasets//queries") class DatasetQueryApi(Resource): + @api.doc("get_dataset_queries") + @api.doc(description="Get dataset query history") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Query history retrieved successfully", dataset_query_detail_fields) @setup_required @login_required @account_initialization_required @@ -385,7 +451,11 @@ class DatasetQueryApi(Resource): return response, 200 +@console_ns.route("/datasets/indexing-estimate") class DatasetIndexingEstimateApi(Resource): + @api.doc("estimate_dataset_indexing") + @api.doc(description="Estimate dataset indexing cost") + @api.response(200, "Indexing estimate calculated successfully") @setup_required @login_required @account_initialization_required @@ -486,7 +556,12 @@ class DatasetIndexingEstimateApi(Resource): return response.model_dump(), 200 +@console_ns.route("/datasets//related-apps") class DatasetRelatedAppListApi(Resource): + @api.doc("get_dataset_related_apps") + @api.doc(description="Get applications related to dataset") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Related apps retrieved successfully", related_app_list) @setup_required @login_required @account_initialization_required @@ -513,7 +588,12 @@ class DatasetRelatedAppListApi(Resource): return {"data": related_apps, "total": len(related_apps)}, 200 +@console_ns.route("/datasets//indexing-status") class DatasetIndexingStatusApi(Resource): + @api.doc("get_dataset_indexing_status") + @api.doc(description="Get dataset indexing status") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Indexing status retrieved successfully") @setup_required @login_required @account_initialization_required @@ -560,11 +640,15 @@ class DatasetIndexingStatusApi(Resource): return data, 200 +@console_ns.route("/datasets/api-keys") class DatasetApiKeyApi(Resource): max_keys = 10 token_prefix = "dataset-" resource_type = "dataset" + @api.doc("get_dataset_api_keys") + @api.doc(description="Get dataset API keys") + @api.response(200, "API keys retrieved successfully", api_key_list) @setup_required @login_required @account_initialization_required @@ -609,9 +693,14 @@ class DatasetApiKeyApi(Resource): return api_token, 200 +@console_ns.route("/datasets/api-keys/") class DatasetApiDeleteApi(Resource): resource_type = "dataset" + @api.doc("delete_dataset_api_key") + @api.doc(description="Delete dataset API key") + @api.doc(params={"api_key_id": "API key ID"}) + @api.response(204, "API key deleted successfully") @setup_required @login_required @account_initialization_required @@ -641,7 +730,11 @@ class DatasetApiDeleteApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/datasets/api-base-info") class DatasetApiBaseUrlApi(Resource): + @api.doc("get_dataset_api_base_info") + @api.doc(description="Get dataset API base information") + @api.response(200, "API base info retrieved successfully") @setup_required @login_required @account_initialization_required @@ -649,7 +742,11 @@ class DatasetApiBaseUrlApi(Resource): return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} +@console_ns.route("/datasets/retrieval-setting") class DatasetRetrievalSettingApi(Resource): + @api.doc("get_dataset_retrieval_setting") + @api.doc(description="Get dataset retrieval settings") + @api.response(200, "Retrieval settings retrieved successfully") @setup_required @login_required @account_initialization_required @@ -700,7 +797,12 @@ class DatasetRetrievalSettingApi(Resource): raise ValueError(f"Unsupported vector db type {vector_type}.") +@console_ns.route("/datasets/retrieval-setting/") class DatasetRetrievalSettingMockApi(Resource): + @api.doc("get_dataset_retrieval_setting_mock") + @api.doc(description="Get mock dataset retrieval settings by vector type") + @api.doc(params={"vector_type": "Vector store type"}) + @api.response(200, "Mock retrieval settings retrieved successfully") @setup_required @login_required @account_initialization_required @@ -749,7 +851,13 @@ class DatasetRetrievalSettingMockApi(Resource): raise ValueError(f"Unsupported vector db type {vector_type}.") +@console_ns.route("/datasets//error-docs") class DatasetErrorDocs(Resource): + @api.doc("get_dataset_error_docs") + @api.doc(description="Get dataset error documents") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Error documents retrieved successfully") + @api.response(404, "Dataset not found") @setup_required @login_required @account_initialization_required @@ -763,7 +871,14 @@ class DatasetErrorDocs(Resource): return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200 +@console_ns.route("/datasets//permission-part-users") class DatasetPermissionUserListApi(Resource): + @api.doc("get_dataset_permission_users") + @api.doc(description="Get dataset permission user list") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Permission users retrieved successfully") + @api.response(404, "Dataset not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -784,7 +899,13 @@ class DatasetPermissionUserListApi(Resource): }, 200 +@console_ns.route("/datasets//auto-disable-logs") class DatasetAutoDisableLogApi(Resource): + @api.doc("get_dataset_auto_disable_logs") + @api.doc(description="Get dataset auto disable logs") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.response(200, "Auto disable logs retrieved successfully") + @api.response(404, "Dataset not found") @setup_required @login_required @account_initialization_required @@ -794,20 +915,3 @@ class DatasetAutoDisableLogApi(Resource): if dataset is None: raise NotFound("Dataset not found.") return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200 - - -api.add_resource(DatasetListApi, "/datasets") -api.add_resource(DatasetApi, "/datasets/") -api.add_resource(DatasetUseCheckApi, "/datasets//use-check") -api.add_resource(DatasetQueryApi, "/datasets//queries") -api.add_resource(DatasetErrorDocs, "/datasets//error-docs") -api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate") -api.add_resource(DatasetRelatedAppListApi, "/datasets//related-apps") -api.add_resource(DatasetIndexingStatusApi, "/datasets//indexing-status") -api.add_resource(DatasetApiKeyApi, "/datasets/api-keys") -api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/") -api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") -api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") -api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") -api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users") -api.add_resource(DatasetAutoDisableLogApi, "/datasets//auto-disable-logs") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index f943fb3ccb..0b65967445 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -5,12 +5,12 @@ from typing import Literal, cast from flask import request from flask_login import current_user -from flask_restx import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with, reqparse from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound import services -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.app.error import ( ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -98,7 +98,12 @@ class DocumentResource(Resource): return documents +@console_ns.route("/datasets/process-rule") class GetProcessRuleApi(Resource): + @api.doc("get_process_rule") + @api.doc(description="Get dataset document processing rules") + @api.doc(params={"document_id": "Document ID (optional)"}) + @api.response(200, "Process rules retrieved successfully") @setup_required @login_required @account_initialization_required @@ -140,7 +145,21 @@ class GetProcessRuleApi(Resource): return {"mode": mode, "rules": rules, "limits": limits} +@console_ns.route("/datasets//documents") class DatasetDocumentListApi(Resource): + @api.doc("get_dataset_documents") + @api.doc(description="Get documents in a dataset") + @api.doc( + params={ + "dataset_id": "Dataset ID", + "page": "Page number (default: 1)", + "limit": "Number of items per page (default: 20)", + "keyword": "Search keyword", + "sort": "Sort order (default: -created_at)", + "fetch": "Fetch full details (default: false)", + } + ) + @api.response(200, "Documents retrieved successfully") @setup_required @login_required @account_initialization_required @@ -324,7 +343,23 @@ class DatasetDocumentListApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/datasets/init") class DatasetInitApi(Resource): + @api.doc("init_dataset") + @api.doc(description="Initialize dataset with documents") + @api.expect( + api.model( + "DatasetInitRequest", + { + "upload_file_id": fields.String(required=True, description="Upload file ID"), + "indexing_technique": fields.String(description="Indexing technique"), + "process_rule": fields.Raw(description="Processing rules"), + "data_source": fields.Raw(description="Data source configuration"), + }, + ) + ) + @api.response(201, "Dataset initialized successfully", dataset_and_document_fields) + @api.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required @@ -394,7 +429,14 @@ class DatasetInitApi(Resource): return response +@console_ns.route("/datasets//documents//indexing-estimate") class DocumentIndexingEstimateApi(DocumentResource): + @api.doc("estimate_document_indexing") + @api.doc(description="Estimate document indexing cost") + @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @api.response(200, "Indexing estimate calculated successfully") + @api.response(404, "Document not found") + @api.response(400, "Document already finished") @setup_required @login_required @account_initialization_required @@ -593,7 +635,13 @@ class DocumentBatchIndexingStatusApi(DocumentResource): return data +@console_ns.route("/datasets//documents//indexing-status") class DocumentIndexingStatusApi(DocumentResource): + @api.doc("get_document_indexing_status") + @api.doc(description="Get document indexing status") + @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @api.response(200, "Indexing status retrieved successfully") + @api.response(404, "Document not found") @setup_required @login_required @account_initialization_required @@ -635,9 +683,21 @@ class DocumentIndexingStatusApi(DocumentResource): return marshal(document_dict, document_status_fields) +@console_ns.route("/datasets//documents/") class DocumentApi(DocumentResource): METADATA_CHOICES = {"all", "only", "without"} + @api.doc("get_document") + @api.doc(description="Get document details") + @api.doc( + params={ + "dataset_id": "Dataset ID", + "document_id": "Document ID", + "metadata": "Metadata inclusion (all/only/without)", + } + ) + @api.response(200, "Document retrieved successfully") + @api.response(404, "Document not found") @setup_required @login_required @account_initialization_required @@ -746,7 +806,16 @@ class DocumentApi(DocumentResource): return {"result": "success"}, 204 +@console_ns.route("/datasets//documents//processing/") class DocumentProcessingApi(DocumentResource): + @api.doc("update_document_processing") + @api.doc(description="Update document processing status (pause/resume)") + @api.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"} + ) + @api.response(200, "Processing status updated successfully") + @api.response(404, "Document not found") + @api.response(400, "Invalid action") @setup_required @login_required @account_initialization_required @@ -781,7 +850,23 @@ class DocumentProcessingApi(DocumentResource): return {"result": "success"}, 200 +@console_ns.route("/datasets//documents//metadata") class DocumentMetadataApi(DocumentResource): + @api.doc("update_document_metadata") + @api.doc(description="Update document metadata") + @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @api.expect( + api.model( + "UpdateDocumentMetadataRequest", + { + "doc_type": fields.String(description="Document type"), + "doc_metadata": fields.Raw(description="Document metadata"), + }, + ) + ) + @api.response(200, "Document metadata updated successfully") + @api.response(404, "Document not found") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -1015,26 +1100,3 @@ class WebsiteDocumentSyncApi(DocumentResource): DocumentService.sync_website_document(dataset_id, document) return {"result": "success"}, 200 - - -api.add_resource(GetProcessRuleApi, "/datasets/process-rule") -api.add_resource(DatasetDocumentListApi, "/datasets//documents") -api.add_resource(DatasetInitApi, "/datasets/init") -api.add_resource( - DocumentIndexingEstimateApi, "/datasets//documents//indexing-estimate" -) -api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets//batch//indexing-estimate") -api.add_resource(DocumentBatchIndexingStatusApi, "/datasets//batch//indexing-status") -api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") -api.add_resource(DocumentApi, "/datasets//documents/") -api.add_resource( - DocumentProcessingApi, "/datasets//documents//processing/" -) -api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") -api.add_resource(DocumentStatusApi, "/datasets//documents/status//batch") -api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") -api.add_resource(DocumentRecoverApi, "/datasets//documents//processing/resume") -api.add_resource(DocumentRetryApi, "/datasets//retry") -api.add_resource(DocumentRenameApi, "/datasets//documents//rename") - -api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 043f39f623..7195a5dd11 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,10 +1,10 @@ from flask import request from flask_login import current_user -from flask_restx import Resource, marshal, reqparse +from flask_restx import Resource, fields, marshal, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import account_initialization_required, setup_required from fields.dataset_fields import dataset_detail_fields @@ -21,7 +21,18 @@ def _validate_name(name): return name +@console_ns.route("/datasets/external-knowledge-api") class ExternalApiTemplateListApi(Resource): + @api.doc("get_external_api_templates") + @api.doc(description="Get external knowledge API templates") + @api.doc( + params={ + "page": "Page number (default: 1)", + "limit": "Number of items per page (default: 20)", + "keyword": "Search keyword", + } + ) + @api.response(200, "External API templates retrieved successfully") @setup_required @login_required @account_initialization_required @@ -79,7 +90,13 @@ class ExternalApiTemplateListApi(Resource): return external_knowledge_api.to_dict(), 201 +@console_ns.route("/datasets/external-knowledge-api/") class ExternalApiTemplateApi(Resource): + @api.doc("get_external_api_template") + @api.doc(description="Get external knowledge API template details") + @api.doc(params={"external_knowledge_api_id": "External knowledge API ID"}) + @api.response(200, "External API template retrieved successfully") + @api.response(404, "Template not found") @setup_required @login_required @account_initialization_required @@ -138,7 +155,12 @@ class ExternalApiTemplateApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/datasets/external-knowledge-api//use-check") class ExternalApiUseCheckApi(Resource): + @api.doc("check_external_api_usage") + @api.doc(description="Check if external knowledge API is being used") + @api.doc(params={"external_knowledge_api_id": "External knowledge API ID"}) + @api.response(200, "Usage check completed successfully") @setup_required @login_required @account_initialization_required @@ -151,7 +173,24 @@ class ExternalApiUseCheckApi(Resource): return {"is_using": external_knowledge_api_is_using, "count": count}, 200 +@console_ns.route("/datasets/external") class ExternalDatasetCreateApi(Resource): + @api.doc("create_external_dataset") + @api.doc(description="Create external knowledge dataset") + @api.expect( + api.model( + "CreateExternalDatasetRequest", + { + "external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"), + "external_knowledge_id": fields.String(required=True, description="External knowledge ID"), + "name": fields.String(required=True, description="Dataset name"), + "description": fields.String(description="Dataset description"), + }, + ) + ) + @api.response(201, "External dataset created successfully", dataset_detail_fields) + @api.response(400, "Invalid parameters") + @api.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @@ -191,7 +230,24 @@ class ExternalDatasetCreateApi(Resource): return marshal(dataset, dataset_detail_fields), 201 +@console_ns.route("/datasets//external-hit-testing") class ExternalKnowledgeHitTestingApi(Resource): + @api.doc("test_external_knowledge_retrieval") + @api.doc(description="Test external knowledge retrieval for dataset") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.expect( + api.model( + "ExternalHitTestingRequest", + { + "query": fields.String(required=True, description="Query text for testing"), + "retrieval_model": fields.Raw(description="Retrieval model configuration"), + "external_retrieval_model": fields.Raw(description="External retrieval model configuration"), + }, + ) + ) + @api.response(200, "External hit testing completed successfully") + @api.response(404, "Dataset not found") + @api.response(400, "Invalid parameters") @setup_required @login_required @account_initialization_required @@ -228,8 +284,22 @@ class ExternalKnowledgeHitTestingApi(Resource): raise InternalServerError(str(e)) +@console_ns.route("/test/retrieval") class BedrockRetrievalApi(Resource): # this api is only for internal testing + @api.doc("bedrock_retrieval_test") + @api.doc(description="Bedrock retrieval test (internal use only)") + @api.expect( + api.model( + "BedrockRetrievalTestRequest", + { + "retrieval_setting": fields.Raw(required=True, description="Retrieval settings"), + "query": fields.String(required=True, description="Query text"), + "knowledge_id": fields.String(required=True, description="Knowledge ID"), + }, + ) + ) + @api.response(200, "Bedrock retrieval test completed") def post(self): parser = reqparse.RequestParser() parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") @@ -247,12 +317,3 @@ class BedrockRetrievalApi(Resource): args["retrieval_setting"], args["query"], args["knowledge_id"] ) return result, 200 - - -api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets//external-hit-testing") -api.add_resource(ExternalDatasetCreateApi, "/datasets/external") -api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api") -api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/") -api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api//use-check") -# this api is only for internal test -api.add_resource(BedrockRetrievalApi, "/test/retrieval") diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 2ad192571b..abaca88090 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,6 +1,6 @@ -from flask_restx import Resource +from flask_restx import Resource, fields -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.console.wraps import ( account_initialization_required, @@ -10,7 +10,25 @@ from controllers.console.wraps import ( from libs.login import login_required +@console_ns.route("/datasets//hit-testing") class HitTestingApi(Resource, DatasetsHitTestingBase): + @api.doc("test_dataset_retrieval") + @api.doc(description="Test dataset knowledge retrieval") + @api.doc(params={"dataset_id": "Dataset ID"}) + @api.expect( + api.model( + "HitTestingRequest", + { + "query": fields.String(required=True, description="Query text for testing"), + "retrieval_model": fields.Raw(description="Retrieval model configuration"), + "top_k": fields.Integer(description="Number of top results to return"), + "score_threshold": fields.Float(description="Score threshold for filtering results"), + }, + ) + ) + @api.response(200, "Hit testing completed successfully") + @api.response(404, "Dataset not found") + @api.response(400, "Invalid parameters") @setup_required @login_required @account_initialization_required @@ -23,6 +41,3 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) - - -api.add_resource(HitTestingApi, "/datasets//hit-testing") diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index bdaa268462..b9c1f65bfd 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -1,13 +1,32 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService +@console_ns.route("/website/crawl") class WebsiteCrawlApi(Resource): + @api.doc("crawl_website") + @api.doc(description="Crawl website content") + @api.expect( + api.model( + "WebsiteCrawlRequest", + { + "provider": fields.String( + required=True, + description="Crawl provider (firecrawl/watercrawl/jinareader)", + enum=["firecrawl", "watercrawl", "jinareader"], + ), + "url": fields.String(required=True, description="URL to crawl"), + "options": fields.Raw(required=True, description="Crawl options"), + }, + ) + ) + @api.response(200, "Website crawl initiated successfully") + @api.response(400, "Invalid crawl parameters") @setup_required @login_required @account_initialization_required @@ -39,7 +58,14 @@ class WebsiteCrawlApi(Resource): return result, 200 +@console_ns.route("/website/crawl/status/") class WebsiteCrawlStatusApi(Resource): + @api.doc("get_crawl_status") + @api.doc(description="Get website crawl status") + @api.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"}) + @api.response(200, "Crawl status retrieved successfully") + @api.response(404, "Crawl job not found") + @api.response(400, "Invalid provider") @setup_required @login_required @account_initialization_required @@ -62,7 +88,3 @@ class WebsiteCrawlStatusApi(Resource): except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 - - -api.add_resource(WebsiteCrawlApi, "/website/crawl") -api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/") diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index d9afb5bab2..7742ea24a9 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -20,7 +20,7 @@ class AppParameterApi(InstalledAppResource): if app_model is None: raise AppUnavailableError() - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 6401f804c0..3a8ba64a03 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Concatenate, Optional, ParamSpec, TypeVar +from typing import Concatenate, ParamSpec, TypeVar from flask_login import current_user from flask_restx import Resource @@ -20,7 +20,7 @@ R = TypeVar("R") T = TypeVar("T") -def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): +def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None): def decorator(view: Callable[Concatenate[InstalledApp, P], R]): @wraps(view) def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): @@ -50,7 +50,7 @@ def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], return decorator -def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): +def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None): def decorator(view: Callable[Concatenate[InstalledApp, P], R]): @wraps(view) def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index e375fe285b..092071481e 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -242,6 +242,19 @@ def email_password_login_enabled(view: Callable[P, R]): return decorated +def email_register_enabled(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if features.is_allow_register: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated + + def enable_change_email(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 26fbf7097e..f8976b86b9 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -14,6 +14,15 @@ api = ExternalApi( files_ns = Namespace("files", description="File operations", path="/") -from . import image_preview, tool_files, upload # pyright: ignore[reportUnusedImport] +from . import image_preview, tool_files, upload api.add_namespace(files_ns) + +__all__ = [ + "api", + "bp", + "files_ns", + "image_preview", + "tool_files", + "upload", +] diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 7a2b3b0428..206a5d1cc2 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -1,5 +1,4 @@ from mimetypes import guess_extension -from typing import Optional from flask_restx import Resource, reqparse from flask_restx.api import HTTPStatus @@ -73,11 +72,11 @@ class PluginUploadFileApi(Resource): nonce: str = args["nonce"] sign: str = args["sign"] tenant_id: str = args["tenant_id"] - user_id: Optional[str] = args.get("user_id") + user_id: str | None = args.get("user_id") user = get_user(tenant_id, user_id) - filename: Optional[str] = file.filename - mimetype: Optional[str] = file.mimetype + filename: str | None = file.filename + mimetype: str | None = file.mimetype if not filename or not mimetype: raise Forbidden("Invalid request.") @@ -86,7 +85,7 @@ class PluginUploadFileApi(Resource): filename=filename, mimetype=mimetype, tenant_id=tenant_id, - user_id=user_id, + user_id=user.id, timestamp=timestamp, nonce=nonce, sign=sign, diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index f29f624ba5..74005217ef 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -15,8 +15,17 @@ api = ExternalApi( # Create namespace inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") -from . import mail as _mail # pyright: ignore[reportUnusedImport] -from .plugin import plugin as _plugin # pyright: ignore[reportUnusedImport] -from .workspace import workspace as _workspace # pyright: ignore[reportUnusedImport] +from . import mail as _mail +from .plugin import plugin as _plugin +from .workspace import workspace as _workspace api.add_namespace(inner_api_ns) + +__all__ = [ + "_mail", + "_plugin", + "_workspace", + "api", + "bp", + "inner_api_ns", +] diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index bde0150ffd..3776d0be0e 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional, ParamSpec, TypeVar, cast +from typing import ParamSpec, TypeVar, cast from flask import current_app, request from flask_login import user_logged_in @@ -8,11 +8,10 @@ from flask_restx import reqparse from pydantic import BaseModel from sqlalchemy.orm import Session -from core.file.constants import DEFAULT_SERVICE_API_USER_ID from extensions.ext_database import db from libs.login import current_user from models.account import Tenant -from models.model import EndUser +from models.model import DefaultEndUserSessionID, EndUser P = ParamSpec("P") R = TypeVar("R") @@ -28,7 +27,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: try: with Session(db.engine) as session: if not user_id: - user_id = DEFAULT_SERVICE_API_USER_ID + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value user_model = ( session.query(EndUser) @@ -42,7 +41,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: user_model = EndUser( tenant_id=tenant_id, type="service_api", - is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID, + is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value, session_id=user_id, ) session.add(user_model) @@ -55,7 +54,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: return user_model -def get_user_tenant(view: Optional[Callable[P, R]] = None): +def get_user_tenant(view: Callable[P, R] | None = None): def decorator(view_func: Callable[P, R]): @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): @@ -73,7 +72,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None): raise ValueError("tenant_id is required") if not user_id: - user_id = DEFAULT_SERVICE_API_USER_ID + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value try: tenant_model = ( @@ -107,7 +106,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None): return decorator(view) -def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]): +def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): def decorator(view_func: Callable[P, R]): def decorated_view(*args: P.args, **kwargs: P.kwargs): try: diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py index 336a7801bb..d6fb2981e4 100644 --- a/api/controllers/mcp/__init__.py +++ b/api/controllers/mcp/__init__.py @@ -14,6 +14,13 @@ api = ExternalApi( mcp_ns = Namespace("mcp", description="MCP operations", path="/") -from . import mcp # pyright: ignore[reportUnusedImport] +from . import mcp api.add_namespace(mcp_ns) + +__all__ = [ + "api", + "bp", + "mcp", + "mcp_ns", +] diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 43b59d5334..a8629dca20 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union from flask import Response from flask_restx import Resource, reqparse @@ -73,7 +73,7 @@ class MCPAppApi(Resource): ValidationError: Invalid request format or parameters """ args = mcp_request_parser.parse_args() - request_id: Optional[Union[int, str]] = args.get("id") + request_id: Union[int, str] | None = args.get("id") mcp_request = self._parse_mcp_request(args) with Session(db.engine, expire_on_commit=False) as session: @@ -107,7 +107,7 @@ class MCPAppApi(Resource): def _process_mcp_message( self, mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification, - request_id: Optional[Union[int, str]], + request_id: Union[int, str] | None, app: App, mcp_server: AppMCPServer, user_input_form: list[VariableEntity], @@ -130,7 +130,7 @@ class MCPAppApi(Resource): def _handle_request( self, mcp_request: mcp_types.ClientRequest, - request_id: Optional[Union[int, str]], + request_id: Union[int, str] | None, app: App, mcp_server: AppMCPServer, user_input_form: list[VariableEntity], @@ -150,7 +150,7 @@ class MCPAppApi(Resource): def _get_user_input_form(self, app: App) -> list[VariableEntity]: """Get and convert user input form""" # Get raw user input form based on app mode - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if not app.workflow: raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable") raw_user_input_form = app.workflow.user_input_form(to_old_structure=True) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index a6008fdb99..9032733e2c 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -14,27 +14,46 @@ api = ExternalApi( service_api_ns = Namespace("service_api", description="Service operations", path="/") -from . import index # pyright: ignore[reportUnusedImport] +from . import index from .app import ( - annotation, # pyright: ignore[reportUnusedImport] - app, # pyright: ignore[reportUnusedImport] - audio, # pyright: ignore[reportUnusedImport] - completion, # pyright: ignore[reportUnusedImport] - conversation, # pyright: ignore[reportUnusedImport] - file, # pyright: ignore[reportUnusedImport] - file_preview, # pyright: ignore[reportUnusedImport] - message, # pyright: ignore[reportUnusedImport] - site, # pyright: ignore[reportUnusedImport] - workflow, # pyright: ignore[reportUnusedImport] + annotation, + app, + audio, + completion, + conversation, + file, + file_preview, + message, + site, + workflow, ) from .dataset import ( - dataset, # pyright: ignore[reportUnusedImport] - document, # pyright: ignore[reportUnusedImport] - hit_testing, # pyright: ignore[reportUnusedImport] - metadata, # pyright: ignore[reportUnusedImport] - segment, # pyright: ignore[reportUnusedImport] - upload_file, # pyright: ignore[reportUnusedImport] + dataset, + document, + hit_testing, + metadata, + segment, ) -from .workspace import models # pyright: ignore[reportUnusedImport] +from .workspace import models + +__all__ = [ + "annotation", + "app", + "audio", + "completion", + "conversation", + "dataset", + "document", + "file", + "file_preview", + "hit_testing", + "index", + "message", + "metadata", + "models", + "segment", + "site", + "workflow", +] api.add_namespace(service_api_ns) diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 9038bda11a..ad1bdc7334 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -165,7 +165,7 @@ class AnnotationUpdateDeleteApi(Resource): def put(self, app_model: App, annotation_id): """Update an existing annotation.""" assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() annotation_id = str(annotation_id) @@ -189,7 +189,7 @@ class AnnotationUpdateDeleteApi(Resource): """Delete an annotation.""" assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() annotation_id = str(annotation_id) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 2dbeed1d68..25d7ccccec 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -29,7 +29,7 @@ class AppParameterApi(Resource): Returns the input form parameters and configuration for the application. """ - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 580b08b9f0..99fde12e34 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -340,6 +340,9 @@ class DatasetApi(DatasetApiResource): else: data["embedding_available"] = True + # force update search method to keyword_search if indexing_technique is economic + data["retrieval_model_dict"]["search_method"] = "keyword_search" + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) data.update({"partial_member_list": part_users_list}) @@ -559,7 +562,7 @@ class DatasetTagsApi(DatasetApiResource): def post(self, _, dataset_id): """Add a knowledge type tag.""" assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_create_parser.parse_args() @@ -583,7 +586,7 @@ class DatasetTagsApi(DatasetApiResource): @validate_dataset_token def patch(self, _, dataset_id): assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_update_parser.parse_args() @@ -610,7 +613,7 @@ class DatasetTagsApi(DatasetApiResource): def delete(self, _, dataset_id): """Delete a knowledge type tag.""" assert isinstance(current_user, Account) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() args = tag_delete_parser.parse_args() TagService.delete_tag(args.get("tag_id")) @@ -634,7 +637,7 @@ class DatasetTagBindingApi(DatasetApiResource): def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_binding_parser.parse_args() @@ -660,7 +663,7 @@ class DatasetTagUnbindingApi(DatasetApiResource): def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator assert isinstance(current_user, Account) - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() args = tag_unbinding_parser.parse_args() diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py deleted file mode 100644 index 27b36a6402..0000000000 --- a/api/controllers/service_api/dataset/upload_file.py +++ /dev/null @@ -1,65 +0,0 @@ -from werkzeug.exceptions import NotFound - -from controllers.service_api import service_api_ns -from controllers.service_api.wraps import ( - DatasetApiResource, -) -from core.file import helpers as file_helpers -from extensions.ext_database import db -from models.dataset import Dataset -from models.model import UploadFile -from services.dataset_service import DocumentService - - -@service_api_ns.route("/datasets//documents//upload-file") -class UploadFileApi(DatasetApiResource): - @service_api_ns.doc("get_upload_file") - @service_api_ns.doc(description="Get upload file information and download URL") - @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) - @service_api_ns.doc( - responses={ - 200: "Upload file information retrieved successfully", - 401: "Unauthorized - invalid API token", - 404: "Dataset, document, or upload file not found", - } - ) - def get(self, tenant_id, dataset_id, document_id): - """Get upload file information and download URL. - - Returns information about an uploaded file including its download URL. - """ - # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: - raise NotFound("Dataset not found.") - # check document - document_id = str(document_id) - document = DocumentService.get_document(dataset.id, document_id) - if not document: - raise NotFound("Document not found.") - # check upload file - if document.data_source_type != "upload_file": - raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.") - data_source_info = document.data_source_info_dict - if data_source_info and "upload_file_id" in data_source_info: - file_id = data_source_info["upload_file_id"] - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() - if not upload_file: - raise NotFound("UploadFile not found.") - else: - raise ValueError("Upload file id not found in document data source info.") - - url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": url, - "download_url": f"{url}&as_attachment=true", - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at.timestamp(), - }, 200 diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 64a2f5445c..1a40707c65 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -3,7 +3,7 @@ from collections.abc import Callable from datetime import timedelta from enum import StrEnum, auto from functools import wraps -from typing import Concatenate, Optional, ParamSpec, TypeVar +from typing import Concatenate, ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in @@ -13,14 +13,13 @@ from sqlalchemy import select, update from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound, Unauthorized -from core.file.constants import DEFAULT_SERVICE_API_USER_ID from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from libs.login import current_user from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.dataset import Dataset, RateLimitLog -from models.model import ApiToken, App, EndUser +from models.model import ApiToken, App, DefaultEndUserSessionID, EndUser from services.feature_service import FeatureService P = ParamSpec("P") @@ -43,7 +42,7 @@ class FetchUserArg(BaseModel): required: bool = False -def validate_app_token(view: Optional[Callable[P, R]] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): +def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None): def decorator(view_func: Callable[P, R]): @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): @@ -190,7 +189,7 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): return interceptor -def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None): +def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): def decorator(view: Callable[Concatenate[T, P], R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): @@ -268,12 +267,12 @@ def validate_and_get_api_token(scope: str | None = None): return api_token -def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser: +def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None = None) -> EndUser: """ Create or update session terminal based on user ID. """ if not user_id: - user_id = DEFAULT_SERVICE_API_USER_ID + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value with Session(db.engine, expire_on_commit=False) as session: end_user = ( @@ -292,7 +291,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] tenant_id=app_model.tenant_id, app_id=app_model.id, type="service_api", - is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID, + is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value, session_id=user_id, ) session.add(end_user) diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 97bcd3d53c..1d22954308 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -16,20 +16,40 @@ api = ExternalApi( web_ns = Namespace("web", description="Web application API operations", path="/") from . import ( - app, # pyright: ignore[reportUnusedImport] - audio, # pyright: ignore[reportUnusedImport] - completion, # pyright: ignore[reportUnusedImport] - conversation, # pyright: ignore[reportUnusedImport] - feature, # pyright: ignore[reportUnusedImport] - files, # pyright: ignore[reportUnusedImport] - forgot_password, # pyright: ignore[reportUnusedImport] - login, # pyright: ignore[reportUnusedImport] - message, # pyright: ignore[reportUnusedImport] - passport, # pyright: ignore[reportUnusedImport] - remote_files, # pyright: ignore[reportUnusedImport] - saved_message, # pyright: ignore[reportUnusedImport] - site, # pyright: ignore[reportUnusedImport] - workflow, # pyright: ignore[reportUnusedImport] + app, + audio, + completion, + conversation, + feature, + files, + forgot_password, + login, + message, + passport, + remote_files, + saved_message, + site, + workflow, ) api.add_namespace(web_ns) + +__all__ = [ + "api", + "app", + "audio", + "bp", + "completion", + "conversation", + "feature", + "files", + "forgot_password", + "login", + "message", + "passport", + "remote_files", + "saved_message", + "site", + "web_ns", + "workflow", +] diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index e0c3e997ce..2bc068ec75 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -38,7 +38,7 @@ class AppParameterApi(WebApiResource): @marshal_with(fields.parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index e79456535a..ba03c4eae4 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,7 +1,7 @@ from collections.abc import Callable from datetime import UTC, datetime from functools import wraps -from typing import Concatenate, Optional, ParamSpec, TypeVar +from typing import Concatenate, ParamSpec, TypeVar from flask import request from flask_restx import Resource @@ -21,7 +21,7 @@ P = ParamSpec("P") R = TypeVar("R") -def validate_jwt_token(view: Optional[Callable[Concatenate[App, EndUser, P], R]] = None): +def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = None): def decorator(view: Callable[Concatenate[App, EndUser, P], R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1bcf83de6a..0a874e9085 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,7 +1,7 @@ import json import logging import uuid -from typing import Optional, Union, cast +from typing import Union, cast from sqlalchemy import select @@ -60,8 +60,8 @@ class BaseAgentRunner(AppRunner): message: Message, user_id: str, model_instance: ModelInstance, - memory: Optional[TokenBufferMemory] = None, - prompt_messages: Optional[list[PromptMessage]] = None, + memory: TokenBufferMemory | None = None, + prompt_messages: list[PromptMessage] | None = None, ): self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -112,7 +112,7 @@ class BaseAgentRunner(AppRunner): features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] - self.query: Optional[str] = "" + self.query: str | None = "" self._current_thoughts: list[PromptMessage] = [] def _repack_app_generate_entity( diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index d1d5a011e0..25ad6dc060 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional +from typing import Any from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit @@ -70,12 +70,12 @@ class CotAgentRunner(BaseAgentRunner, ABC): self._prompt_messages_tools = prompt_messages_tools function_call_state = True - llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} + llm_usage: dict[str, LLMUsage | None] = {"usage": None} final_answer = "" prompt_messages: list = [] # Initialize prompt_messages agent_thought_id = "" # Initialize agent_thought_id - def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: @@ -122,7 +122,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): callbacks=[], ) - usage_dict: dict[str, Optional[LLMUsage]] = {} + usage_dict: dict[str, LLMUsage | None] = {} react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( agent_response="", @@ -274,7 +274,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): action: AgentScratchpadUnit.Action, tool_instances: Mapping[str, Tool], message_file_ids: list[str], - trace_manager: Optional[TraceQueueManager] = None, + trace_manager: TraceQueueManager | None = None, ) -> tuple[str, ToolInvokeMeta]: """ handle invoke action diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 3a4d31e047..da9a001d84 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,5 +1,4 @@ import json -from typing import Optional from core.agent.cot_agent_runner import CotAgentRunner from core.model_runtime.entities.message_entities import ( @@ -31,7 +30,7 @@ class CotCompletionAgentRunner(CotAgentRunner): return system_prompt - def _organize_historic_prompt(self, current_session_messages: Optional[list[PromptMessage]] = None) -> str: + def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str: """ Organize historic prompt """ diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 816d2782f0..220feced1d 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, Field @@ -50,11 +50,11 @@ class AgentScratchpadUnit(BaseModel): "action_input": self.action_input, } - agent_response: Optional[str] = None - thought: Optional[str] = None - action_str: Optional[str] = None - observation: Optional[str] = None - action: Optional[Action] = None + agent_response: str | None = None + thought: str | None = None + action_str: str | None = None + observation: str | None = None + action: Action | None = None def is_final(self) -> bool: """ @@ -81,8 +81,8 @@ class AgentEntity(BaseModel): provider: str model: str strategy: Strategy - prompt: Optional[AgentPromptEntity] = None - tools: Optional[list[AgentToolEntity]] = None + prompt: AgentPromptEntity | None = None + tools: list[AgentToolEntity] | None = None max_iteration: int = 10 diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 5236266908..dcc1326b33 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -2,7 +2,7 @@ import json import logging from collections.abc import Generator from copy import deepcopy -from typing import Any, Optional, Union +from typing import Any, Union from core.agent.base_agent_runner import BaseAgentRunner from core.app.apps.base_app_queue_manager import PublishFrom @@ -52,14 +52,14 @@ class FunctionCallAgentRunner(BaseAgentRunner): # continue to run until there is not any tool call function_call_state = True - llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} + llm_usage: dict[str, LLMUsage | None] = {"usage": None} final_answer = "" prompt_messages: list = [] # Initialize prompt_messages # get tracing instance trace_manager = app_generate_entity.trace_manager - def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index a3438fc2c7..90aa7b5fd4 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -1,5 +1,5 @@ -import enum -from typing import Any, Optional +from enum import StrEnum +from typing import Any from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator @@ -26,25 +26,25 @@ class AgentStrategyProviderIdentity(ToolProviderIdentity): class AgentStrategyParameter(PluginParameter): - class AgentStrategyParameterType(enum.StrEnum): + class AgentStrategyParameterType(StrEnum): """ Keep all the types from PluginParameterType """ - STRING = CommonParameterType.STRING.value - NUMBER = CommonParameterType.NUMBER.value - BOOLEAN = CommonParameterType.BOOLEAN.value - SELECT = CommonParameterType.SELECT.value - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - FILE = CommonParameterType.FILE.value - FILES = CommonParameterType.FILES.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value - ANY = CommonParameterType.ANY.value + STRING = CommonParameterType.STRING + NUMBER = CommonParameterType.NUMBER + BOOLEAN = CommonParameterType.BOOLEAN + SELECT = CommonParameterType.SELECT + SECRET_INPUT = CommonParameterType.SECRET_INPUT + FILE = CommonParameterType.FILE + FILES = CommonParameterType.FILES + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR + ANY = CommonParameterType.ANY # deprecated, should not use. - SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + SYSTEM_FILES = CommonParameterType.SYSTEM_FILES def as_normal_type(self): return as_normal_type(self) @@ -53,7 +53,7 @@ class AgentStrategyParameter(PluginParameter): return cast_parameter_value(self, value) type: AgentStrategyParameterType = Field(..., description="The type of the parameter") - help: Optional[I18nObject] = None + help: I18nObject | None = None def init_frontend_parameter(self, value: Any): return init_frontend_parameter(self, self.type, value) @@ -61,7 +61,7 @@ class AgentStrategyParameter(PluginParameter): class AgentStrategyProviderEntity(BaseModel): identity: AgentStrategyProviderIdentity - plugin_id: Optional[str] = Field(None, description="The id of the plugin") + plugin_id: str | None = Field(None, description="The id of the plugin") class AgentStrategyIdentity(ToolIdentity): @@ -72,7 +72,7 @@ class AgentStrategyIdentity(ToolIdentity): pass -class AgentFeature(enum.StrEnum): +class AgentFeature(StrEnum): """ Agent Feature, used to describe the features of the agent strategy. """ @@ -84,9 +84,9 @@ class AgentStrategyEntity(BaseModel): identity: AgentStrategyIdentity parameters: list[AgentStrategyParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The description of the agent strategy") - output_schema: Optional[dict] = None - features: Optional[list[AgentFeature]] = None - meta_version: Optional[str] = None + output_schema: dict | None = None + features: list[AgentFeature] | None = None + meta_version: str | None = None # pydantic configs model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/agent/strategy/base.py b/api/core/agent/strategy/base.py index a52a1dfd7a..8a9be05dde 100644 --- a/api/core/agent/strategy/base.py +++ b/api/core/agent/strategy/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Sequence -from typing import Any, Optional +from typing import Any from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyParameter @@ -16,10 +16,10 @@ class BaseAgentStrategy(ABC): self, params: dict[str, Any], user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - credentials: Optional[InvokeCredentials] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. @@ -37,9 +37,9 @@ class BaseAgentStrategy(ABC): self, params: dict[str, Any], user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - credentials: Optional[InvokeCredentials] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, ) -> Generator[AgentInvokeMessage, None, None]: pass diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index 04661581a7..a3cc798352 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Sequence -from typing import Any, Optional +from typing import Any from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter @@ -38,10 +38,10 @@ class PluginAgentStrategy(BaseAgentStrategy): self, params: dict[str, Any], user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - credentials: Optional[InvokeCredentials] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 97ede178c7..e925d6dd52 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -1,12 +1,10 @@ -from typing import Optional - from core.app.app_config.entities import SensitiveWordAvoidanceEntity from core.moderation.factory import ModerationFactory class SensitiveWordAvoidanceConfigManager: @classmethod - def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: + def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None: sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 8887d2500c..eab26e5af9 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -1,12 +1,10 @@ -from typing import Optional - from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity from core.agent.prompt.template import REACT_PROMPT_TEMPLATES class AgentConfigManager: @classmethod - def convert(cls, config: dict) -> Optional[AgentEntity]: + def convert(cls, config: dict) -> AgentEntity | None: """ Convert model config to model config diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index fcbf479e2e..4b824bde76 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,5 +1,4 @@ import uuid -from typing import Optional from core.app.app_config.entities import ( DatasetEntity, @@ -14,7 +13,7 @@ from services.dataset_service import DatasetService class DatasetConfigManager: @classmethod - def convert(cls, config: dict) -> Optional[DatasetEntity]: + def convert(cls, config: dict) -> DatasetEntity | None: """ Convert model config to model config diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index cda17c0010..ec4f6074ab 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -70,7 +70,7 @@ class PromptTemplateConfigManager: :param config: app model config args """ if not config.get("prompt_type"): - config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value + config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] if config["prompt_type"] not in prompt_type_vals: @@ -90,7 +90,7 @@ class PromptTemplateConfigManager: if not isinstance(config["completion_prompt_config"], dict): raise ValueError("completion_prompt_config must be of object type") - if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value: + if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED: if not config["chat_prompt_config"] and not config["completion_prompt_config"]: raise ValueError( "chat_prompt_config or completion_prompt_config is required when prompt_type is advanced" diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index df2074df2c..533cb37f8f 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,6 +1,6 @@ from collections.abc import Sequence -from enum import Enum, StrEnum -from typing import Any, Literal, Optional +from enum import StrEnum, auto +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator @@ -17,7 +17,7 @@ class ModelConfigEntity(BaseModel): provider: str model: str - mode: Optional[str] = None + mode: str | None = None parameters: dict[str, Any] = Field(default_factory=dict) stop: list[str] = Field(default_factory=list) @@ -53,7 +53,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel): assistant: str prompt: str - role_prefix: Optional[RolePrefixEntity] = None + role_prefix: RolePrefixEntity | None = None class PromptTemplateEntity(BaseModel): @@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel): Prompt Template Entity. """ - class PromptType(Enum): + class PromptType(StrEnum): """ Prompt Type. 'simple', 'advanced' """ - SIMPLE = "simple" - ADVANCED = "advanced" + SIMPLE = auto() + ADVANCED = auto() @classmethod def value_of(cls, value: str): @@ -84,9 +84,9 @@ class PromptTemplateEntity(BaseModel): raise ValueError(f"invalid prompt type value {value}") prompt_type: PromptType - simple_prompt_template: Optional[str] = None - advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None - advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None + simple_prompt_template: str | None = None + advanced_chat_prompt_template: AdvancedChatPromptTemplateEntity | None = None + advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None class VariableEntityType(StrEnum): @@ -112,7 +112,7 @@ class VariableEntity(BaseModel): type: VariableEntityType required: bool = False hide: bool = False - max_length: Optional[int] = None + max_length: int | None = None options: Sequence[str] = Field(default_factory=list) allowed_file_types: Sequence[FileType] = Field(default_factory=list) allowed_file_extensions: Sequence[str] = Field(default_factory=list) @@ -173,7 +173,7 @@ class ModelConfig(BaseModel): class Condition(BaseModel): """ - Conditon detail + Condition detail """ name: str @@ -186,8 +186,8 @@ class MetadataFilteringCondition(BaseModel): Metadata Filtering Condition. """ - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) class DatasetRetrieveConfigEntity(BaseModel): @@ -195,14 +195,14 @@ class DatasetRetrieveConfigEntity(BaseModel): Dataset Retrieve Config Entity. """ - class RetrieveStrategy(Enum): + class RetrieveStrategy(StrEnum): """ Dataset Retrieve Strategy. 'single' or 'multiple' """ - SINGLE = "single" - MULTIPLE = "multiple" + SINGLE = auto() + MULTIPLE = auto() @classmethod def value_of(cls, value: str): @@ -217,18 +217,18 @@ class DatasetRetrieveConfigEntity(BaseModel): return mode raise ValueError(f"invalid retrieve strategy value {value}") - query_variable: Optional[str] = None # Only when app mode is completion + query_variable: str | None = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy - top_k: Optional[int] = None - score_threshold: Optional[float] = 0.0 - rerank_mode: Optional[str] = "reranking_model" - reranking_model: Optional[dict] = None - weights: Optional[dict] = None - reranking_enabled: Optional[bool] = True - metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_model_config: Optional[ModelConfig] = None - metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None + top_k: int | None = None + score_threshold: float | None = 0.0 + rerank_mode: str | None = "reranking_model" + reranking_model: dict | None = None + weights: dict | None = None + reranking_enabled: bool | None = True + metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled" + metadata_model_config: ModelConfig | None = None + metadata_filtering_conditions: MetadataFilteringCondition | None = None class DatasetEntity(BaseModel): @@ -255,8 +255,8 @@ class TextToSpeechEntity(BaseModel): """ enabled: bool - voice: Optional[str] = None - language: Optional[str] = None + voice: str | None = None + language: str | None = None class TracingConfigEntity(BaseModel): @@ -269,15 +269,15 @@ class TracingConfigEntity(BaseModel): class AppAdditionalFeatures(BaseModel): - file_upload: Optional[FileUploadConfig] = None - opening_statement: Optional[str] = None + file_upload: FileUploadConfig | None = None + opening_statement: str | None = None suggested_questions: list[str] = [] suggested_questions_after_answer: bool = False show_retrieve_source: bool = False more_like_this: bool = False speech_to_text: bool = False - text_to_speech: Optional[TextToSpeechEntity] = None - trace_config: Optional[TracingConfigEntity] = None + text_to_speech: TextToSpeechEntity | None = None + trace_config: TracingConfigEntity | None = None class AppConfig(BaseModel): @@ -290,15 +290,15 @@ class AppConfig(BaseModel): app_mode: AppMode additional_features: AppAdditionalFeatures variables: list[VariableEntity] = [] - sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None + sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None -class EasyUIBasedAppModelConfigFrom(Enum): +class EasyUIBasedAppModelConfigFrom(StrEnum): """ App Model Config From. """ - ARGS = "args" + ARGS = auto() APP_LATEST_CONFIG = "app-latest-config" CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config" @@ -313,7 +313,7 @@ class EasyUIBasedAppConfig(AppConfig): app_model_config_dict: dict model: ModelConfigEntity prompt_template: PromptTemplateEntity - dataset: Optional[DatasetEntity] = None + dataset: DatasetEntity | None = None external_data_variables: list[ExternalDataVariableEntity] = [] diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 84b032d6ca..42e19001b3 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -3,7 +3,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping -from typing import Any, Literal, Optional, Union, overload +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -390,7 +390,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): application_generate_entity: AdvancedChatAppGenerateEntity, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, - conversation: Optional[Conversation] = None, + conversation: Conversation | None = None, stream: bool = True, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 635754a201..b8e0b5b310 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -231,7 +231,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): def query_app_annotations_to_reply( self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom - ) -> Optional[MessageAnnotation]: + ) -> MessageAnnotation | None: """ Query app annotations to reply :param app_record: app record diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index cec3b83674..23ce8a7880 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,7 +4,7 @@ import time from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager from threading import Thread -from typing import Any, Optional, Union +from typing import Any, Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -233,7 +233,7 @@ class AdvancedChatAppGenerateTaskPipeline: return None def _wrapper_process_stream_response( - self, trace_manager: Optional[TraceQueueManager] = None + self, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id @@ -294,7 +294,7 @@ class AdvancedChatAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState: + def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState: """Fluent validation for graph runtime state.""" if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") @@ -411,8 +411,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueTextChunkEvent, *, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle text chunk events.""" @@ -538,8 +538,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowSucceededEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow succeeded events.""" @@ -569,8 +569,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowPartialSuccessEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow partial success events.""" @@ -601,8 +601,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueWorkflowFailedEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow failed events.""" @@ -636,8 +636,8 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueStopEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle stop events.""" @@ -677,7 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: QueueAdvancedChatMessageEndEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, + graph_runtime_state: GraphRuntimeState | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle advanced chat message end events.""" @@ -775,10 +775,10 @@ class AdvancedChatAppGenerateTaskPipeline: self, event: Any, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + graph_runtime_state: GraphRuntimeState | None = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, ) -> Generator[StreamResponse, None, None]: """Dispatch events using elegant pattern matching.""" handlers = self._get_event_handlers() @@ -830,15 +830,15 @@ class AdvancedChatAppGenerateTaskPipeline: def _process_stream_response( self, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response using elegant Fluent Python patterns. Maintains exact same functionality as original 57-if-statement version. """ # Initialize graph runtime state - graph_runtime_state: Optional[GraphRuntimeState] = None + graph_runtime_state: GraphRuntimeState | None = None for queue_message in self._base_task_pipeline.queue_manager.listen(): event = queue_message.event @@ -888,7 +888,7 @@ class AdvancedChatAppGenerateTaskPipeline: if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None): + def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None): message = self._get_message(session=session) # If there are assistant files, remove markdown image links from answer diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 54d1a9595f..9ce841f432 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from core.agent.entities import AgentEntity from core.app.app_config.base_app_config_manager import BaseAppConfigManager @@ -30,7 +30,7 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): Agent Chatbot App Config Entity. """ - agent: Optional[AgentEntity] = None + agent: AgentEntity | None = None class AgentChatAppConfigManager(BaseAppConfigManager): @@ -39,8 +39,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager): cls, app_model: App, app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None, + conversation: Conversation | None = None, + override_config_dict: dict | None = None, ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 6681fc6e48..8f13599ead 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, Union, final +from typing import TYPE_CHECKING, Any, Union, final from sqlalchemy.orm import Session @@ -24,7 +24,7 @@ class BaseAppGenerator: def _prepare_user_inputs( self, *, - user_inputs: Optional[Mapping[str, Any]], + user_inputs: Mapping[str, Any] | None, variables: Sequence["VariableEntity"], tenant_id: str, strict_type_validation: bool = False, diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 2a7fe7902b..a58795bccb 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -2,7 +2,7 @@ import queue import time from abc import abstractmethod from enum import IntEnum, auto -from typing import Any, Optional +from typing import Any from sqlalchemy.orm import DeclarativeMeta @@ -116,7 +116,7 @@ class AppQueueManager: Set task stop flag :return: """ - result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id)) + result: Any | None = redis_client.get(cls._generate_task_belong_cache_key(task_id)) if result is None: return diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index dafdcdd429..e7db3bc41b 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,7 +1,7 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -82,11 +82,11 @@ class AppRunner: prompt_template_entity: PromptTemplateEntity, inputs: Mapping[str, str], files: Sequence["File"], - query: Optional[str] = None, - context: Optional[str] = None, - memory: Optional[TokenBufferMemory] = None, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + query: str | None = None, + context: str | None = None, + memory: TokenBufferMemory | None = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: """ Organize prompt messages :param context: @@ -161,7 +161,7 @@ class AppRunner: prompt_messages: list, text: str, stream: bool, - usage: Optional[LLMUsage] = None, + usage: LLMUsage | None = None, ): """ Direct output @@ -375,7 +375,7 @@ class AppRunner: def query_app_annotations_to_reply( self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom - ) -> Optional[MessageAnnotation]: + ) -> MessageAnnotation | None: """ Query app annotations to reply :param app_record: app record diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index 96a3db8502..4b6720a3c3 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -32,8 +30,8 @@ class ChatAppConfigManager(BaseAppConfigManager): cls, app_model: App, app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None, + conversation: Conversation | None = None, + override_config_dict: dict | None = None, ) -> ChatAppConfig: """ Convert app model config to chat app config diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 937b2a7dd7..1b4d28a5b8 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -1,7 +1,7 @@ import time from collections.abc import Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from sqlalchemy.orm import Session @@ -140,7 +140,7 @@ class WorkflowResponseConverter: event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[NodeStartStreamResponse]: + ) -> NodeStartStreamResponse | None: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_execution_id: @@ -190,7 +190,7 @@ class WorkflowResponseConverter: | QueueNodeExceptionEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[NodeFinishStreamResponse]: + ) -> NodeFinishStreamResponse | None: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_execution_id: @@ -235,7 +235,7 @@ class WorkflowResponseConverter: event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: + ) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_execution_id: diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 3a1f29689d..eb1902f12e 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -24,7 +22,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig): class CompletionAppConfigManager(BaseAppConfigManager): @classmethod def get_app_config( - cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None ) -> CompletionAppConfig: """ Convert app model config to completion app config diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 92f3b6507c..170c6a274b 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from typing import Optional, Union, cast +from typing import Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -84,7 +84,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): logger.exception("Failed to handle response, conversation_id: %s", conversation.id) raise e - def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: + def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig: if conversation: stmt = select(AppModelConfig).where( AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id @@ -112,7 +112,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity, ], - conversation: Optional[Conversation] = None, + conversation: Conversation | None = None, ) -> tuple[Conversation, Message]: """ Initialize generate records diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 60395f0416..83c29ca166 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -3,7 +3,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Literal, Optional, Union, overload +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -53,7 +53,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[True], call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, ) -> Generator[Mapping | str, None, None]: ... @overload @@ -67,7 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[False], call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, ) -> Mapping[str, Any]: ... @overload @@ -81,7 +81,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool, call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... def generate( @@ -94,7 +94,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool = True, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: files: Sequence[Mapping[str, Any]] = args.get("files") or [] @@ -200,7 +200,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ @@ -434,7 +434,7 @@ class WorkflowAppGenerator(BaseAppGenerator): queue_manager: AppQueueManager, context: contextvars.Context, variable_loader: VariableLoader, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, ): """ Generate worker in a new thread. diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 42b3575807..3026be27f8 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, cast +from typing import cast from configs import dify_config from core.app.apps.base_app_queue_manager import AppQueueManager @@ -31,7 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, variable_loader: VariableLoader, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, workflow: Workflow, system_user_id: str, ): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 1c950063dd..638c4e938c 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Callable, Generator from contextlib import contextmanager -from typing import Any, Optional, Union +from typing import Any, Union from sqlalchemy.orm import Session @@ -206,7 +206,7 @@ class WorkflowAppGenerateTaskPipeline: return None def _wrapper_process_stream_response( - self, trace_manager: Optional[TraceQueueManager] = None + self, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id @@ -268,7 +268,7 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState: + def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState: """Fluent validation for graph runtime state.""" if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") @@ -474,8 +474,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueWorkflowSucceededEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow succeeded events.""" @@ -508,8 +508,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueWorkflowPartialSuccessEvent, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow partial success events.""" @@ -543,8 +543,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: Union[QueueWorkflowFailedEvent, QueueStopEvent], *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - trace_manager: Optional[TraceQueueManager] = None, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle workflow failed and stop events.""" @@ -581,8 +581,8 @@ class WorkflowAppGenerateTaskPipeline: self, event: QueueTextChunkEvent, *, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle text chunk events.""" @@ -635,10 +635,10 @@ class WorkflowAppGenerateTaskPipeline: self, event: Any, *, - graph_runtime_state: Optional[GraphRuntimeState] = None, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, - queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + graph_runtime_state: GraphRuntimeState | None = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, + queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, ) -> Generator[StreamResponse, None, None]: """Dispatch events using elegant pattern matching.""" handlers = self._get_event_handlers() @@ -701,8 +701,8 @@ class WorkflowAppGenerateTaskPipeline: def _process_stream_response( self, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, + tts_publisher: AppGeneratorTTSPublisher | None = None, + trace_manager: TraceQueueManager | None = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response using elegant Fluent Python patterns. @@ -769,7 +769,7 @@ class WorkflowAppGenerateTaskPipeline: session.commit() def _text_chunk_to_stream_response( - self, text: str, from_variable_selector: Optional[list[str]] = None + self, text: str, from_variable_selector: list[str] | None = None ) -> TextChunkStreamResponse: """ Handle completed event. diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 1d5ebabaf7..4c0abd0983 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator @@ -96,7 +96,7 @@ class AppGenerateEntity(BaseModel): # app config app_config: Any = None - file_upload_config: Optional[FileUploadConfig] = None + file_upload_config: FileUploadConfig | None = None inputs: Mapping[str, Any] files: Sequence[File] @@ -114,7 +114,7 @@ class AppGenerateEntity(BaseModel): # tracing instance # Using Any to avoid circular import with TraceQueueManager - trace_manager: Optional[Any] = None + trace_manager: Any | None = None class EasyUIBasedAppGenerateEntity(AppGenerateEntity): @@ -126,7 +126,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): app_config: EasyUIBasedAppConfig = None # type: ignore model_conf: ModelConfigWithCredentialsEntity - query: Optional[str] = None + query: str | None = None # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -137,8 +137,8 @@ class ConversationAppGenerateEntity(AppGenerateEntity): Base entity for conversation-based app generation. """ - conversation_id: Optional[str] = None - parent_message_id: Optional[str] = Field( + conversation_id: str | None = None + parent_message_id: str | None = Field( default=None, description=( "Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API." @@ -188,7 +188,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): # app config app_config: WorkflowUIBasedAppConfig = None # type: ignore - workflow_run_id: Optional[str] = None + workflow_run_id: str | None = None query: str class SingleIterationRunEntity(BaseModel): @@ -199,7 +199,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): node_id: str inputs: Mapping - single_iteration_run: Optional[SingleIterationRunEntity] = None + single_iteration_run: SingleIterationRunEntity | None = None class SingleLoopRunEntity(BaseModel): """ @@ -209,7 +209,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): node_id: str inputs: Mapping - single_loop_run: Optional[SingleLoopRunEntity] = None + single_loop_run: SingleLoopRunEntity | None = None class WorkflowAppGenerateEntity(AppGenerateEntity): @@ -229,7 +229,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): node_id: str inputs: dict - single_iteration_run: Optional[SingleIterationRunEntity] = None + single_iteration_run: SingleIterationRunEntity | None = None class SingleLoopRunEntity(BaseModel): """ @@ -239,4 +239,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): node_id: str inputs: dict - single_loop_run: Optional[SingleLoopRunEntity] = None + single_loop_run: SingleLoopRunEntity | None = None diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index fc04e60836..6d2808b447 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime -from enum import Enum, StrEnum -from typing import Any, Optional +from enum import StrEnum, auto +from typing import Any from pydantic import BaseModel @@ -81,20 +81,20 @@ class QueueIterationStartEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None + metadata: Mapping[str, Any] | None = None class QueueIterationNextEvent(AppQueueEvent): @@ -109,19 +109,19 @@ class QueueIterationNextEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration run in parallel mode run id""" node_run_index: int - output: Optional[Any] = None # output for the current iteration - duration: Optional[float] = None + output: Any | None = None # output for the current iteration + duration: float | None = None class QueueIterationCompletedEvent(AppQueueEvent): @@ -135,23 +135,23 @@ class QueueIterationCompletedEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 - error: Optional[str] = None + error: str | None = None class QueueLoopStartEvent(AppQueueEvent): @@ -164,20 +164,20 @@ class QueueLoopStartEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None + metadata: Mapping[str, Any] | None = None class QueueLoopNextEvent(AppQueueEvent): @@ -192,19 +192,19 @@ class QueueLoopNextEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration run in parallel mode run id""" node_run_index: int - output: Optional[Any] = None # output for the current loop - duration: Optional[float] = None + output: Any | None = None # output for the current loop + duration: float | None = None class QueueLoopCompletedEvent(AppQueueEvent): @@ -218,23 +218,23 @@ class QueueLoopCompletedEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 - error: Optional[str] = None + error: str | None = None class QueueTextChunkEvent(AppQueueEvent): @@ -244,11 +244,11 @@ class QueueTextChunkEvent(AppQueueEvent): event: QueueEvent = QueueEvent.TEXT_CHUNK text: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None """from variable selector""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -285,9 +285,9 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES retriever_resources: Sequence[RetrievalSourceMetadata] - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -306,7 +306,7 @@ class QueueMessageEndEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.MESSAGE_END - llm_result: Optional[LLMResult] = None + llm_result: LLMResult | None = None class QueueAdvancedChatMessageEndEvent(AppQueueEvent): @@ -332,7 +332,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None class QueueWorkflowFailedEvent(AppQueueEvent): @@ -352,7 +352,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent): event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED exceptions_count: int - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None class QueueNodeStartedEvent(AppQueueEvent): @@ -367,23 +367,23 @@ class QueueNodeStartedEvent(AppQueueEvent): node_type: NodeType node_data: BaseNodeData node_run_index: int = 1 - predecessor_node_id: Optional[str] = None - parallel_id: Optional[str] = None + predecessor_node_id: str | None = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration run in parallel mode run id""" - agent_strategy: Optional[AgentNodeStrategyInit] = None + agent_strategy: AgentNodeStrategyInit | None = None class QueueNodeSucceededEvent(AppQueueEvent): @@ -397,30 +397,30 @@ class QueueNodeSucceededEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None - error: Optional[str] = None + error: str | None = None """single iteration duration map""" - iteration_duration_map: Optional[dict[str, float]] = None + iteration_duration_map: dict[str, float] | None = None """single loop duration map""" - loop_duration_map: Optional[dict[str, float]] = None + loop_duration_map: dict[str, float] | None = None class QueueAgentLogEvent(AppQueueEvent): @@ -436,7 +436,7 @@ class QueueAgentLogEvent(AppQueueEvent): error: str | None = None status: str data: Mapping[str, Any] - metadata: Optional[Mapping[str, Any]] = None + metadata: Mapping[str, Any] | None = None node_id: str @@ -445,10 +445,10 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent): event: QueueEvent = QueueEvent.RETRY - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str retry_index: int # retry index @@ -465,24 +465,24 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -498,24 +498,24 @@ class QueueNodeInLoopFailedEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -531,24 +531,24 @@ class QueueNodeExceptionEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -564,24 +564,24 @@ class QueueNodeFailedEvent(AppQueueEvent): node_id: str node_type: NodeType node_data: BaseNodeData - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -610,7 +610,7 @@ class QueueErrorEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.ERROR - error: Optional[Any] = None + error: Any | None = None class QueuePingEvent(AppQueueEvent): @@ -626,15 +626,15 @@ class QueueStopEvent(AppQueueEvent): QueueStopEvent entity """ - class StopBy(Enum): + class StopBy(StrEnum): """ Stop by enum """ - USER_MANUAL = "user-manual" - ANNOTATION_REPLY = "annotation-reply" - OUTPUT_MODERATION = "output-moderation" - INPUT_MODERATION = "input-moderation" + USER_MANUAL = auto() + ANNOTATION_REPLY = auto() + OUTPUT_MODERATION = auto() + INPUT_MODERATION = auto() event: QueueEvent = QueueEvent.STOP stopped_by: StopBy @@ -689,13 +689,13 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent): parallel_id: str parallel_start_node_id: str - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -708,13 +708,13 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent): parallel_id: str parallel_start_node_id: str - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -727,12 +727,12 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent): parallel_id: str parallel_start_node_id: str - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 31183d19a3..92be2fce37 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence -from enum import Enum -from typing import Any, Optional +from enum import StrEnum, auto +from typing import Any from pydantic import BaseModel, ConfigDict, Field @@ -50,37 +50,37 @@ class WorkflowTaskState(TaskState): answer: str = "" -class StreamEvent(Enum): +class StreamEvent(StrEnum): """ Stream event """ - PING = "ping" - ERROR = "error" - MESSAGE = "message" - MESSAGE_END = "message_end" - TTS_MESSAGE = "tts_message" - TTS_MESSAGE_END = "tts_message_end" - MESSAGE_FILE = "message_file" - MESSAGE_REPLACE = "message_replace" - AGENT_THOUGHT = "agent_thought" - AGENT_MESSAGE = "agent_message" - WORKFLOW_STARTED = "workflow_started" - WORKFLOW_FINISHED = "workflow_finished" - NODE_STARTED = "node_started" - NODE_FINISHED = "node_finished" - NODE_RETRY = "node_retry" - PARALLEL_BRANCH_STARTED = "parallel_branch_started" - PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" - ITERATION_STARTED = "iteration_started" - ITERATION_NEXT = "iteration_next" - ITERATION_COMPLETED = "iteration_completed" - LOOP_STARTED = "loop_started" - LOOP_NEXT = "loop_next" - LOOP_COMPLETED = "loop_completed" - TEXT_CHUNK = "text_chunk" - TEXT_REPLACE = "text_replace" - AGENT_LOG = "agent_log" + PING = auto() + ERROR = auto() + MESSAGE = auto() + MESSAGE_END = auto() + TTS_MESSAGE = auto() + TTS_MESSAGE_END = auto() + MESSAGE_FILE = auto() + MESSAGE_REPLACE = auto() + AGENT_THOUGHT = auto() + AGENT_MESSAGE = auto() + WORKFLOW_STARTED = auto() + WORKFLOW_FINISHED = auto() + NODE_STARTED = auto() + NODE_FINISHED = auto() + NODE_RETRY = auto() + PARALLEL_BRANCH_STARTED = auto() + PARALLEL_BRANCH_FINISHED = auto() + ITERATION_STARTED = auto() + ITERATION_NEXT = auto() + ITERATION_COMPLETED = auto() + LOOP_STARTED = auto() + LOOP_NEXT = auto() + LOOP_COMPLETED = auto() + TEXT_CHUNK = auto() + TEXT_REPLACE = auto() + AGENT_LOG = auto() class StreamResponse(BaseModel): @@ -110,7 +110,7 @@ class MessageStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE id: str answer: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None class MessageAudioStreamResponse(StreamResponse): @@ -139,7 +139,7 @@ class MessageEndStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE_END id: str metadata: dict = Field(default_factory=dict) - files: Optional[Sequence[Mapping[str, Any]]] = None + files: Sequence[Mapping[str, Any]] | None = None class MessageFileStreamResponse(StreamResponse): @@ -172,12 +172,12 @@ class AgentThoughtStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.AGENT_THOUGHT id: str position: int - thought: Optional[str] = None - observation: Optional[str] = None - tool: Optional[str] = None - tool_labels: Optional[dict] = None - tool_input: Optional[str] = None - message_files: Optional[list[str]] = None + thought: str | None = None + observation: str | None = None + tool: str | None = None + tool_labels: dict | None = None + tool_input: str | None = None + message_files: list[str] | None = None class AgentMessageStreamResponse(StreamResponse): @@ -223,16 +223,16 @@ class WorkflowFinishStreamResponse(StreamResponse): id: str workflow_id: str status: str - outputs: Optional[Mapping[str, Any]] = None - error: Optional[str] = None + outputs: Mapping[str, Any] | None = None + error: str | None = None elapsed_time: float total_tokens: int total_steps: int - created_by: Optional[dict] = None + created_by: dict | None = None created_at: int finished_at: int - exceptions_count: Optional[int] = 0 - files: Optional[Sequence[Mapping[str, Any]]] = [] + exceptions_count: int | None = 0 + files: Sequence[Mapping[str, Any]] | None = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED workflow_run_id: str @@ -254,18 +254,18 @@ class NodeStartStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None created_at: int extras: dict = Field(default_factory=dict) - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None - parallel_run_id: Optional[str] = None - agent_strategy: Optional[AgentNodeStrategyInit] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None + parallel_run_id: str | None = None + agent_strategy: AgentNodeStrategyInit | None = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -311,23 +311,23 @@ class NodeFinishStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None status: str - error: Optional[str] = None + error: str | None = None elapsed_time: float - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None created_at: int finished_at: int - files: Optional[Sequence[Mapping[str, Any]]] = [] - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + files: Sequence[Mapping[str, Any]] | None = [] + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None event: StreamEvent = StreamEvent.NODE_FINISHED workflow_run_id: str @@ -380,23 +380,23 @@ class NodeRetryStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None status: str - error: Optional[str] = None + error: str | None = None elapsed_time: float - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None created_at: int finished_at: int - files: Optional[Sequence[Mapping[str, Any]]] = [] - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + files: Sequence[Mapping[str, Any]] | None = [] + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None retry_index: int = 0 event: StreamEvent = StreamEvent.NODE_RETRY @@ -448,10 +448,10 @@ class ParallelBranchStartStreamResponse(StreamResponse): parallel_id: str parallel_branch_id: str - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None created_at: int event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED @@ -471,12 +471,12 @@ class ParallelBranchFinishedStreamResponse(StreamResponse): parallel_id: str parallel_branch_id: str - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None status: str - error: Optional[str] = None + error: str | None = None created_at: int event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED @@ -502,8 +502,8 @@ class IterationNodeStartStreamResponse(StreamResponse): extras: dict = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.ITERATION_STARTED workflow_run_id: str @@ -526,12 +526,12 @@ class IterationNodeNextStreamResponse(StreamResponse): title: str index: int created_at: int - pre_iteration_output: Optional[Any] = None + pre_iteration_output: Any | None = None extras: dict = Field(default_factory=dict) - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parallel_mode_run_id: Optional[str] = None - duration: Optional[float] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parallel_mode_run_id: str | None = None + duration: float | None = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str @@ -552,19 +552,19 @@ class IterationNodeCompletedStreamResponse(StreamResponse): node_id: str node_type: str title: str - outputs: Optional[Mapping] = None + outputs: Mapping | None = None created_at: int - extras: Optional[dict] = None - inputs: Optional[Mapping] = None + extras: dict | None = None + inputs: Mapping | None = None status: WorkflowNodeExecutionStatus - error: Optional[str] = None + error: str | None = None elapsed_time: float total_tokens: int - execution_metadata: Optional[Mapping] = None + execution_metadata: Mapping | None = None finished_at: int steps: int - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.ITERATION_COMPLETED workflow_run_id: str @@ -589,8 +589,8 @@ class LoopNodeStartStreamResponse(StreamResponse): extras: dict = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.LOOP_STARTED workflow_run_id: str @@ -613,12 +613,12 @@ class LoopNodeNextStreamResponse(StreamResponse): title: str index: int created_at: int - pre_loop_output: Optional[Any] = None + pre_loop_output: Any | None = None extras: dict = Field(default_factory=dict) - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parallel_mode_run_id: Optional[str] = None - duration: Optional[float] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parallel_mode_run_id: str | None = None + duration: float | None = None event: StreamEvent = StreamEvent.LOOP_NEXT workflow_run_id: str @@ -639,19 +639,19 @@ class LoopNodeCompletedStreamResponse(StreamResponse): node_id: str node_type: str title: str - outputs: Optional[Mapping] = None + outputs: Mapping | None = None created_at: int - extras: Optional[dict] = None - inputs: Optional[Mapping] = None + extras: dict | None = None + inputs: Mapping | None = None status: WorkflowNodeExecutionStatus - error: Optional[str] = None + error: str | None = None elapsed_time: float total_tokens: int - execution_metadata: Optional[Mapping] = None + execution_metadata: Mapping | None = None finished_at: int steps: int - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.LOOP_COMPLETED workflow_run_id: str @@ -669,7 +669,7 @@ class TextChunkStreamResponse(StreamResponse): """ text: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None event: StreamEvent = StreamEvent.TEXT_CHUNK data: Data @@ -731,7 +731,7 @@ class WorkflowAppStreamResponse(AppStreamResponse): WorkflowAppStreamResponse entity """ - workflow_run_id: Optional[str] = None + workflow_run_id: str | None = None class AppBlockingResponse(BaseModel): @@ -796,8 +796,8 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): id: str workflow_id: str status: str - outputs: Optional[Mapping[str, Any]] = None - error: Optional[str] = None + outputs: Mapping[str, Any] | None = None + error: str | None = None elapsed_time: float total_tokens: int total_steps: int @@ -825,7 +825,7 @@ class AgentLogStreamResponse(StreamResponse): error: str | None = None status: str data: Mapping[str, Any] - metadata: Optional[Mapping[str, Any]] = None + metadata: Mapping[str, Any] | None = None node_id: str event: StreamEvent = StreamEvent.AGENT_LOG diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 3853dccdc5..79fbafe39e 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from sqlalchemy import select @@ -17,7 +16,7 @@ logger = logging.getLogger(__name__) class AnnotationReplyFeature: def query( self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom - ) -> Optional[MessageAnnotation]: + ) -> MessageAnnotation | None: """ Query app annotations to reply :param app_record: app record diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 6f13f11da0..ffa10cd43c 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -3,7 +3,7 @@ import time import uuid from collections.abc import Generator, Mapping from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Union from core.errors.error import AppInvokeQuotaExceededError from extensions.ext_redis import redis_client @@ -63,7 +63,7 @@ class RateLimit: if timeout_requests: redis_client.hdel(self.active_requests_key, *timeout_requests) - def enter(self, request_id: Optional[str] = None) -> str: + def enter(self, request_id: str | None = None) -> str: if self.disabled(): return RateLimit._UNLIMITED_REQUEST_ID if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL: diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 4931300901..45e3c0006b 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,5 @@ import logging import time -from typing import Optional from sqlalchemy import select from sqlalchemy.orm import Session @@ -101,7 +100,7 @@ class BasedGenerateTaskPipeline: """ return PingStreamResponse(task_id=self._application_generate_entity.task_id) - def _init_output_moderation(self) -> Optional[OutputModeration]: + def _init_output_moderation(self) -> OutputModeration | None: """ Init output moderation. :return: @@ -118,7 +117,7 @@ class BasedGenerateTaskPipeline: ) return None - def handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: + def handle_output_moderation_when_task_finished(self, completion: str) -> str | None: """ Handle output moderation when task finished. :param completion: completion diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 71fd5ac653..67abb569e3 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Generator from threading import Thread -from typing import Optional, Union, cast +from typing import Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -109,7 +109,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): task_state=self._task_state, ) - self._conversation_name_generate_thread: Optional[Thread] = None + self._conversation_name_generate_thread: Thread | None = None def process( self, @@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if self._task_state.metadata: extras["metadata"] = self._task_state.metadata.model_dump() response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] - if self._conversation_mode == AppMode.COMPLETION.value: + if self._conversation_mode == AppMode.COMPLETION: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=CompletionAppBlockingResponse.Data( @@ -209,7 +209,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): return None def _wrapper_process_stream_response( - self, trace_manager: Optional[TraceQueueManager] = None + self, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id @@ -252,7 +252,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None + self, publisher: AppGeneratorTTSPublisher | None, trace_manager: TraceQueueManager | None = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None): + def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None): """ Save message. :return: @@ -466,14 +466,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): task_id=self._application_generate_entity.task_id, id=message_id, answer=answer ) - def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]: + def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> AgentThoughtStreamResponse | None: """ Agent thought to stream response. :param event: agent thought event :return: """ with Session(db.engine, expire_on_commit=False) as session: - agent_thought: Optional[MessageAgentThought] = ( + agent_thought: MessageAgentThought | None = ( session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() ) diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index e865ba9d60..90ffdcf1f6 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -1,6 +1,6 @@ import logging from threading import Thread -from typing import Optional, Union +from typing import Union from flask import Flask, current_app from sqlalchemy import select @@ -52,7 +52,7 @@ class MessageCycleManager: self._application_generate_entity = application_generate_entity self._task_state = task_state - def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: + def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None: """ Generate conversation name. :param conversation_id: conversation id @@ -92,7 +92,7 @@ class MessageCycleManager: if not conversation: return - if conversation.mode != AppMode.COMPLETION.value: + if conversation.mode != AppMode.COMPLETION: app_model = conversation.app if not app_model: return @@ -111,7 +111,7 @@ class MessageCycleManager: db.session.commit() db.session.close() - def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: + def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> MessageAnnotation | None: """ Handle annotation reply. :param event: event @@ -141,7 +141,7 @@ class MessageCycleManager: if self._application_generate_entity.app_config.additional_features.show_retrieve_source: self._task_state.metadata.retriever_resources = event.retriever_resources - def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: + def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None: """ Message file to stream response. :param event: event @@ -180,7 +180,7 @@ class MessageCycleManager: return None def message_to_stream_response( - self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None + self, answer: str, message_id: str, from_variable_selector: list[str] | None = None ) -> MessageStreamResponse: """ Message to stream response. diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 89190c36cc..1e0fba6215 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -5,7 +5,6 @@ import queue import re import threading from collections.abc import Iterable -from typing import Optional from core.app.entities.queue_entities import ( MessageQueueMessage, @@ -56,7 +55,7 @@ def _process_future( class AppGeneratorTTSPublisher: - def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None): + def __init__(self, tenant_id: str, voice: str, language: str | None = None): self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id self.msg_text = "" @@ -73,7 +72,7 @@ class AppGeneratorTTSPublisher: if not voice or voice not in values: self.voice = self.voices[0].get("value") self.max_sentence = 2 - self._last_audio_event: Optional[AudioTrunk] = None + self._last_audio_event: AudioTrunk | None = None # FIXME better way to handle this threading.start threading.Thread(target=self._runtime).start() self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 30cdab26dc..9ee02acc92 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -1,5 +1,5 @@ from collections.abc import Iterable, Mapping -from typing import Any, Optional, TextIO, Union +from typing import Any, TextIO, Union from pydantic import BaseModel @@ -23,7 +23,7 @@ def get_colored_text(text: str, color: str) -> str: return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" -def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None): +def print_text(text: str, color: str | None = None, end: str = "", file: TextIO | None = None): """Print text with highlighting and no end characters.""" text_to_print = get_colored_text(text, color) if color else text print(text_to_print, end=end, file=file) @@ -34,10 +34,10 @@ def print_text(text: str, color: Optional[str] = None, end: str = "", file: Opti class DifyAgentCallbackHandler(BaseModel): """Callback Handler that prints to std out.""" - color: Optional[str] = "" + color: str | None = "" current_loop: int = 1 - def __init__(self, color: Optional[str] = None): + def __init__(self, color: str | None = None): super().__init__() """Initialize callback handler.""" # use a specific color is not specified @@ -58,9 +58,9 @@ class DifyAgentCallbackHandler(BaseModel): tool_name: str, tool_inputs: Mapping[str, Any], tool_outputs: Iterable[ToolInvokeMessage] | str, - message_id: Optional[str] = None, - timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None, + message_id: str | None = None, + timer: Any | None = None, + trace_manager: TraceQueueManager | None = None, ): """If not the final action, print out observation.""" if dify_config.DEBUG: @@ -98,7 +98,7 @@ class DifyAgentCallbackHandler(BaseModel): else: print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) - def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any): + def on_agent_finish(self, color: str | None = None, **kwargs: Any): """Run on agent end.""" if dify_config.DEBUG: print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) diff --git a/api/core/callback_handler/workflow_tool_callback_handler.py b/api/core/callback_handler/workflow_tool_callback_handler.py index 350b18772b..23aabd9970 100644 --- a/api/core/callback_handler/workflow_tool_callback_handler.py +++ b/api/core/callback_handler/workflow_tool_callback_handler.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Iterable, Mapping -from typing import Any, Optional +from typing import Any from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text from core.ops.ops_trace_manager import TraceQueueManager @@ -14,9 +14,9 @@ class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): tool_name: str, tool_inputs: Mapping[str, Any], tool_outputs: Iterable[ToolInvokeMessage], - message_id: Optional[str] = None, - timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None, + message_id: str | None = None, + timer: Any | None = None, + trace_manager: TraceQueueManager | None = None, ) -> Generator[ToolInvokeMessage, None, None]: for tool_output in tool_outputs: print_text("\n[on_tool_execution]\n", color=self.color) diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py index 656bf4aa72..cf958b91d2 100644 --- a/api/core/entities/agent_entities.py +++ b/api/core/entities/agent_entities.py @@ -1,8 +1,8 @@ -from enum import Enum +from enum import StrEnum, auto -class PlanningStrategy(Enum): - ROUTER = "router" - REACT_ROUTER = "react_router" - REACT = "react" - FUNCTION_CALL = "function_call" +class PlanningStrategy(StrEnum): + ROUTER = auto() + REACT_ROUTER = auto() + REACT = auto() + FUNCTION_CALL = auto() diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py index 9b4934646b..89b48fd2ef 100644 --- a/api/core/entities/embedding_type.py +++ b/api/core/entities/embedding_type.py @@ -1,10 +1,10 @@ -from enum import Enum +from enum import StrEnum, auto -class EmbeddingInputType(Enum): +class EmbeddingInputType(StrEnum): """ Enum for embedding input type. """ - DOCUMENT = "document" - QUERY = "query" + DOCUMENT = auto() + QUERY = auto() diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index 90c9879733..6143b9b703 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,11 +1,9 @@ -from typing import Optional - from pydantic import BaseModel class PreviewDetail(BaseModel): content: str - child_chunks: Optional[list[str]] = None + child_chunks: list[str] | None = None class QAPreviewDetail(BaseModel): @@ -16,4 +14,4 @@ class QAPreviewDetail(BaseModel): class IndexingEstimate(BaseModel): total_segments: int preview: list[PreviewDetail] - qa_preview: Optional[list[QAPreviewDetail]] = None + qa_preview: list[QAPreviewDetail] | None = None diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 0fd49b059c..663a8164c6 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,6 +1,5 @@ from collections.abc import Sequence -from enum import Enum -from typing import Optional +from enum import StrEnum, auto from pydantic import BaseModel, ConfigDict @@ -9,16 +8,16 @@ from core.model_runtime.entities.model_entities import ModelType, ProviderModel from core.model_runtime.entities.provider_entities import ProviderEntity -class ModelStatus(Enum): +class ModelStatus(StrEnum): """ Enum class for model status. """ - ACTIVE = "active" + ACTIVE = auto() NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" NO_PERMISSION = "no-permission" - DISABLED = "disabled" + DISABLED = auto() CREDENTIAL_REMOVED = "credential-removed" @@ -29,8 +28,8 @@ class SimpleModelProviderEntity(BaseModel): provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None supported_model_types: list[ModelType] def __init__(self, provider_entity: ProviderEntity): @@ -92,8 +91,8 @@ class DefaultModelProviderEntity(BaseModel): provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] = [] diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index fbd62437e6..0afb51edce 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -1,20 +1,20 @@ -from enum import StrEnum +from enum import StrEnum, auto class CommonParameterType(StrEnum): SECRET_INPUT = "secret-input" TEXT_INPUT = "text-input" - SELECT = "select" - STRING = "string" - NUMBER = "number" - FILE = "file" - FILES = "files" + SELECT = auto() + STRING = auto() + NUMBER = auto() + FILE = auto() + FILES = auto() SYSTEM_FILES = "system-files" - BOOLEAN = "boolean" + BOOLEAN = auto() APP_SELECTOR = "app-selector" MODEL_SELECTOR = "model-selector" TOOLS_SELECTOR = "array[tools]" - ANY = "any" + ANY = auto() # Dynamic select parameter # Once you are not sure about the available options until authorization is done @@ -23,29 +23,29 @@ class CommonParameterType(StrEnum): # TOOL_SELECTOR = "tool-selector" # MCP object and array type parameters - ARRAY = "array" - OBJECT = "object" + ARRAY = auto() + OBJECT = auto() class AppSelectorScope(StrEnum): - ALL = "all" - CHAT = "chat" - WORKFLOW = "workflow" - COMPLETION = "completion" + ALL = auto() + CHAT = auto() + WORKFLOW = auto() + COMPLETION = auto() class ModelSelectorScope(StrEnum): - LLM = "llm" + LLM = auto() TEXT_EMBEDDING = "text-embedding" - RERANK = "rerank" - TTS = "tts" - SPEECH2TEXT = "speech2text" - MODERATION = "moderation" - VISION = "vision" + RERANK = auto() + TTS = auto() + SPEECH2TEXT = auto() + MODERATION = auto() + VISION = auto() class ToolSelectorScope(StrEnum): - ALL = "all" - CUSTOM = "custom" - BUILTIN = "builtin" - WORKFLOW = "workflow" + ALL = auto() + CUSTOM = auto() + BUILTIN = auto() + WORKFLOW = auto() diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 5309e4e638..d694a27942 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -4,7 +4,6 @@ import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from typing import Optional from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import func, select @@ -92,7 +91,7 @@ class ProviderConfiguration(BaseModel): ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) - def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: + def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ Get current credentials. @@ -165,7 +164,7 @@ class ProviderConfiguration(BaseModel): return credentials - def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]: + def get_system_configuration_status(self) -> SystemConfigurationStatus | None: """ Get system configuration status. :return: @@ -793,9 +792,7 @@ class ProviderConfiguration(BaseModel): stmt = stmt.where(ProviderModelCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None - def get_custom_model_credential( - self, model_type: ModelType, model: str, credential_id: str | None - ) -> Optional[dict]: + def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None: """ Get custom model credentials. @@ -1272,7 +1269,7 @@ class ProviderConfiguration(BaseModel): return model_setting - def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]: + def get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None: """ Get provider model setting. :param model_type: model type @@ -1448,7 +1445,7 @@ class ProviderConfiguration(BaseModel): def get_provider_model( self, model_type: ModelType, model: str, only_active: bool = False - ) -> Optional[ModelWithProviderEntity]: + ) -> ModelWithProviderEntity | None: """ Get provider model. :param model_type: model type @@ -1465,7 +1462,7 @@ class ProviderConfiguration(BaseModel): return None def get_provider_models( - self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None + self, model_type: ModelType | None = None, only_active: bool = False, model: str | None = None ) -> list[ModelWithProviderEntity]: """ Get provider models. @@ -1649,7 +1646,7 @@ class ProviderConfiguration(BaseModel): model_types: Sequence[ModelType], provider_schema: ProviderEntity, model_setting_map: dict[ModelType, dict[str, ModelSettings]], - model: Optional[str] = None, + model: str | None = None, ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -1783,7 +1780,7 @@ class ProviderConfigurations(BaseModel): super().__init__(tenant_id=tenant_id) def get_models( - self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False + self, provider: str | None = None, model_type: ModelType | None = None, only_active: bool = False ) -> list[ModelWithProviderEntity]: """ Get available models. diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 52acbc1eef..0496959ce2 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,5 +1,5 @@ -from enum import Enum -from typing import Optional, Union +from enum import StrEnum, auto +from typing import Union from pydantic import BaseModel, ConfigDict, Field @@ -13,14 +13,14 @@ from core.model_runtime.entities.model_entities import ModelType from core.tools.entities.common_entities import I18nObject -class ProviderQuotaType(Enum): - PAID = "paid" +class ProviderQuotaType(StrEnum): + PAID = auto() """hosted paid quota""" - FREE = "free" + FREE = auto() """third-party free quota""" - TRIAL = "trial" + TRIAL = auto() """hosted trial quota""" @staticmethod @@ -31,25 +31,25 @@ class ProviderQuotaType(Enum): raise ValueError(f"No matching enum found for value '{value}'") -class QuotaUnit(Enum): - TIMES = "times" - TOKENS = "tokens" - CREDITS = "credits" +class QuotaUnit(StrEnum): + TIMES = auto() + TOKENS = auto() + CREDITS = auto() -class SystemConfigurationStatus(Enum): +class SystemConfigurationStatus(StrEnum): """ Enum class for system configuration status. """ - ACTIVE = "active" + ACTIVE = auto() QUOTA_EXCEEDED = "quota-exceeded" - UNSUPPORTED = "unsupported" + UNSUPPORTED = auto() class RestrictModel(BaseModel): model: str - base_model_name: Optional[str] = None + base_model_name: str | None = None model_type: ModelType # pydantic configs @@ -84,9 +84,9 @@ class SystemConfiguration(BaseModel): """ enabled: bool - current_quota_type: Optional[ProviderQuotaType] = None + current_quota_type: ProviderQuotaType | None = None quota_configurations: list[QuotaConfiguration] = [] - credentials: Optional[dict] = None + credentials: dict | None = None class CustomProviderConfiguration(BaseModel): @@ -95,8 +95,8 @@ class CustomProviderConfiguration(BaseModel): """ credentials: dict - current_credential_id: Optional[str] = None - current_credential_name: Optional[str] = None + current_credential_id: str | None = None + current_credential_name: str | None = None available_credentials: list[CredentialConfiguration] = [] @@ -108,10 +108,10 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType credentials: dict | None = None - current_credential_id: Optional[str] = None - current_credential_name: Optional[str] = None + current_credential_id: str | None = None + current_credential_name: str | None = None available_model_credentials: list[CredentialConfiguration] = [] - unadded_to_model_list: Optional[bool] = False + unadded_to_model_list: bool | None = False # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -131,7 +131,7 @@ class CustomConfiguration(BaseModel): Model class for provider custom configuration. """ - provider: Optional[CustomProviderConfiguration] = None + provider: CustomProviderConfiguration | None = None models: list[CustomModelConfiguration] = [] can_added_models: list[UnaddedModelConfiguration] = [] @@ -168,14 +168,14 @@ class BasicProviderConfig(BaseModel): Base model class for common provider settings like credentials """ - class Type(Enum): - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - TEXT_INPUT = CommonParameterType.TEXT_INPUT.value - SELECT = CommonParameterType.SELECT.value - BOOLEAN = CommonParameterType.BOOLEAN.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + class Type(StrEnum): + SECRET_INPUT = CommonParameterType.SECRET_INPUT + TEXT_INPUT = CommonParameterType.TEXT_INPUT + SELECT = CommonParameterType.SELECT + BOOLEAN = CommonParameterType.BOOLEAN + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR @classmethod def value_of(cls, value: str) -> "ProviderConfig.Type": @@ -205,12 +205,12 @@ class ProviderConfig(BasicProviderConfig): scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None required: bool = False - default: Optional[Union[int, str, float, bool]] = None - options: Optional[list[Option]] = None - label: Optional[I18nObject] = None - help: Optional[I18nObject] = None - url: Optional[str] = None - placeholder: Optional[I18nObject] = None + default: Union[int, str, float, bool] | None = None + options: list[Option] | None = None + label: I18nObject | None = None + help: I18nObject | None = None + url: str | None = None + placeholder: I18nObject | None = None def to_basic_provider_config(self) -> BasicProviderConfig: return BasicProviderConfig(type=self.type, name=self.name) diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 642f24a411..8c1ba98ae1 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -1,12 +1,9 @@ -from typing import Optional - - class LLMError(ValueError): """Base class for all LLM exceptions.""" - description: Optional[str] = None + description: str | None = None - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index eee914a529..c2789a7a35 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -1,10 +1,10 @@ -import enum import importlib.util import json import logging import os +from enum import StrEnum, auto from pathlib import Path -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -13,18 +13,18 @@ from core.helper.position_helper import sort_to_dict_by_position_map logger = logging.getLogger(__name__) -class ExtensionModule(enum.Enum): - MODERATION = "moderation" - EXTERNAL_DATA_TOOL = "external_data_tool" +class ExtensionModule(StrEnum): + MODERATION = auto() + EXTERNAL_DATA_TOOL = auto() class ModuleExtension(BaseModel): - extension_class: Optional[Any] = None + extension_class: Any | None = None name: str - label: Optional[dict] = None - form_schema: Optional[list] = None + label: dict | None = None + form_schema: list | None = None builtin: bool = True - position: Optional[int] = None + position: int | None = None class Extensible: @@ -32,9 +32,9 @@ class Extensible: name: str tenant_id: str - config: Optional[dict] = None + config: dict | None = None - def __init__(self, tenant_id: str, config: Optional[dict] = None): + def __init__(self, tenant_id: str, config: dict | None = None): self.tenant_id = tenant_id self.config = config diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 45878e763f..564801f189 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -1,5 +1,3 @@ -from typing import Optional - from sqlalchemy import select from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor @@ -39,7 +37,7 @@ class ApiExternalDataTool(ExternalDataTool): if not api_based_extension: raise ValueError("api_based_extension_id is invalid") - def query(self, inputs: dict, query: Optional[str] = None) -> str: + def query(self, inputs: dict, query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index 81f1aaf174..cbec2e4e42 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional from core.extension.extensible import Extensible, ExtensionModule @@ -16,7 +15,7 @@ class ExternalDataTool(Extensible, ABC): variable: str """the tool variable name of app tool""" - def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None): + def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None): super().__init__(tenant_id, config) self.app_id = app_id self.variable = variable @@ -34,7 +33,7 @@ class ExternalDataTool(Extensible, ABC): raise NotImplementedError @abstractmethod - def query(self, inputs: dict, query: Optional[str] = None) -> str: + def query(self, inputs: dict, query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 6a9703a569..86bbb7060c 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -1,7 +1,7 @@ import logging from collections.abc import Mapping from concurrent.futures import Future, ThreadPoolExecutor, as_completed -from typing import Any, Optional +from typing import Any from flask import Flask, current_app @@ -63,7 +63,7 @@ class ExternalDataFetch: external_data_tool: ExternalDataVariableEntity, inputs: Mapping[str, Any], query: str, - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """ Query external data tool. :param flask_app: flask app diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 538bc3f525..6c542d681b 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from core.extension.extensible import ExtensionModule from extensions.ext_code_based_extension import code_based_extension @@ -26,7 +26,7 @@ class ExternalDataToolFactory: # FIXME mypy issue here, figure out how to fix it extension_class.validate_config(tenant_id, config) # type: ignore - def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str: + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/file/constants.py b/api/core/file/constants.py index ed1779fd13..0665ed7e0d 100644 --- a/api/core/file/constants.py +++ b/api/core/file/constants.py @@ -9,7 +9,3 @@ FILE_MODEL_IDENTITY = "__dify__file__" def maybe_file_object(o: Any) -> bool: return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY - - -# The default user ID for service API calls. -DEFAULT_SERVICE_API_USER_ID = "DEFAULT-USER" diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py index 3ec29fe23d..bf06dbd1ec 100644 --- a/api/core/file/helpers.py +++ b/api/core/file/helpers.py @@ -5,7 +5,6 @@ import os import time from configs import dify_config -from core.file.constants import DEFAULT_SERVICE_API_USER_ID def get_signed_file_url(upload_file_id: str) -> str: @@ -25,10 +24,6 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, # Plugin access should use internal URL for Docker network communication base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL url = f"{base_url}/files/upload/for-plugin" - - if user_id is None: - user_id = DEFAULT_SERVICE_API_USER_ID - timestamp = str(int(time.time())) nonce = os.urandom(16).hex() key = dify_config.SECRET_KEY.encode() @@ -40,11 +35,8 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, def verify_plugin_file_signature( - *, filename: str, mimetype: str, tenant_id: str, user_id: str | None, timestamp: str, nonce: str, sign: str + *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str ) -> bool: - if user_id is None: - user_id = DEFAULT_SERVICE_API_USER_ID - data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" secret_key = dify_config.SECRET_KEY.encode() recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() diff --git a/api/core/file/models.py b/api/core/file/models.py index 9b74fa387f..dbef7564d6 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field, model_validator @@ -26,7 +26,7 @@ class FileUploadConfig(BaseModel): File Upload Entity. """ - image_config: Optional[ImageConfig] = None + image_config: ImageConfig | None = None allowed_file_types: Sequence[FileType] = Field(default_factory=list) allowed_file_extensions: Sequence[str] = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) @@ -38,21 +38,21 @@ class File(BaseModel): # new and old data formats during serialization and deserialization. dify_model_identity: str = FILE_MODEL_IDENTITY - id: Optional[str] = None # message file id + id: str | None = None # message file id tenant_id: str type: FileType transfer_method: FileTransferMethod # If `transfer_method` is `FileTransferMethod.remote_url`, the # `remote_url` attribute must not be `None`. - remote_url: Optional[str] = None # remote url + remote_url: str | None = None # remote url # If `transfer_method` is `FileTransferMethod.local_file` or # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`. # # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. - related_id: Optional[str] = None - filename: Optional[str] = None - extension: Optional[str] = Field(default=None, description="File extension, should contain dot") - mime_type: Optional[str] = None + related_id: str | None = None + filename: str | None = None + extension: str | None = Field(default=None, description="File extension, should contain dot") + mime_type: str | None = None size: int = -1 # Those properties are private, should not be exposed to the outside. @@ -61,19 +61,19 @@ class File(BaseModel): def __init__( self, *, - id: Optional[str] = None, + id: str | None = None, tenant_id: str, type: FileType, transfer_method: FileTransferMethod, - remote_url: Optional[str] = None, - related_id: Optional[str] = None, - filename: Optional[str] = None, - extension: Optional[str] = None, - mime_type: Optional[str] = None, + remote_url: str | None = None, + related_id: str | None = None, + filename: str | None = None, + extension: str | None = None, + mime_type: str | None = None, size: int = -1, - storage_key: Optional[str] = None, - dify_model_identity: Optional[str] = FILE_MODEL_IDENTITY, - url: Optional[str] = None, + storage_key: str | None = None, + dify_model_identity: str | None = FILE_MODEL_IDENTITY, + url: str | None = None, ): super().__init__( id=id, @@ -108,7 +108,7 @@ class File(BaseModel): return text - def generate_url(self) -> Optional[str]: + def generate_url(self) -> str | None: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.remote_url elif self.transfer_method == FileTransferMethod.LOCAL_FILE: diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 2b580cb373..c44a8e1840 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -2,7 +2,7 @@ import logging from collections.abc import Mapping from enum import StrEnum from threading import Lock -from typing import Any, Optional +from typing import Any from httpx import Timeout, post from pydantic import BaseModel @@ -24,8 +24,8 @@ class CodeExecutionError(Exception): class CodeExecutionResponse(BaseModel): class Data(BaseModel): - stdout: Optional[str] = None - error: Optional[str] = None + stdout: str | None = None + error: str | None = None code: int message: str diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 1c112007cb..00fcfe0b80 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -1,12 +1,11 @@ import json -from enum import Enum +from enum import StrEnum from json import JSONDecodeError -from typing import Optional from extensions.ext_redis import redis_client -class ProviderCredentialsCacheType(Enum): +class ProviderCredentialsCacheType(StrEnum): PROVIDER = "provider" MODEL = "provider_model" LOAD_BALANCING_MODEL = "load_balancing_provider_model" @@ -14,9 +13,9 @@ class ProviderCredentialsCacheType(Enum): class ProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): - self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """ Get cached model provider credentials. diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index 314f052832..2fc8fbf885 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -1,12 +1,14 @@ import os from collections import OrderedDict from collections.abc import Callable +from functools import lru_cache from typing import TypeVar from configs import dify_config -from core.tools.utils.yaml_utils import load_yaml_file +from core.tools.utils.yaml_utils import load_yaml_file_cached +@lru_cache(maxsize=128) def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]: """ Get the mapping from name to index from a YAML file @@ -14,12 +16,17 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> :param file_name: the YAML file name, default to '_position.yaml' :return: a dict with name as key and index as value """ + # FIXME(-LAN-): Cache position maps to prevent file descriptor exhaustion during high-load benchmarks position_file_path = os.path.join(folder_path, file_name) - yaml_content = load_yaml_file(file_path=position_file_path, default_value=[]) + try: + yaml_content = load_yaml_file_cached(file_path=position_file_path) + except Exception: + yaml_content = [] positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()] return {name: index for index, name in enumerate(positions)} +@lru_cache(maxsize=128) def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: """ Get the mapping for tools from name to index from a YAML file. @@ -35,20 +42,6 @@ def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") - ) -def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: - """ - Get the mapping for providers from name to index from a YAML file. - :param folder_path: - :param file_name: the YAML file name, default to '_position.yaml' - :return: a dict with name as key and index as value - """ - position_map = get_position_map(folder_path, file_name=file_name) - return pin_position_map( - position_map, - pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, - ) - - def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]: """ Pin the items in the pin list to the beginning of the position map. diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index 26e738fced..ffb5148386 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod from json import JSONDecodeError -from typing import Any, Optional +from typing import Any from extensions.ext_redis import redis_client @@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC): """Generate cache key based on subclass implementation""" pass - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """Get cached provider credentials""" cached_credentials = redis_client.get(self.cache_key) if cached_credentials: @@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache): class NoOpProviderCredentialCache: """No-op provider credential cache""" - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """Get cached provider credentials""" return None diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index 95a1086ca8..54674d4ff6 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -1,12 +1,11 @@ import json -from enum import Enum +from enum import StrEnum from json import JSONDecodeError -from typing import Optional from extensions.ext_redis import redis_client -class ToolParameterCacheType(Enum): +class ToolParameterCacheType(StrEnum): PARAMETER = "tool_parameter" @@ -15,11 +14,11 @@ class ToolParameterCache: self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str ): self.cache_key = ( - f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" + f"{cache_type}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" f":identity_id:{identity_id}" ) - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """ Get cached model provider credentials. diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index 35e6e292d1..820502e558 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -1,7 +1,7 @@ import contextlib import re from collections.abc import Mapping -from typing import Any, Optional +from typing import Any def is_valid_trace_id(trace_id: str) -> bool: @@ -13,7 +13,7 @@ def is_valid_trace_id(trace_id: str) -> bool: return bool(re.match(r"^[a-zA-Z0-9\-_]{1,128}$", trace_id)) -def get_external_trace_id(request: Any) -> Optional[str]: +def get_external_trace_id(request: Any) -> str | None: """ Retrieve the trace_id from the request. @@ -61,7 +61,7 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]): return {} -def get_trace_id_from_otel_context() -> Optional[str]: +def get_trace_id_from_otel_context() -> str | None: """ Retrieve the current trace ID from the active OpenTelemetry trace context. Returns None if: @@ -88,7 +88,7 @@ def get_trace_id_from_otel_context() -> Optional[str]: return None -def parse_traceparent_header(traceparent: str) -> Optional[str]: +def parse_traceparent_header(traceparent: str) -> str | None: """ Parse the `traceparent` header to extract the trace_id. diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index a5d7f7aac7..af860a1070 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,5 +1,3 @@ -from typing import Optional - from flask import Flask from pydantic import BaseModel @@ -30,8 +28,8 @@ class FreeHostingQuota(HostingQuota): class HostingProvider(BaseModel): enabled: bool = False - credentials: Optional[dict] = None - quota_unit: Optional[QuotaUnit] = None + credentials: dict | None = None + quota_unit: QuotaUnit | None = None quotas: list[HostingQuota] = [] @@ -42,7 +40,7 @@ class HostedModerationConfig(BaseModel): class HostingConfiguration: provider_map: dict[str, HostingProvider] - moderation_config: Optional[HostedModerationConfig] = None + moderation_config: HostedModerationConfig | None = None def __init__(self): self.provider_map = {} diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index ed02b70b03..94e88b55b9 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,7 +5,7 @@ import re import threading import time import uuid -from typing import Any, Optional +from typing import Any from flask import current_app from sqlalchemy import select @@ -230,9 +230,9 @@ class IndexingRunner: tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, - doc_form: Optional[str] = None, + doc_form: str | None = None, doc_language: str = "English", - dataset_id: Optional[str] = None, + dataset_id: str | None = None, indexing_technique: str = "economy", ) -> IndexingEstimate: """ @@ -421,7 +421,7 @@ class IndexingRunner: max_tokens: int, chunk_overlap: int, separator: str, - embedding_model_instance: Optional[ModelInstance], + embedding_model_instance: ModelInstance | None, ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. @@ -655,7 +655,7 @@ class IndexingRunner: @staticmethod def _update_document_index_status( - document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None + document_id: str, after_indexing_status: str, extra_update_params: dict | None = None ): """ Update the document indexing status. diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index d4c4f10a12..83c727ffe0 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -2,7 +2,7 @@ import json import logging import re from collections.abc import Sequence -from typing import Optional, cast +from typing import cast import json_repair @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) class LLMGenerator: @classmethod def generate_conversation_name( - cls, tenant_id: str, query, conversation_id: Optional[str] = None, app_id: Optional[str] = None + cls, tenant_id: str, query, conversation_id: str | None = None, app_id: str | None = None ): prompt = CONVERSATION_TITLE_PROMPT diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index e0b70c132f..1e302b7668 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -2,7 +2,7 @@ import json from collections.abc import Generator, Mapping, Sequence from copy import deepcopy from enum import StrEnum -from typing import Any, Literal, Optional, cast, overload +from typing import Any, Literal, cast, overload import json_repair from pydantic import TypeAdapter, ValidationError @@ -51,12 +51,12 @@ def invoke_llm_with_structured_output( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: Literal[True], - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... @overload def invoke_llm_with_structured_output( @@ -66,12 +66,12 @@ def invoke_llm_with_structured_output( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: Literal[False], - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput: ... @overload def invoke_llm_with_structured_output( @@ -81,12 +81,12 @@ def invoke_llm_with_structured_output( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... def invoke_llm_with_structured_output( *, @@ -95,12 +95,12 @@ def invoke_llm_with_structured_output( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Optional[Mapping] = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: """ Invoke large language model with structured output @@ -166,7 +166,7 @@ def invoke_llm_with_structured_output( def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]: result_text: str = "" prompt_messages: Sequence[PromptMessage] = [] - system_fingerprint: Optional[str] = None + system_fingerprint: str | None = None for event in llm_result: if isinstance(event, LLMResultChunk): prompt_messages = event.prompt_messages diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index a56c6ef86e..a6bb4fdda9 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -4,7 +4,7 @@ import json import os import secrets import urllib.parse -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from urllib.parse import urljoin, urlparse import httpx @@ -127,7 +127,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: return False, "" -def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]: +def discover_oauth_metadata(server_url: str, protocol_version: str | None = None) -> OAuthMetadata | None: """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" # First check if the server supports OAuth 2.0 Resource Discovery support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url) @@ -157,7 +157,7 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N def start_authorization( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_information: OAuthClientInformation, redirect_url: str, provider_id: str, @@ -212,7 +212,7 @@ def start_authorization( def exchange_authorization( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_information: OAuthClientInformation, authorization_code: str, code_verifier: str, @@ -247,7 +247,7 @@ def exchange_authorization( def refresh_authorization( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_information: OAuthClientInformation, refresh_token: str, ) -> OAuthTokens: @@ -278,7 +278,7 @@ def refresh_authorization( def register_client( server_url: str, - metadata: Optional[OAuthMetadata], + metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, ) -> OAuthClientInformationFull: """Performs OAuth 2.0 Dynamic Client Registration.""" @@ -302,8 +302,8 @@ def register_client( def auth( provider: MCPProviderEntity, mcp_service: "MCPToolManageService", - authorization_code: Optional[str] = None, - state_param: Optional[str] = None, + authorization_code: str | None = None, + state_param: str | None = None, ) -> dict[str, str]: """Orchestrates the full auth flow with a server using secure Redis state storage.""" server_url = provider.decrypt_server_url() diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 6f52c65234..212c2eb073 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -142,7 +142,7 @@ def handle_call_tool( end_user, args, InvokeFrom.SERVICE_API, - streaming=app.mode == AppMode.AGENT_CHAT.value, + streaming=app.mode == AppMode.AGENT_CHAT, ) answer = extract_answer_from_response(app, response) @@ -157,7 +157,7 @@ def build_parameter_schema( """Build parameter schema for the tool""" parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict) - if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}: + if app_mode in {AppMode.COMPLETION, AppMode.WORKFLOW}: return { "type": "object", "properties": parameters, @@ -175,9 +175,9 @@ def build_parameter_schema( def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]: """Prepare arguments based on app mode""" - if app.mode == AppMode.WORKFLOW.value: + if app.mode == AppMode.WORKFLOW: return {"inputs": arguments} - elif app.mode == AppMode.COMPLETION.value: + elif app.mode == AppMode.COMPLETION: return {"query": "", "inputs": arguments} else: # Chat modes - create a copy to avoid modifying original dict @@ -218,13 +218,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str: def process_mapping_response(app: App, response: Mapping) -> str: """Process mapping response based on app mode""" if app.mode in { - AppMode.ADVANCED_CHAT.value, - AppMode.COMPLETION.value, - AppMode.CHAT.value, - AppMode.AGENT_CHAT.value, + AppMode.ADVANCED_CHAT, + AppMode.COMPLETION, + AppMode.CHAT, + AppMode.AGENT_CHAT, }: return response.get("answer", "") - elif app.mode == AppMode.WORKFLOW.value: + elif app.mode == AppMode.WORKFLOW: return json.dumps(response["data"]["outputs"], ensure_ascii=False) else: raise ValueError("Invalid app mode: " + str(app.mode)) diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index fbad5576aa..653b3773c0 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -4,7 +4,7 @@ from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from datetime import timedelta from types import TracebackType -from typing import Any, Generic, Optional, Self, TypeVar +from typing import Any, Generic, Self, TypeVar from httpx import HTTPStatusError from pydantic import BaseModel @@ -212,7 +212,7 @@ class BaseSession( request: SendRequestT, result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, - metadata: Optional[MessageMetadata] = None, + metadata: MessageMetadata | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 55c989ca1e..10c1ebb7a9 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -1,6 +1,6 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Annotated, Any, Generic, Literal, Optional, TypeAlias, TypeVar +from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel from pydantic.networks import AnyUrl, UrlConstraints @@ -1288,45 +1288,45 @@ class SessionMessage: """A message with specific metadata for transport-specific features.""" message: JSONRPCMessage - metadata: Optional[MessageMetadata] = None + metadata: MessageMetadata | None = None class OAuthClientMetadata(BaseModel): client_name: str redirect_uris: list[str] - grant_types: Optional[list[str]] = None - response_types: Optional[list[str]] = None - token_endpoint_auth_method: Optional[str] = None - client_uri: Optional[str] = None - scope: Optional[str] = None + grant_types: list[str] | None = None + response_types: list[str] | None = None + token_endpoint_auth_method: str | None = None + client_uri: str | None = None + scope: str | None = None class OAuthClientInformation(BaseModel): client_id: str - client_secret: Optional[str] = None + client_secret: str | None = None class OAuthClientInformationFull(OAuthClientInformation): client_name: str | None = None redirect_uris: list[str] - scope: Optional[str] = None - grant_types: Optional[list[str]] = None - response_types: Optional[list[str]] = None - token_endpoint_auth_method: Optional[str] = None + scope: str | None = None + grant_types: list[str] | None = None + response_types: list[str] | None = None + token_endpoint_auth_method: str | None = None class OAuthTokens(BaseModel): access_token: str token_type: str - expires_in: Optional[int] = None - refresh_token: Optional[str] = None - scope: Optional[str] = None + expires_in: int | None = None + refresh_token: str | None = None + scope: str | None = None class OAuthMetadata(BaseModel): authorization_endpoint: str token_endpoint: str - registration_endpoint: Optional[str] = None + registration_endpoint: str | None = None response_types_supported: list[str] - grant_types_supported: Optional[list[str]] = None - code_challenge_methods_supported: Optional[list[str]] = None + grant_types_supported: list[str] | None = None + code_challenge_methods_supported: list[str] | None = None diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 1f2525cfed..35af742f2a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from typing import Optional from sqlalchemy import select @@ -96,7 +95,7 @@ class TokenBufferMemory: return AssistantPromptMessage(content=prompt_message_contents) def get_history_prompt_messages( - self, max_token_limit: int = 2000, message_limit: Optional[int] = None + self, max_token_limit: int = 2000, message_limit: int | None = None ) -> Sequence[PromptMessage]: """ Get history prompt messages. @@ -187,7 +186,7 @@ class TokenBufferMemory: human_prefix: str = "Human", ai_prefix: str = "Assistant", max_token_limit: int = 2000, - message_limit: Optional[int] = None, + message_limit: int | None = None, ) -> str: """ Get history prompt text. diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 10df2ad79e..a63e94d59c 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -103,47 +103,47 @@ class ModelInstance: def invoke_llm( self, prompt_messages: Sequence[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: Literal[True] = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Generator: ... @overload def invoke_llm( self, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: Literal[False] = False, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> LLMResult: ... @overload def invoke_llm( self, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: ... def invoke_llm( self, prompt_messages: Sequence[PromptMessage], - model_parameters: Optional[dict] = None, + model_parameters: dict | None = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[Sequence[str]] = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -176,7 +176,7 @@ class ModelInstance: ) def get_llm_num_tokens( - self, prompt_messages: Sequence[PromptMessage], tools: Optional[Sequence[PromptMessageTool]] = None + self, prompt_messages: Sequence[PromptMessage], tools: Sequence[PromptMessageTool] | None = None ) -> int: """ Get number of tokens for llm @@ -199,7 +199,7 @@ class ModelInstance: ) def invoke_text_embedding( - self, texts: list[str], user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT + self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT ) -> TextEmbeddingResult: """ Invoke large language model @@ -246,9 +246,9 @@ class ModelInstance: self, query: str, docs: list[str], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -276,7 +276,7 @@ class ModelInstance: ), ) - def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: + def invoke_moderation(self, text: str, user: str | None = None) -> bool: """ Invoke moderation model @@ -297,7 +297,7 @@ class ModelInstance: ), ) - def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: + def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str: """ Invoke large language model @@ -318,7 +318,7 @@ class ModelInstance: ), ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]: """ Invoke large language tts model @@ -397,7 +397,7 @@ class ModelInstance: except Exception as e: raise e - def get_tts_voices(self, language: Optional[str] = None): + def get_tts_voices(self, language: str | None = None): """ Invoke large language tts model voices @@ -470,7 +470,7 @@ class LBModelManager: model_type: ModelType, model: str, load_balancing_configs: list[ModelLoadBalancingConfiguration], - managed_credentials: Optional[dict] = None, + managed_credentials: dict | None = None, ): """ Load balancing model manager @@ -495,7 +495,7 @@ class LBModelManager: else: load_balancing_config.credentials = managed_credentials - def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: + def fetch_next(self) -> ModelLoadBalancingConfiguration | None: """ Get next model load balancing config Strategy: Round Robin diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 5ce4c23dbb..a745a91510 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -31,10 +30,10 @@ class Callback(ABC): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Before invoke callback @@ -60,10 +59,10 @@ class Callback(ABC): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ On new chunk callback @@ -90,10 +89,10 @@ class Callback(ABC): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ After invoke callback @@ -120,10 +119,10 @@ class Callback(ABC): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Invoke error callback @@ -141,7 +140,7 @@ class Callback(ABC): """ raise NotImplementedError() - def print_text(self, text: str, color: Optional[str] = None, end: str = ""): + def print_text(self, text: str, color: str | None = None, end: str = ""): """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(text_to_print, end=end) diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 8411afca92..b366fcc57b 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -2,7 +2,7 @@ import json import logging import sys from collections.abc import Sequence -from typing import Optional, cast +from typing import cast from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -20,10 +20,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Before invoke callback @@ -76,10 +76,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ On new chunk callback @@ -106,10 +106,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ After invoke callback @@ -147,10 +147,10 @@ class LoggingCallback(Callback): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, + user: str | None = None, ): """ Invoke error callback diff --git a/api/core/model_runtime/entities/common_entities.py b/api/core/model_runtime/entities/common_entities.py index 659ad59bd6..c7353de5af 100644 --- a/api/core/model_runtime/entities/common_entities.py +++ b/api/core/model_runtime/entities/common_entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel @@ -8,7 +6,7 @@ class I18nObject(BaseModel): Model class for i18n object. """ - zh_Hans: Optional[str] = None + zh_Hans: str | None = None en_US: str def __init__(self, **data): diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index d5caddb7a3..17f6000d93 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from decimal import Decimal from enum import StrEnum -from typing import Any, Optional, TypedDict, Union +from typing import Any, TypedDict, Union from pydantic import BaseModel, Field @@ -150,13 +150,13 @@ class LLMResult(BaseModel): Model class for llm result. """ - id: Optional[str] = None + id: str | None = None model: str prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) message: AssistantPromptMessage usage: LLMUsage - system_fingerprint: Optional[str] = None - reasoning_content: Optional[str] = None + system_fingerprint: str | None = None + reasoning_content: str | None = None class LLMStructuredOutput(BaseModel): @@ -164,7 +164,7 @@ class LLMStructuredOutput(BaseModel): Model class for llm structured output. """ - structured_output: Optional[Mapping[str, Any]] = None + structured_output: Mapping[str, Any] | None = None class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput): @@ -180,8 +180,8 @@ class LLMResultChunkDelta(BaseModel): index: int message: AssistantPromptMessage - usage: Optional[LLMUsage] = None - finish_reason: Optional[str] = None + usage: LLMUsage | None = None + finish_reason: str | None = None class LLMResultChunk(BaseModel): @@ -191,7 +191,7 @@ class LLMResultChunk(BaseModel): model: str prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - system_fingerprint: Optional[str] = None + system_fingerprint: str | None = None delta: LLMResultChunkDelta diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 7cd2e6a3d1..9235c881e0 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,20 +1,20 @@ from abc import ABC from collections.abc import Mapping, Sequence -from enum import Enum, StrEnum -from typing import Annotated, Any, Literal, Optional, Union +from enum import StrEnum, auto +from typing import Annotated, Any, Literal, Union from pydantic import BaseModel, Field, field_serializer, field_validator -class PromptMessageRole(Enum): +class PromptMessageRole(StrEnum): """ Enum class for prompt message. """ - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - TOOL = "tool" + SYSTEM = auto() + USER = auto() + ASSISTANT = auto() + TOOL = auto() @classmethod def value_of(cls, value: str) -> "PromptMessageRole": @@ -54,11 +54,11 @@ class PromptMessageContentType(StrEnum): Enum class for prompt message content type. """ - TEXT = "text" - IMAGE = "image" - AUDIO = "audio" - VIDEO = "video" - DOCUMENT = "document" + TEXT = auto() + IMAGE = auto() + AUDIO = auto() + VIDEO = auto() + DOCUMENT = auto() class PromptMessageContent(ABC, BaseModel): @@ -108,8 +108,8 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent): """ class DETAIL(StrEnum): - LOW = "low" - HIGH = "high" + LOW = auto() + HIGH = auto() type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW @@ -146,8 +146,8 @@ class PromptMessage(ABC, BaseModel): """ role: PromptMessageRole - content: Optional[str | list[PromptMessageContentUnionTypes]] = None - name: Optional[str] = None + content: str | list[PromptMessageContentUnionTypes] | None = None + name: str | None = None def is_empty(self) -> bool: """ @@ -193,8 +193,8 @@ class PromptMessage(ABC, BaseModel): @field_serializer("content") def serialize_content( - self, content: Optional[Union[str, Sequence[PromptMessageContent]]] - ) -> Optional[str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent]]: + self, content: Union[str, Sequence[PromptMessageContent]] | None + ) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None: if content is None or isinstance(content, str): return content if isinstance(content, list): diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 568149cc37..aee6ce1108 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,23 +1,23 @@ from decimal import Decimal -from enum import Enum, StrEnum -from typing import Any, Optional +from enum import StrEnum, auto +from typing import Any from pydantic import BaseModel, ConfigDict, model_validator from core.model_runtime.entities.common_entities import I18nObject -class ModelType(Enum): +class ModelType(StrEnum): """ Enum class for model type. """ - LLM = "llm" + LLM = auto() TEXT_EMBEDDING = "text-embedding" - RERANK = "rerank" - SPEECH2TEXT = "speech2text" - MODERATION = "moderation" - TTS = "tts" + RERANK = auto() + SPEECH2TEXT = auto() + MODERATION = auto() + TTS = auto() @classmethod def value_of(cls, origin_model_type: str) -> "ModelType": @@ -26,17 +26,17 @@ class ModelType(Enum): :return: model type """ - if origin_model_type in {"text-generation", cls.LLM.value}: + if origin_model_type in {"text-generation", cls.LLM}: return cls.LLM - elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}: + elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}: return cls.TEXT_EMBEDDING - elif origin_model_type in {"reranking", cls.RERANK.value}: + elif origin_model_type in {"reranking", cls.RERANK}: return cls.RERANK - elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}: + elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}: return cls.SPEECH2TEXT - elif origin_model_type in {"tts", cls.TTS.value}: + elif origin_model_type in {"tts", cls.TTS}: return cls.TTS - elif origin_model_type == cls.MODERATION.value: + elif origin_model_type == cls.MODERATION: return cls.MODERATION else: raise ValueError(f"invalid origin model type {origin_model_type}") @@ -63,7 +63,7 @@ class ModelType(Enum): raise ValueError(f"invalid model type {self}") -class FetchFrom(Enum): +class FetchFrom(StrEnum): """ Enum class for fetch from. """ @@ -72,7 +72,7 @@ class FetchFrom(Enum): CUSTOMIZABLE_MODEL = "customizable-model" -class ModelFeature(Enum): +class ModelFeature(StrEnum): """ Enum class for llm feature. """ @@ -80,11 +80,11 @@ class ModelFeature(Enum): TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" - VISION = "vision" + VISION = auto() STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = "document" - VIDEO = "video" - AUDIO = "audio" + DOCUMENT = auto() + VIDEO = auto() + AUDIO = auto() STRUCTURED_OUTPUT = "structured-output" @@ -93,14 +93,14 @@ class DefaultParameterName(StrEnum): Enum class for parameter template variable. """ - TEMPERATURE = "temperature" - TOP_P = "top_p" - TOP_K = "top_k" - PRESENCE_PENALTY = "presence_penalty" - FREQUENCY_PENALTY = "frequency_penalty" - MAX_TOKENS = "max_tokens" - RESPONSE_FORMAT = "response_format" - JSON_SCHEMA = "json_schema" + TEMPERATURE = auto() + TOP_P = auto() + TOP_K = auto() + PRESENCE_PENALTY = auto() + FREQUENCY_PENALTY = auto() + MAX_TOKENS = auto() + RESPONSE_FORMAT = auto() + JSON_SCHEMA = auto() @classmethod def value_of(cls, value: Any) -> "DefaultParameterName": @@ -116,34 +116,34 @@ class DefaultParameterName(StrEnum): raise ValueError(f"invalid parameter name {value}") -class ParameterType(Enum): +class ParameterType(StrEnum): """ Enum class for parameter type. """ - FLOAT = "float" - INT = "int" - STRING = "string" - BOOLEAN = "boolean" - TEXT = "text" + FLOAT = auto() + INT = auto() + STRING = auto() + BOOLEAN = auto() + TEXT = auto() -class ModelPropertyKey(Enum): +class ModelPropertyKey(StrEnum): """ Enum class for model property key. """ - MODE = "mode" - CONTEXT_SIZE = "context_size" - MAX_CHUNKS = "max_chunks" - FILE_UPLOAD_LIMIT = "file_upload_limit" - SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions" - MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk" - DEFAULT_VOICE = "default_voice" - VOICES = "voices" - WORD_LIMIT = "word_limit" - AUDIO_TYPE = "audio_type" - MAX_WORKERS = "max_workers" + MODE = auto() + CONTEXT_SIZE = auto() + MAX_CHUNKS = auto() + FILE_UPLOAD_LIMIT = auto() + SUPPORTED_FILE_EXTENSIONS = auto() + MAX_CHARACTERS_PER_CHUNK = auto() + DEFAULT_VOICE = auto() + VOICES = auto() + WORD_LIMIT = auto() + AUDIO_TYPE = auto() + MAX_WORKERS = auto() class ProviderModel(BaseModel): @@ -154,7 +154,7 @@ class ProviderModel(BaseModel): model: str label: I18nObject model_type: ModelType - features: Optional[list[ModelFeature]] = None + features: list[ModelFeature] | None = None fetch_from: FetchFrom model_properties: dict[ModelPropertyKey, Any] deprecated: bool = False @@ -171,15 +171,15 @@ class ParameterRule(BaseModel): """ name: str - use_template: Optional[str] = None + use_template: str | None = None label: I18nObject type: ParameterType - help: Optional[I18nObject] = None + help: I18nObject | None = None required: bool = False - default: Optional[Any] = None - min: Optional[float] = None - max: Optional[float] = None - precision: Optional[int] = None + default: Any | None = None + min: float | None = None + max: float | None = None + precision: int | None = None options: list[str] = [] @@ -189,7 +189,7 @@ class PriceConfig(BaseModel): """ input: Decimal - output: Optional[Decimal] = None + output: Decimal | None = None unit: Decimal currency: str @@ -200,7 +200,7 @@ class AIModelEntity(ProviderModel): """ parameter_rules: list[ParameterRule] = [] - pricing: Optional[PriceConfig] = None + pricing: PriceConfig | None = None @model_validator(mode="after") def validate_model(self): @@ -220,13 +220,13 @@ class ModelUsage(BaseModel): pass -class PriceType(Enum): +class PriceType(StrEnum): """ Enum class for price type. """ - INPUT = "input" - OUTPUT = "output" + INPUT = auto() + OUTPUT = auto() class PriceInfo(BaseModel): diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index c9aa8d1474..2ccc9e0eae 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,6 +1,5 @@ from collections.abc import Sequence -from enum import Enum -from typing import Optional +from enum import Enum, StrEnum, auto from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -17,16 +16,16 @@ class ConfigurateMethod(Enum): CUSTOMIZABLE_MODEL = "customizable-model" -class FormType(Enum): +class FormType(StrEnum): """ Enum class for form type. """ TEXT_INPUT = "text-input" SECRET_INPUT = "secret-input" - SELECT = "select" - RADIO = "radio" - SWITCH = "switch" + SELECT = auto() + RADIO = auto() + SWITCH = auto() class FormShowOnObject(BaseModel): @@ -62,9 +61,9 @@ class CredentialFormSchema(BaseModel): label: I18nObject type: FormType required: bool = True - default: Optional[str] = None - options: Optional[list[FormOption]] = None - placeholder: Optional[I18nObject] = None + default: str | None = None + options: list[FormOption] | None = None + placeholder: I18nObject | None = None max_length: int = 0 show_on: list[FormShowOnObject] = [] @@ -79,7 +78,7 @@ class ProviderCredentialSchema(BaseModel): class FieldModelSchema(BaseModel): label: I18nObject - placeholder: Optional[I18nObject] = None + placeholder: I18nObject | None = None class ModelCredentialSchema(BaseModel): @@ -98,8 +97,8 @@ class SimpleProviderEntity(BaseModel): provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] models: list[AIModelEntity] = [] @@ -120,24 +119,24 @@ class ProviderEntity(BaseModel): provider: str label: I18nObject - description: Optional[I18nObject] = None - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None - icon_small_dark: Optional[I18nObject] = None - icon_large_dark: Optional[I18nObject] = None - background: Optional[str] = None - help: Optional[ProviderHelpEntity] = None + description: I18nObject | None = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None + icon_small_dark: I18nObject | None = None + icon_large_dark: I18nObject | None = None + background: str | None = None + help: ProviderHelpEntity | None = None supported_model_types: Sequence[ModelType] configurate_methods: list[ConfigurateMethod] models: list[AIModelEntity] = Field(default_factory=list) - provider_credential_schema: Optional[ProviderCredentialSchema] = None - model_credential_schema: Optional[ModelCredentialSchema] = None + provider_credential_schema: ProviderCredentialSchema | None = None + model_credential_schema: ModelCredentialSchema | None = None # pydantic configs model_config = ConfigDict(protected_namespaces=()) # position from plugin _position.yaml - position: Optional[dict[str, list[str]]] = {} + position: dict[str, list[str]] | None = {} @field_validator("models", mode="before") @classmethod diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index 6bcb707684..80cf01fb6c 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -1,12 +1,9 @@ -from typing import Optional - - class InvokeError(ValueError): """Base class for all LLM exceptions.""" - description: Optional[str] = None + description: str | None = None - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description def __str__(self): diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index f41818e270..a3d743c373 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,7 +1,6 @@ import decimal import hashlib from threading import Lock -from typing import Optional from pydantic import BaseModel, ConfigDict, Field @@ -99,7 +98,7 @@ class AIModel(BaseModel): model_schema = self.get_model_schema(model, credentials) # get price info from predefined model schema - price_config: Optional[PriceConfig] = None + price_config: PriceConfig | None = None if model_schema and model_schema.pricing: price_config = model_schema.pricing @@ -132,7 +131,7 @@ class AIModel(BaseModel): currency=price_config.currency, ) - def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: + def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None: """ Get model schema by model name and credentials @@ -171,7 +170,7 @@ class AIModel(BaseModel): return schema - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema from credentials @@ -229,7 +228,7 @@ class AIModel(BaseModel): return schema - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 1d7fd7d447..80dabffa10 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Sequence -from typing import Optional, Union +from typing import Union from pydantic import ConfigDict @@ -94,12 +94,12 @@ class LargeLanguageModel(AIModel): model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: """ Invoke large language model @@ -243,11 +243,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunk, None, None]: """ Invoke result generator @@ -328,7 +328,7 @@ class LargeLanguageModel(AIModel): model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> int: """ Get number of tokens for given prompt messages @@ -403,11 +403,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger before invoke callbacks @@ -451,11 +451,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger new chunk callbacks @@ -498,11 +498,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger after invoke callbacks @@ -548,11 +548,11 @@ class LargeLanguageModel(AIModel): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, stream: bool = True, - user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None, + user: str | None = None, + callbacks: list[Callback] | None = None, ): """ Trigger invoke error callbacks diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py index 19dc1d599a..c3ce6f17ad 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/core/model_runtime/model_providers/__base/moderation_model.py @@ -1,5 +1,4 @@ import time -from typing import Optional from pydantic import ConfigDict @@ -18,7 +17,7 @@ class ModerationModel(AIModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: + def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool: """ Invoke moderation model diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 569e756a3b..81a434405f 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -19,9 +17,9 @@ class RerankModel(AIModel): credentials: dict, query: str, docs: list[str], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> RerankResult: """ Invoke rerank model diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py index c69f65b681..57d7ccf350 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py @@ -1,4 +1,4 @@ -from typing import IO, Optional +from typing import IO from pydantic import ConfigDict @@ -17,7 +17,7 @@ class Speech2TextModel(AIModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str: """ Invoke speech to text model diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index f7bba0eba1..8b335c4951 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import ConfigDict from core.entities.embedding_type import EmbeddingInputType @@ -24,7 +22,7 @@ class TextEmbeddingModel(AIModel): model: str, credentials: dict, texts: list[str], - user: Optional[str] = None, + user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> TextEmbeddingResult: """ @@ -47,7 +45,7 @@ class TextEmbeddingModel(AIModel): model=model, credentials=credentials, texts=texts, - input_type=input_type.value, + input_type=input_type, ) except Exception as e: raise self._transform_invoke_error(e) diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py index 8f8a638af6..23d36c03af 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py @@ -1,10 +1,10 @@ import logging from threading import Lock -from typing import Any, Optional +from typing import Any logger = logging.getLogger(__name__) -_tokenizer: Optional[Any] = None +_tokenizer: Any | None = None _lock = Lock() diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 9ee29f2f2f..ca391162a0 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -1,6 +1,5 @@ import logging from collections.abc import Iterable -from typing import Optional from pydantic import ConfigDict @@ -28,7 +27,7 @@ class TTSModel(AIModel): credentials: dict, content_text: str, voice: str, - user: Optional[str] = None, + user: str | None = None, ) -> Iterable[bytes]: """ Invoke large language model @@ -56,7 +55,7 @@ class TTSModel(AIModel): except Exception as e: raise self._transform_invoke_error(e) - def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None): + def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None): """ Retrieves the list of voices supported by a given text-to-speech (TTS) model. diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 6502b920f5..2434425933 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,14 +1,9 @@ import hashlib import logging -import os from collections.abc import Sequence from threading import Lock -from typing import Optional - -from pydantic import BaseModel import contexts -from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -28,48 +23,20 @@ from core.plugin.impl.model import PluginModelClient logger = logging.getLogger(__name__) -class ModelProviderExtension(BaseModel): - plugin_model_provider_entity: PluginModelProviderEntity - position: Optional[int] = None - - class ModelProviderFactory: - provider_position_map: dict[str, int] - def __init__(self, tenant_id: str): - self.provider_position_map = {} - self.tenant_id = tenant_id self.plugin_model_manager = PluginModelClient() - if not self.provider_position_map: - # get the path of current classes - current_path = os.path.abspath(__file__) - model_providers_path = os.path.dirname(current_path) - - # get _position.yaml file path - self.provider_position_map = get_provider_position_map(model_providers_path) - def get_providers(self) -> Sequence[ProviderEntity]: """ Get all providers :return: list of providers """ - # Fetch plugin model providers + # FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server + # The plugin server should return providers in the desired order plugin_providers = self.get_plugin_model_providers() - - # Convert PluginModelProviderEntity to ModelProviderExtension - model_provider_extensions = [] - for provider in plugin_providers: - model_provider_extensions.append(ModelProviderExtension(plugin_model_provider_entity=provider)) - - sorted_extensions = sort_to_dict_by_position_map( - position_map=self.provider_position_map, - data=model_provider_extensions, - name_func=lambda x: x.plugin_model_provider_entity.declaration.provider, - ) - - return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()] + return [provider.declaration for provider in plugin_providers] def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: """ @@ -238,9 +205,9 @@ class ModelProviderFactory: def get_models( self, *, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - provider_configs: Optional[list[ProviderConfig]] = None, + provider: str | None = None, + model_type: ModelType | None = None, + provider_configs: list[ProviderConfig] | None = None, ) -> list[SimpleProviderEntity]: """ Get all models for given model type diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index 962e417671..c758eaf49f 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -8,7 +8,7 @@ from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6 from pathlib import Path, PurePath from re import Pattern from types import GeneratorType -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Union from uuid import UUID from pydantic import BaseModel @@ -18,7 +18,7 @@ from pydantic_core import Url from pydantic_extra_types.color import Color -def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any): +def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: return model.model_dump(mode=mode, **kwargs) @@ -98,9 +98,9 @@ def jsonable_encoder( exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, + custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, sqlalchemy_safe: bool = True, -): +) -> Any: custom_encoder = custom_encoder or {} if custom_encoder: if type(obj) in custom_encoder: diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index ce7bd21110..573f4ec2a7 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field from sqlalchemy import select @@ -87,7 +85,7 @@ class ApiModeration(Moderation): return result @staticmethod - def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: + def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension | None: stmt = select(APIBasedExtension).where( APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id ) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 752617b654..d76b4689be 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,15 +1,14 @@ from abc import ABC, abstractmethod -from enum import Enum -from typing import Optional +from enum import StrEnum, auto from pydantic import BaseModel, Field from core.extension.extensible import Extensible, ExtensionModule -class ModerationAction(Enum): - DIRECT_OUTPUT = "direct_output" - OVERRIDDEN = "overridden" +class ModerationAction(StrEnum): + DIRECT_OUTPUT = auto() + OVERRIDDEN = auto() class ModerationInputsResult(BaseModel): @@ -34,7 +33,7 @@ class Moderation(Extensible, ABC): module: ExtensionModule = ExtensionModule.MODERATION - def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None): + def __init__(self, app_id: str, tenant_id: str, config: dict | None = None): super().__init__(tenant_id, config) self.app_id = app_id diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 3ac33966cb..21dc58f16f 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationError @@ -21,7 +21,7 @@ class InputModeration: inputs: Mapping[str, Any], query: str, message_id: str, - trace_manager: Optional[TraceQueueManager] = None, + trace_manager: TraceQueueManager | None = None, ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 6993ec8b0b..a97e3d4253 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -1,7 +1,7 @@ import logging import threading import time -from typing import Any, Optional +from typing import Any from flask import Flask, current_app from pydantic import BaseModel, ConfigDict @@ -27,11 +27,11 @@ class OutputModeration(BaseModel): rule: ModerationRule queue_manager: AppQueueManager - thread: Optional[threading.Thread] = None + thread: threading.Thread | None = None thread_running: bool = True buffer: str = "" is_final_chunk: bool = False - final_output: Optional[str] = None + final_output: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) def should_direct_output(self) -> bool: @@ -127,7 +127,7 @@ class OutputModeration(BaseModel): if result.action == ModerationAction.DIRECT_OUTPUT: break - def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: + def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> ModerationOutputsResult | None: try: moderation_factory = ModerationFactory( name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index c661050637..d9519bb078 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,7 +1,6 @@ import json import logging from collections.abc import Sequence -from typing import Optional from urllib.parse import urljoin from opentelemetry.trace import Link, Status, StatusCode @@ -123,7 +122,7 @@ class AliyunDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -356,8 +355,8 @@ class AliyunDataTrace(BaseTraceInstance): GEN_AI_FRAMEWORK: "dify", TOOL_NAME: node_execution.title, TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False), - TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False), - INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False), + TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False), + INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False), OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False), }, status=self.get_workflow_node_status(node_execution), diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index 881ec2141c..09cb6e3fc1 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -7,7 +7,6 @@ import uuid from collections import deque from collections.abc import Sequence from datetime import datetime -from typing import Optional import requests from opentelemetry import trace as trace_api @@ -184,7 +183,7 @@ def generate_span_id() -> int: return span_id -def convert_to_trace_id(uuid_v4: Optional[str]) -> int: +def convert_to_trace_id(uuid_v4: str | None) -> int: try: uuid_obj = uuid.UUID(uuid_v4) return uuid_obj.int @@ -192,7 +191,7 @@ def convert_to_trace_id(uuid_v4: Optional[str]) -> int: raise ValueError(f"Invalid UUID input: {e}") -def convert_string_to_id(string: Optional[str]) -> int: +def convert_string_to_id(string: str | None) -> int: if not string: return generate_span_id() hash_bytes = hashlib.sha256(string.encode("utf-8")).digest() @@ -200,7 +199,7 @@ def convert_string_to_id(string: Optional[str]) -> int: return id -def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int: +def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int: try: uuid_obj = uuid.UUID(uuid_v4) except Exception as e: @@ -209,7 +208,7 @@ def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int: return convert_string_to_id(combined_key) -def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]: +def convert_datetime_to_nanoseconds(start_time_a: datetime | None) -> int | None: if start_time_a is None: return None timestamp_in_seconds = start_time_a.timestamp() diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py index 1caa822cd0..f3dcbc5b8f 100644 --- a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py +++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from typing import Optional from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import Event, Status, StatusCode @@ -10,12 +9,12 @@ class SpanData(BaseModel): model_config = {"arbitrary_types_allowed": True} trace_id: int = Field(..., description="The unique identifier for the trace.") - parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.") + parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.") span_id: int = Field(..., description="The unique identifier for this span.") name: str = Field(..., description="The name of the span.") attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.") events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.") links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.") status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.") - start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.") - end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.") + start_time: int | None = Field(..., description="The start time of the span in nanoseconds.") + end_time: int | None = Field(..., description="The end time of the span in nanoseconds.") diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/core/ops/aliyun_trace/entities/semconv.py index 5d70264320..c9427c776a 100644 --- a/api/core/ops/aliyun_trace/entities/semconv.py +++ b/api/core/ops/aliyun_trace/entities/semconv.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum # public GEN_AI_SESSION_ID = "gen_ai.session.id" @@ -53,7 +53,7 @@ TOOL_DESCRIPTION = "tool.description" TOOL_PARAMETERS = "tool.parameters" -class GenAISpanKind(Enum): +class GenAISpanKind(StrEnum): CHAIN = "CHAIN" RETRIEVER = "RETRIEVER" RERANKER = "RERANKER" diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index c5fbe4d78b..1497bc1863 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -3,7 +3,7 @@ import json import logging import os from datetime import datetime, timedelta -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from urllib.parse import urlparse from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes @@ -92,14 +92,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra raise -def datetime_to_nanos(dt: Optional[datetime]) -> int: +def datetime_to_nanos(dt: datetime | None) -> int: """Convert datetime to nanoseconds since epoch. If None, use current time.""" if dt is None: dt = datetime.now() return int(dt.timestamp() * 1_000_000_000) -def string_to_trace_id128(string: Optional[str]) -> int: +def string_to_trace_id128(string: str | None) -> int: """ Convert any input string into a stable 128-bit integer trace ID. @@ -284,7 +284,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): return file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data if message_file_data is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -308,7 +308,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): # Add end user data if available if trace_info.message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() ) if end_user_data is not None: diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 1870da3781..d6f8164590 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,20 +1,20 @@ from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, ConfigDict, field_serializer, field_validator class BaseTraceInfo(BaseModel): - message_id: Optional[str] = None - message_data: Optional[Any] = None - inputs: Optional[Union[str, dict[str, Any], list]] = None - outputs: Optional[Union[str, dict[str, Any], list]] = None - start_time: Optional[datetime] = None - end_time: Optional[datetime] = None + message_id: str | None = None + message_data: Any | None = None + inputs: Union[str, dict[str, Any], list] | None = None + outputs: Union[str, dict[str, Any], list] | None = None + start_time: datetime | None = None + end_time: datetime | None = None metadata: dict[str, Any] - trace_id: Optional[str] = None + trace_id: str | None = None @field_validator("inputs", "outputs") @classmethod @@ -36,8 +36,8 @@ class BaseTraceInfo(BaseModel): class WorkflowTraceInfo(BaseTraceInfo): workflow_data: Any = None - conversation_id: Optional[str] = None - workflow_app_log_id: Optional[str] = None + conversation_id: str | None = None + workflow_app_log_id: str | None = None workflow_id: str tenant_id: str workflow_run_id: str @@ -46,7 +46,7 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_inputs: Mapping[str, Any] workflow_run_outputs: Mapping[str, Any] workflow_run_version: str - error: Optional[str] = None + error: str | None = None total_tokens: int file_list: list[str] query: str @@ -58,9 +58,9 @@ class MessageTraceInfo(BaseTraceInfo): message_tokens: int answer_tokens: int total_tokens: int - error: Optional[str] = None - file_list: Optional[Union[str, dict[str, Any], list]] = None - message_file_data: Optional[Any] = None + error: str | None = None + file_list: Union[str, dict[str, Any], list] | None = None + message_file_data: Any | None = None conversation_mode: str @@ -73,17 +73,17 @@ class ModerationTraceInfo(BaseTraceInfo): class SuggestedQuestionTraceInfo(BaseTraceInfo): total_tokens: int - status: Optional[str] = None - error: Optional[str] = None - from_account_id: Optional[str] = None - agent_based: Optional[bool] = None - from_source: Optional[str] = None - model_provider: Optional[str] = None - model_id: Optional[str] = None + status: str | None = None + error: str | None = None + from_account_id: str | None = None + agent_based: bool | None = None + from_source: str | None = None + model_provider: str | None = None + model_id: str | None = None suggested_question: list[str] level: str - status_message: Optional[str] = None - workflow_run_id: Optional[str] = None + status_message: str | None = None + workflow_run_id: str | None = None model_config = ConfigDict(protected_namespaces=()) @@ -98,7 +98,7 @@ class ToolTraceInfo(BaseTraceInfo): tool_outputs: str metadata: dict[str, Any] message_file_data: Any = None - error: Optional[str] = None + error: str | None = None tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] @@ -106,7 +106,7 @@ class ToolTraceInfo(BaseTraceInfo): class GenerateNameTraceInfo(BaseTraceInfo): - conversation_id: Optional[str] = None + conversation_id: str | None = None tenant_id: str diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index 46ba1c45b9..312c7d3676 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic_core.core_schema import ValidationInfo @@ -52,50 +52,50 @@ class LangfuseTrace(BaseModel): Langfuse trace model """ - id: Optional[str] = Field( + id: str | None = Field( default=None, description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems " "or when creating a distributed trace. Traces are upserted on id.", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Identifier of the trace. Useful for sorting/filtering in the UI.", ) - input: Optional[Union[str, dict[str, Any], list, None]] = Field( + input: Union[str, dict[str, Any], list, None] | None = Field( default=None, description="The input of the trace. Can be any JSON object." ) - output: Optional[Union[str, dict[str, Any], list, None]] = Field( + output: Union[str, dict[str, Any], list, None] | None = Field( default=None, description="The output of the trace. Can be any JSON object." ) - metadata: Optional[dict[str, Any]] = Field( + metadata: dict[str, Any] | None = Field( default=None, description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated " "via the API.", ) - user_id: Optional[str] = Field( + user_id: str | None = Field( default=None, description="The id of the user that triggered the execution. Used to provide user-level analytics.", ) - session_id: Optional[str] = Field( + session_id: str | None = Field( default=None, description="Used to group multiple traces into a session in Langfuse. Use your own session/thread identifier.", ) - version: Optional[str] = Field( + version: str | None = Field( default=None, description="The version of the trace type. Used to understand how changes to the trace type affect metrics. " "Useful in debugging.", ) - release: Optional[str] = Field( + release: str | None = Field( default=None, description="The release identifier of the current deployment. Used to understand how changes of different " "deployments affect metrics. Useful in debugging.", ) - tags: Optional[list[str]] = Field( + tags: list[str] | None = Field( default=None, description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET " "API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.", ) - public: Optional[bool] = Field( + public: bool | None = Field( default=None, description="You can make a trace public to share it via a public link. This allows others to view the trace " "without needing to log in or be members of your Langfuse project.", @@ -113,61 +113,61 @@ class LangfuseSpan(BaseModel): Langfuse span model """ - id: Optional[str] = Field( + id: str | None = Field( default=None, description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.", ) - session_id: Optional[str] = Field( + session_id: str | None = Field( default=None, description="Used to group multiple spans into a session in Langfuse. Use your own session/thread identifier.", ) - trace_id: Optional[str] = Field( + trace_id: str | None = Field( default=None, description="The id of the trace the span belongs to. Used to link spans to traces.", ) - user_id: Optional[str] = Field( + user_id: str | None = Field( default=None, description="The id of the user that triggered the execution. Used to provide user-level analytics.", ) - start_time: Optional[datetime | str] = Field( + start_time: datetime | str | None = Field( default_factory=datetime.now, description="The time at which the span started, defaults to the current time.", ) - end_time: Optional[datetime | str] = Field( + end_time: datetime | str | None = Field( default=None, description="The time at which the span ended. Automatically set by span.end().", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Identifier of the span. Useful for sorting/filtering in the UI.", ) - metadata: Optional[dict[str, Any]] = Field( + metadata: dict[str, Any] | None = Field( default=None, description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated " "via the API.", ) - level: Optional[str] = Field( + level: str | None = Field( default=None, description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of " "traces with elevated error levels and for highlighting in the UI.", ) - status_message: Optional[str] = Field( + status_message: str | None = Field( default=None, description="The status message of the span. Additional field for context of the event. E.g. the error " "message of an error event.", ) - input: Optional[Union[str, Mapping[str, Any], list, None]] = Field( + input: Union[str, Mapping[str, Any], list, None] | None = Field( default=None, description="The input of the span. Can be any JSON object." ) - output: Optional[Union[str, Mapping[str, Any], list, None]] = Field( + output: Union[str, Mapping[str, Any], list, None] | None = Field( default=None, description="The output of the span. Can be any JSON object." ) - version: Optional[str] = Field( + version: str | None = Field( default=None, description="The version of the span type. Used to understand how changes to the span type affect metrics. " "Useful in debugging.", ) - parent_observation_id: Optional[str] = Field( + parent_observation_id: str | None = Field( default=None, description="The id of the observation the span belongs to. Used to link spans to observations.", ) @@ -188,15 +188,15 @@ class UnitEnum(StrEnum): class GenerationUsage(BaseModel): - promptTokens: Optional[int] = None - completionTokens: Optional[int] = None - total: Optional[int] = None - input: Optional[int] = None - output: Optional[int] = None - unit: Optional[UnitEnum] = None - inputCost: Optional[float] = None - outputCost: Optional[float] = None - totalCost: Optional[float] = None + promptTokens: int | None = None + completionTokens: int | None = None + total: int | None = None + input: int | None = None + output: int | None = None + unit: UnitEnum | None = None + inputCost: float | None = None + outputCost: float | None = None + totalCost: float | None = None @field_validator("input", "output") @classmethod @@ -206,69 +206,69 @@ class GenerationUsage(BaseModel): class LangfuseGeneration(BaseModel): - id: Optional[str] = Field( + id: str | None = Field( default=None, description="The id of the generation can be set, defaults to random id.", ) - trace_id: Optional[str] = Field( + trace_id: str | None = Field( default=None, description="The id of the trace the generation belongs to. Used to link generations to traces.", ) - parent_observation_id: Optional[str] = Field( + parent_observation_id: str | None = Field( default=None, description="The id of the observation the generation belongs to. Used to link generations to observations.", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Identifier of the generation. Useful for sorting/filtering in the UI.", ) - start_time: Optional[datetime | str] = Field( + start_time: datetime | str | None = Field( default_factory=datetime.now, description="The time at which the generation started, defaults to the current time.", ) - completion_start_time: Optional[datetime | str] = Field( + completion_start_time: datetime | str | None = Field( default=None, description="The time at which the completion started (streaming). Set it to get latency analytics broken " "down into time until completion started and completion duration.", ) - end_time: Optional[datetime | str] = Field( + end_time: datetime | str | None = Field( default=None, description="The time at which the generation ended. Automatically set by generation.end().", ) - model: Optional[str] = Field(default=None, description="The name of the model used for the generation.") - model_parameters: Optional[dict[str, Any]] = Field( + model: str | None = Field(default=None, description="The name of the model used for the generation.") + model_parameters: dict[str, Any] | None = Field( default=None, description="The parameters of the model used for the generation; can be any key-value pairs.", ) - input: Optional[Any] = Field( + input: Any | None = Field( default=None, description="The prompt used for the generation. Can be any string or JSON object.", ) - output: Optional[Any] = Field( + output: Any | None = Field( default=None, description="The completion generated by the model. Can be any string or JSON object.", ) - usage: Optional[GenerationUsage] = Field( + usage: GenerationUsage | None = Field( default=None, description="The usage object supports the OpenAi structure with tokens and a more generic version with " "detailed costs and units.", ) - metadata: Optional[dict[str, Any]] = Field( + metadata: dict[str, Any] | None = Field( default=None, description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being " "updated via the API.", ) - level: Optional[LevelEnum] = Field( + level: LevelEnum | None = Field( default=None, description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering " "of traces with elevated error levels and for highlighting in the UI.", ) - status_message: Optional[str] = Field( + status_message: str | None = Field( default=None, description="The status message of the generation. Additional field for context of the event. E.g. the error " "message of an error event.", ) - version: Optional[str] = Field( + version: str | None = Field( default=None, description="The version of the generation type. Used to understand how changes to the span type affect " "metrics. Useful in debugging.", diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 3a03d9f4fe..119dd52a5f 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -1,7 +1,6 @@ import logging import os from datetime import datetime, timedelta -from typing import Optional from langfuse import Langfuse # type: ignore from sqlalchemy.orm import sessionmaker @@ -145,13 +144,13 @@ class LangFuseDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { @@ -164,7 +163,7 @@ class LangFuseDataTrace(BaseTraceInstance): "status": status, } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} model_provider = process_data.get("model_provider", None) model_name = process_data.get("model_name", None) if model_provider is not None and model_name is not None: @@ -242,7 +241,7 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -399,7 +398,7 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.add_span(langfuse_span_data=name_generation_span_data) - def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None): + def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None): format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {} try: self.langfuse_client.trace(**format_trace_data) @@ -407,7 +406,7 @@ class LangFuseDataTrace(BaseTraceInstance): except Exception as e: raise ValueError(f"LangFuse Failed to create trace: {str(e)}") - def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None): + def add_span(self, langfuse_span_data: LangfuseSpan | None = None): format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} try: self.langfuse_client.span(**format_span_data) @@ -415,12 +414,12 @@ class LangFuseDataTrace(BaseTraceInstance): except Exception as e: raise ValueError(f"LangFuse Failed to create span: {str(e)}") - def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None): + def update_span(self, span, langfuse_span_data: LangfuseSpan | None = None): format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} span.end(**format_span_data) - def add_generation(self, langfuse_generation_data: Optional[LangfuseGeneration] = None): + def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None): format_generation_data = ( filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} ) @@ -430,7 +429,7 @@ class LangFuseDataTrace(BaseTraceInstance): except Exception as e: raise ValueError(f"LangFuse Failed to create generation: {str(e)}") - def update_generation(self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None): + def update_generation(self, generation, langfuse_generation_data: LangfuseGeneration | None = None): format_generation_data = ( filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} ) diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 4fd01136ba..f73ba01c8b 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, Field, field_validator from pydantic_core.core_schema import ValidationInfo @@ -20,36 +20,36 @@ class LangSmithRunType(StrEnum): class LangSmithTokenUsage(BaseModel): - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - total_tokens: Optional[int] = None + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None class LangSmithMultiModel(BaseModel): - file_list: Optional[list[str]] = Field(None, description="List of files") + file_list: list[str] | None = Field(None, description="List of files") class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): - name: Optional[str] = Field(..., description="Name of the run") - inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run") - outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run") + name: str | None = Field(..., description="Name of the run") + inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the run") + outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the run") run_type: LangSmithRunType = Field(..., description="Type of the run") - start_time: Optional[datetime | str] = Field(None, description="Start time of the run") - end_time: Optional[datetime | str] = Field(None, description="End time of the run") - extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") - error: Optional[str] = Field(None, description="Error message of the run") - serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run") - parent_run_id: Optional[str] = Field(None, description="Parent run ID") - events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") - tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") - dotted_order: Optional[str] = Field(None, description="Dotted order of the run") - id: Optional[str] = Field(None, description="ID of the run") - session_id: Optional[str] = Field(None, description="Session ID associated with the run") - session_name: Optional[str] = Field(None, description="Session name associated with the run") - reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") - input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") - output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") + start_time: datetime | str | None = Field(None, description="Start time of the run") + end_time: datetime | str | None = Field(None, description="End time of the run") + extra: dict[str, Any] | None = Field(None, description="Extra information of the run") + error: str | None = Field(None, description="Error message of the run") + serialized: dict[str, Any] | None = Field(None, description="Serialized data of the run") + parent_run_id: str | None = Field(None, description="Parent run ID") + events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run") + tags: list[str] | None = Field(None, description="Tags associated with the run") + trace_id: str | None = Field(None, description="Trace ID associated with the run") + dotted_order: str | None = Field(None, description="Dotted order of the run") + id: str | None = Field(None, description="ID of the run") + session_id: str | None = Field(None, description="Session ID associated with the run") + session_name: str | None = Field(None, description="Session name associated with the run") + reference_example_id: str | None = Field(None, description="Reference example ID associated with the run") + input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run") + output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run") @field_validator("inputs", "outputs") @classmethod @@ -128,15 +128,15 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): class LangSmithRunUpdateModel(BaseModel): run_id: str = Field(..., description="ID of the run") - trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") - dotted_order: Optional[str] = Field(None, description="Dotted order of the run") - parent_run_id: Optional[str] = Field(None, description="Parent run ID") - end_time: Optional[datetime | str] = Field(None, description="End time of the run") - error: Optional[str] = Field(None, description="Error message of the run") - inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run") - outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run") - events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") - tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") - input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") - output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") + trace_id: str | None = Field(None, description="Trace ID associated with the run") + dotted_order: str | None = Field(None, description="Dotted order of the run") + parent_run_id: str | None = Field(None, description="Parent run ID") + end_time: datetime | str | None = Field(None, description="End time of the run") + error: str | None = Field(None, description="Error message of the run") + inputs: dict[str, Any] | None = Field(None, description="Inputs of the run") + outputs: dict[str, Any] | None = Field(None, description="Outputs of the run") + events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run") + tags: list[str] | None = Field(None, description="Tags associated with the run") + extra: dict[str, Any] | None = Field(None, description="Extra information of the run") + input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run") + output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index f9e5128e89..6c24ac0e47 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -2,7 +2,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import Optional, cast +from typing import cast from langsmith import Client from langsmith.schemas import RunBase @@ -167,13 +167,13 @@ class LangSmithDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 metadata = {str(key): value for key, value in execution_metadata.items()} metadata.update( @@ -188,7 +188,7 @@ class LangSmithDataTrace(BaseTraceInstance): } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} if process_data and process_data.get("model_mode") == "chat": run_type = LangSmithRunType.llm @@ -247,7 +247,7 @@ class LangSmithDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): # get message file data file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) metadata = trace_info.metadata @@ -260,7 +260,7 @@ class LangSmithDataTrace(BaseTraceInstance): metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index dd6a424ddb..98e9cb2dcb 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -2,7 +2,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import Optional, cast +from typing import cast from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 @@ -47,7 +47,7 @@ def wrap_metadata(metadata, **kwargs): return metadata -def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]): +def prepare_opik_uuid(user_datetime: datetime | None, user_uuid: str | None): """Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most messages and objects. The type-hints of BaseTraceInfo indicates that objects start_time and message_id could be null which means we cannot map @@ -182,13 +182,13 @@ class OpikDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { @@ -202,7 +202,7 @@ class OpikDataTrace(BaseTraceInstance): } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} provider = None model = None @@ -264,7 +264,7 @@ class OpikDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): # get message file data file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data if message_file_data is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -282,7 +282,7 @@ class OpikDataTrace(BaseTraceInstance): metadata["file_list"] = file_list if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index a2f1969bc8..08d4adb2ff 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -1,3 +1,4 @@ +import collections import json import logging import os @@ -5,7 +6,7 @@ import queue import threading import time from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Union from uuid import UUID, uuid4 from cachetools import LRUCache @@ -40,7 +41,7 @@ from tasks.ops_trace_task import process_trace_tasks logger = logging.getLogger(__name__) -class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): +class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): def __getitem__(self, provider: str) -> dict[str, Any]: match provider: case TracingProviderEnum.LANGFUSE: @@ -121,7 +122,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): raise KeyError(f"Unsupported tracing provider: {provider}") -provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap() +provider_config_map = OpsTraceProviderConfigMap() class OpsTraceManager: @@ -218,7 +219,7 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: Optional[TraceAppConfig] = ( + trace_config_data: TraceAppConfig | None = ( db.session.query(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -242,7 +243,7 @@ class OpsTraceManager: @classmethod def get_ops_trace_instance( cls, - app_id: Optional[Union[UUID, str]] = None, + app_id: Union[UUID, str] | None = None, ): """ Get ops trace through model config @@ -255,7 +256,7 @@ class OpsTraceManager: if app_id is None: return None - app: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.query(App).where(App.id == app_id).first() if app is None: return None @@ -329,7 +330,7 @@ class OpsTraceManager: except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app_config: App | None = db.session.query(App).where(App.id == app_id).first() if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -347,7 +348,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.query(App).where(App.id == app_id).first() if not app: raise ValueError("App not found") if not app.tracing: @@ -405,11 +406,11 @@ class TraceTask: def __init__( self, trace_type: Any, - message_id: Optional[str] = None, - workflow_execution: Optional[WorkflowExecution] = None, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - timer: Optional[Any] = None, + message_id: str | None = None, + workflow_execution: WorkflowExecution | None = None, + conversation_id: str | None = None, + user_id: str | None = None, + timer: Any | None = None, **kwargs, ): self.trace_type = trace_type @@ -823,7 +824,7 @@ class TraceTask: return generate_name_trace_info -trace_manager_timer: Optional[threading.Timer] = None +trace_manager_timer: threading.Timer | None = None trace_manager_queue: queue.Queue = queue.Queue() trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 2c0afb1600..5e8651d6f9 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from datetime import datetime -from typing import Optional, Union +from typing import Union from urllib.parse import urlparse from sqlalchemy import select @@ -49,9 +49,7 @@ def replace_text_with_content(data): return data -def generate_dotted_order( - run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None -) -> str: +def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_dotted_order: str | None = None) -> str: """ generate dotted_order for langsmith """ diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index 7f489f37ac..ef1a3be45b 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, Field, field_validator from pydantic_core.core_schema import ValidationInfo @@ -8,24 +8,24 @@ from core.ops.utils import replace_text_with_content class WeaveTokenUsage(BaseModel): - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - total_tokens: Optional[int] = None + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None class WeaveMultiModel(BaseModel): - file_list: Optional[list[str]] = Field(None, description="List of files") + file_list: list[str] | None = Field(None, description="List of files") class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): id: str = Field(..., description="ID of the trace") op: str = Field(..., description="Name of the operation") - inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace") - outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace") - attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( + inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the trace") + outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the trace") + attributes: Union[str, dict[str, Any], list, None] | None = Field( None, description="Metadata and attributes associated with trace" ) - exception: Optional[str] = Field(None, description="Exception message of the trace") + exception: str | None = Field(None, description="Exception message of the trace") @field_validator("inputs", "outputs") @classmethod diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 8eb94cc679..13a4529311 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -2,7 +2,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import Any, Optional, cast +from typing import Any, cast import wandb import weave @@ -169,13 +169,13 @@ class WeaveDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 attributes = {str(k): v for k, v in execution_metadata.items()} attributes.update( @@ -190,7 +190,7 @@ class WeaveDataTrace(BaseTraceInstance): } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} if process_data and process_data.get("model_mode") == "chat": attributes.update( { @@ -223,7 +223,7 @@ class WeaveDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): # get message file data file_list = cast(list[str], trace_info.file_list) or [] - message_file_data: Optional[MessageFile] = trace_info.message_file_data + message_file_data: MessageFile | None = trace_info.message_file_data file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) attributes = trace_info.metadata @@ -236,7 +236,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: Optional[EndUser] = ( + end_user_data: EndUser | None = ( db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -424,7 +424,7 @@ class WeaveDataTrace(BaseTraceInstance): logger.debug("Weave API check failed: %s", str(e)) raise ValueError(f"Weave API check failed: {str(e)}") - def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): + def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None): call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) self.calls[run_data.id] = call if parent_run_id: diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 48f44da68e..9352a55be0 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping -from typing import Optional, Union +from typing import Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -27,7 +27,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): app = cls._get_app(app_id, tenant_id) """Retrieve app parameters.""" - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app.workflow if workflow is None: raise ValueError("unexpected app type") @@ -53,8 +53,8 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): app_id: str, user_id: str, tenant_id: str, - conversation_id: Optional[str], - query: Optional[str], + conversation_id: str | None, + query: str | None, stream: bool, inputs: Mapping, files: list[dict], @@ -70,7 +70,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): conversation_id = conversation_id or "" - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}: if not query: raise ValueError("missing query") @@ -96,7 +96,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ invoke chat app """ - if app.mode == AppMode.ADVANCED_CHAT.value: + if app.mode == AppMode.ADVANCED_CHAT: workflow = app.workflow if not workflow: raise ValueError("unexpected app type") @@ -114,7 +114,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): invoke_from=InvokeFrom.SERVICE_API, streaming=stream, ) - elif app.mode == AppMode.AGENT_CHAT.value: + elif app.mode == AppMode.AGENT_CHAT: return AgentChatAppGenerator().generate( app_model=app, user=user, @@ -127,7 +127,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): invoke_from=InvokeFrom.SERVICE_API, streaming=stream, ) - elif app.mode == AppMode.CHAT.value: + elif app.mode == AppMode.CHAT: return ChatAppGenerator().generate( app_model=app, user=user, diff --git a/api/core/plugin/backwards_invocation/base.py b/api/core/plugin/backwards_invocation/base.py index 2a5f857576..a89b0f95be 100644 --- a/api/core/plugin/backwards_invocation/base.py +++ b/api/core/plugin/backwards_invocation/base.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar from pydantic import BaseModel @@ -23,5 +23,5 @@ T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel) class BaseBackwardsInvocationResponse(BaseModel, Generic[T]): - data: Optional[T] = None + data: T | None = None error: str = "" diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index 06773504d9..c2d1574e67 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.backwards_invocation.base import BaseBackwardsInvocation @@ -23,7 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): provider: str, tool_name: str, tool_parameters: dict[str, Any], - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tool diff --git a/api/core/plugin/entities/endpoint.py b/api/core/plugin/entities/endpoint.py index d7ba75bb4f..e5bca140f8 100644 --- a/api/core/plugin/entities/endpoint.py +++ b/api/core/plugin/entities/endpoint.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional from pydantic import BaseModel, Field, model_validator @@ -24,7 +23,7 @@ class EndpointProviderDeclaration(BaseModel): """ settings: list[ProviderConfig] = Field(default_factory=list) - endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list[EndpointDeclaration]) + endpoints: list[EndpointDeclaration] | None = Field(default_factory=list[EndpointDeclaration]) class EndpointEntity(BasePluginEntity): diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 1c13a621d4..e0762619e6 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field, model_validator from core.model_runtime.entities.provider_entities import ProviderEntity @@ -19,11 +17,11 @@ class MarketplacePluginDeclaration(BaseModel): resource: PluginResourceRequirements = Field( ..., description="Specification of computational resources needed to run the plugin" ) - endpoint: Optional[EndpointProviderDeclaration] = Field( + endpoint: EndpointProviderDeclaration | None = Field( None, description="Configuration for the plugin's API endpoint, if applicable" ) - model: Optional[ProviderEntity] = Field(None, description="Details of the AI model used by the plugin, if any") - tool: Optional[ToolProviderEntity] = Field( + model: ProviderEntity | None = Field(None, description="Details of the AI model used by the plugin, if any") + tool: ToolProviderEntity | None = Field( None, description="Information about the tool functionality provided by the plugin, if any" ) latest_version: str = Field( diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 92427a7426..0f7604b368 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -1,6 +1,6 @@ -import enum import json -from typing import Any, Optional, Union +from enum import StrEnum, auto +from typing import Any, Union from pydantic import BaseModel, Field, field_validator @@ -12,9 +12,7 @@ from core.workflow.nodes.base.entities import NumberType class PluginParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") - icon: Optional[str] = Field( - default=None, description="The icon of the option, can be a url or a base64 encoded image" - ) + icon: str | None = Field(default=None, description="The icon of the option, can be a url or a base64 encoded image") @field_validator("value", mode="before") @classmethod @@ -25,44 +23,44 @@ class PluginParameterOption(BaseModel): return value -class PluginParameterType(enum.StrEnum): +class PluginParameterType(StrEnum): """ all available parameter types """ - STRING = CommonParameterType.STRING.value - NUMBER = CommonParameterType.NUMBER.value - BOOLEAN = CommonParameterType.BOOLEAN.value - SELECT = CommonParameterType.SELECT.value - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - FILE = CommonParameterType.FILE.value - FILES = CommonParameterType.FILES.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value - ANY = CommonParameterType.ANY.value - DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value + STRING = CommonParameterType.STRING + NUMBER = CommonParameterType.NUMBER + BOOLEAN = CommonParameterType.BOOLEAN + SELECT = CommonParameterType.SELECT + SECRET_INPUT = CommonParameterType.SECRET_INPUT + FILE = CommonParameterType.FILE + FILES = CommonParameterType.FILES + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR + ANY = CommonParameterType.ANY + DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT # deprecated, should not use. - SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + SYSTEM_FILES = CommonParameterType.SYSTEM_FILES # MCP object and array type parameters - ARRAY = CommonParameterType.ARRAY.value - OBJECT = CommonParameterType.OBJECT.value + ARRAY = CommonParameterType.ARRAY + OBJECT = CommonParameterType.OBJECT -class MCPServerParameterType(enum.StrEnum): +class MCPServerParameterType(StrEnum): """ MCP server got complex parameter types """ - ARRAY = "array" - OBJECT = "object" + ARRAY = auto() + OBJECT = auto() class PluginParameterAutoGenerate(BaseModel): - class Type(enum.StrEnum): - PROMPT_INSTRUCTION = "prompt_instruction" + class Type(StrEnum): + PROMPT_INSTRUCTION = auto() type: Type @@ -74,15 +72,15 @@ class PluginParameterTemplate(BaseModel): class PluginParameter(BaseModel): name: str = Field(..., description="The name of the parameter") label: I18nObject = Field(..., description="The label presented to the user") - placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user") + placeholder: I18nObject | None = Field(default=None, description="The placeholder presented to the user") scope: str | None = None - auto_generate: Optional[PluginParameterAutoGenerate] = None - template: Optional[PluginParameterTemplate] = None + auto_generate: PluginParameterAutoGenerate | None = None + template: PluginParameterTemplate | None = None required: bool = False - default: Optional[Union[float, int, str]] = None - min: Optional[Union[float, int]] = None - max: Optional[Union[float, int]] = None - precision: Optional[int] = None + default: Union[float, int, str] | None = None + min: Union[float, int] | None = None + max: Union[float, int] | None = None + precision: int | None = None options: list[PluginParameterOption] = Field(default_factory=list) @field_validator("options", mode="before") @@ -93,7 +91,7 @@ class PluginParameter(BaseModel): return v -def as_normal_type(typ: enum.StrEnum): +def as_normal_type(typ: StrEnum): if typ.value in { PluginParameterType.SECRET_INPUT, PluginParameterType.SELECT, @@ -102,7 +100,7 @@ def as_normal_type(typ: enum.StrEnum): return typ.value -def cast_parameter_value(typ: enum.StrEnum, value: Any, /): +def cast_parameter_value(typ: StrEnum, value: Any, /): try: match typ.value: case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT: @@ -190,7 +188,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.") -def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: Any): +def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any): """ init frontend parameter by rule """ diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index a6369636e2..adc80d1e94 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -1,8 +1,8 @@ import datetime -import enum import re from collections.abc import Mapping -from typing import Any, Optional +from enum import StrEnum, auto +from typing import Any from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -16,11 +16,11 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity -class PluginInstallationSource(enum.StrEnum): - Github = "github" - Marketplace = "marketplace" - Package = "package" - Remote = "remote" +class PluginInstallationSource(StrEnum): + Github = auto() + Marketplace = auto() + Package = auto() + Remote = auto() class PluginResourceRequirements(BaseModel): @@ -28,56 +28,56 @@ class PluginResourceRequirements(BaseModel): class Permission(BaseModel): class Tool(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) class Model(BaseModel): - enabled: Optional[bool] = Field(default=False) - llm: Optional[bool] = Field(default=False) - text_embedding: Optional[bool] = Field(default=False) - rerank: Optional[bool] = Field(default=False) - tts: Optional[bool] = Field(default=False) - speech2text: Optional[bool] = Field(default=False) - moderation: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) + llm: bool | None = Field(default=False) + text_embedding: bool | None = Field(default=False) + rerank: bool | None = Field(default=False) + tts: bool | None = Field(default=False) + speech2text: bool | None = Field(default=False) + moderation: bool | None = Field(default=False) class Node(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) class Endpoint(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) class Storage(BaseModel): - enabled: Optional[bool] = Field(default=False) + enabled: bool | None = Field(default=False) size: int = Field(ge=1024, le=1073741824, default=1048576) - tool: Optional[Tool] = Field(default=None) - model: Optional[Model] = Field(default=None) - node: Optional[Node] = Field(default=None) - endpoint: Optional[Endpoint] = Field(default=None) - storage: Optional[Storage] = Field(default=None) + tool: Tool | None = Field(default=None) + model: Model | None = Field(default=None) + node: Node | None = Field(default=None) + endpoint: Endpoint | None = Field(default=None) + storage: Storage | None = Field(default=None) - permission: Optional[Permission] = Field(default=None) + permission: Permission | None = Field(default=None) -class PluginCategory(enum.StrEnum): - Tool = "tool" - Model = "model" - Extension = "extension" +class PluginCategory(StrEnum): + Tool = auto() + Model = auto() + Extension = auto() AgentStrategy = "agent-strategy" class PluginDeclaration(BaseModel): class Plugins(BaseModel): - tools: Optional[list[str]] = Field(default_factory=list[str]) - models: Optional[list[str]] = Field(default_factory=list[str]) - endpoints: Optional[list[str]] = Field(default_factory=list[str]) + tools: list[str] | None = Field(default_factory=list[str]) + models: list[str] | None = Field(default_factory=list[str]) + endpoints: list[str] | None = Field(default_factory=list[str]) class Meta(BaseModel): - minimum_dify_version: Optional[str] = Field(default=None) - version: Optional[str] = Field(default=None) + minimum_dify_version: str | None = Field(default=None) + version: str | None = Field(default=None) @field_validator("minimum_dify_version") @classmethod - def validate_minimum_dify_version(cls, v: Optional[str]) -> Optional[str]: + def validate_minimum_dify_version(cls, v: str | None) -> str | None: if v is None: return v try: @@ -87,23 +87,23 @@ class PluginDeclaration(BaseModel): raise ValueError(f"Invalid version format: {v}") from e version: str = Field(...) - author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$") + author: str | None = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$") name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$") description: I18nObject icon: str - icon_dark: Optional[str] = Field(default=None) + icon_dark: str | None = Field(default=None) label: I18nObject category: PluginCategory created_at: datetime.datetime resource: PluginResourceRequirements plugins: Plugins tags: list[str] = Field(default_factory=list) - repo: Optional[str] = Field(default=None) + repo: str | None = Field(default=None) verified: bool = Field(default=False) - tool: Optional[ToolProviderEntity] = None - model: Optional[ProviderEntity] = None - endpoint: Optional[EndpointProviderDeclaration] = None - agent_strategy: Optional[AgentStrategyProviderEntity] = None + tool: ToolProviderEntity | None = None + model: ProviderEntity | None = None + endpoint: EndpointProviderDeclaration | None = None + agent_strategy: AgentStrategyProviderEntity | None = None meta: Meta @field_validator("version") @@ -206,10 +206,10 @@ class ToolProviderID(GenericProviderID): class PluginDependency(BaseModel): - class Type(enum.StrEnum): - Github = PluginInstallationSource.Github.value - Marketplace = PluginInstallationSource.Marketplace.value - Package = PluginInstallationSource.Package.value + class Type(StrEnum): + Github = PluginInstallationSource.Github + Marketplace = PluginInstallationSource.Marketplace + Package = PluginInstallationSource.Package class Github(BaseModel): repo: str @@ -233,9 +233,9 @@ class PluginDependency(BaseModel): type: Type value: Github | Marketplace | Package - current_identifier: Optional[str] = None + current_identifier: str | None = None class MissingPluginDependency(BaseModel): plugin_unique_identifier: str - current_identifier: Optional[str] = None + current_identifier: str | None = None diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index f1d6860bb4..d6f0dd8121 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Generic, TypeVar from pydantic import BaseModel, ConfigDict, Field @@ -24,7 +24,7 @@ class PluginDaemonBasicResponse(BaseModel, Generic[T]): code: int message: str - data: Optional[T] = None + data: T | None = None class InstallPluginMessage(BaseModel): @@ -174,7 +174,7 @@ class PluginVerification(BaseModel): class PluginDecodeResponse(BaseModel): unique_identifier: str = Field(description="The unique identifier of the plugin.") manifest: PluginDeclaration - verification: Optional[PluginVerification] = Field(default=None, description="Basic verification information") + verification: PluginVerification | None = Field(default=None, description="Basic verification information") class PluginOAuthAuthorizationUrlResponse(BaseModel): diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 3a783dad3e..10f37f75f8 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -35,7 +35,7 @@ class InvokeCredentials(BaseModel): class PluginInvokeContext(BaseModel): - credentials: Optional[InvokeCredentials] = Field( + credentials: InvokeCredentials | None = Field( default_factory=InvokeCredentials, description="Credentials context for the plugin invocation or backward invocation.", ) @@ -50,7 +50,7 @@ class RequestInvokeTool(BaseModel): provider: str tool: str tool_parameters: dict - credential_id: Optional[str] = None + credential_id: str | None = None class BaseRequestInvokeModel(BaseModel): @@ -70,9 +70,9 @@ class RequestInvokeLLM(BaseRequestInvokeModel): mode: str completion_params: dict[str, Any] = Field(default_factory=dict) prompt_messages: list[PromptMessage] = Field(default_factory=list) - tools: Optional[list[PromptMessageTool]] = Field(default_factory=list[PromptMessageTool]) - stop: Optional[list[str]] = Field(default_factory=list[str]) - stream: Optional[bool] = False + tools: list[PromptMessageTool] | None = Field(default_factory=list[PromptMessageTool]) + stop: list[str] | None = Field(default_factory=list[str]) + stream: bool | None = False model_config = ConfigDict(protected_namespaces=()) @@ -194,10 +194,10 @@ class RequestInvokeApp(BaseModel): app_id: str inputs: dict[str, Any] - query: Optional[str] = None + query: str | None = None response_mode: Literal["blocking", "streaming"] - conversation_id: Optional[str] = None - user: Optional[str] = None + conversation_id: str | None = None + user: str | None = None files: list[dict] = Field(default_factory=list) diff --git a/api/core/plugin/impl/agent.py b/api/core/plugin/impl/agent.py index 526f6f2961..0b55f20522 100644 --- a/api/core/plugin/impl/agent.py +++ b/api/core/plugin/impl/agent.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.agent.entities import AgentInvokeMessage from core.plugin.entities.plugin import GenericProviderID @@ -82,10 +82,10 @@ class PluginAgentClient(BasePluginClient): agent_provider: str, agent_strategy: str, agent_params: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - context: Optional[PluginInvokeContext] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + context: PluginInvokeContext | None = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent with the given tenant, user, plugin, provider, name and parameters. diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 85a72d9f82..153da142f4 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -1,6 +1,6 @@ import binascii from collections.abc import Generator, Sequence -from typing import IO, Optional +from typing import IO from core.model_runtime.entities.llm_entities import LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -151,9 +151,9 @@ class PluginModelClient(BasePluginClient): model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, ) -> Generator[LLMResultChunk, None, None]: """ @@ -200,7 +200,7 @@ class PluginModelClient(BasePluginClient): model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> int: """ Get number of tokens for llm @@ -325,8 +325,8 @@ class PluginModelClient(BasePluginClient): credentials: dict, query: str, docs: list[str], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, + score_threshold: float | None = None, + top_n: int | None = None, ) -> RerankResult: """ Invoke rerank @@ -414,7 +414,7 @@ class PluginModelClient(BasePluginClient): provider: str, model: str, credentials: dict, - language: Optional[str] = None, + language: str | None = None, ): """ Get tts model voices diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 7199c0d15a..bb68f4700c 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -81,9 +81,9 @@ class PluginToolManager(BasePluginClient): credentials: dict[str, Any], credential_type: CredentialType, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Invoke the tool with the given tenant, user, plugin, provider, name, credentials and parameters. @@ -153,9 +153,9 @@ class PluginToolManager(BasePluginClient): provider: str, credentials: dict[str, Any], tool: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get the runtime parameters of the tool diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 11c6e5c23b..5f2ffefd94 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Optional, cast +from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import file_manager @@ -41,11 +41,11 @@ class AdvancedPromptTransform(PromptTransform): inputs: Mapping[str, str], query: str, files: Sequence[File], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], + context: str | None, + memory_config: MemoryConfig | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: prompt_messages = [] @@ -80,13 +80,13 @@ class AdvancedPromptTransform(PromptTransform): self, prompt_template: CompletionModelPromptTemplate, inputs: Mapping[str, str], - query: Optional[str], + query: str | None, files: Sequence[File], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], + context: str | None, + memory_config: MemoryConfig | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Get completion model prompt messages. @@ -141,13 +141,13 @@ class AdvancedPromptTransform(PromptTransform): self, prompt_template: list[ChatModelMessage], inputs: Mapping[str, str], - query: Optional[str], + query: str | None, files: Sequence[File], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], + context: str | None, + memory_config: MemoryConfig | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Get chat model prompt messages. diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 09f017a7db..a96b094e6d 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, cast +from typing import cast from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, @@ -23,7 +23,7 @@ class AgentHistoryPromptTransform(PromptTransform): model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage], history_messages: list[PromptMessage], - memory: Optional[TokenBufferMemory] = None, + memory: TokenBufferMemory | None = None, ): self.model_config = model_config self.prompt_messages = prompt_messages diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index c8e7b414df..7094633093 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -12,7 +12,7 @@ class ChatModelMessage(BaseModel): text: str role: PromptMessageRole - edition_type: Optional[Literal["basic", "jinja2"]] = None + edition_type: Literal["basic", "jinja2"] | None = None class CompletionModelPromptTemplate(BaseModel): @@ -21,7 +21,7 @@ class CompletionModelPromptTemplate(BaseModel): """ text: str - edition_type: Optional[Literal["basic", "jinja2"]] = None + edition_type: Literal["basic", "jinja2"] | None = None class MemoryConfig(BaseModel): @@ -43,8 +43,8 @@ class MemoryConfig(BaseModel): """ enabled: bool - size: Optional[int] = None + size: int | None = None - role_prefix: Optional[RolePrefix] = None + role_prefix: RolePrefix | None = None window: WindowConfig - query_prompt_template: Optional[str] = None + query_prompt_template: str | None = None diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 1f040599be..a6e873d587 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -55,8 +55,8 @@ class PromptTransform: memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None, + human_prefix: str | None = None, + ai_prefix: str | None = None, ) -> str: """Get memory messages.""" kwargs: dict[str, Any] = {"max_token_limit": max_token_limit} diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d15cb7cbc1..d1d518a55d 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,8 +1,8 @@ -import enum import json import os from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from enum import StrEnum, auto +from typing import TYPE_CHECKING, Any, cast from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -25,9 +25,9 @@ if TYPE_CHECKING: from core.file.models import File -class ModelMode(enum.StrEnum): - COMPLETION = "completion" - CHAT = "chat" +class ModelMode(StrEnum): + COMPLETION = auto() + CHAT = auto() prompt_file_contents: dict[str, Any] = {} @@ -45,11 +45,11 @@ class SimplePromptTransform(PromptTransform): inputs: Mapping[str, str], query: str, files: Sequence["File"], - context: Optional[str], - memory: Optional[TokenBufferMemory], + context: str | None, + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: inputs = {key: str(value) for key, value in inputs.items()} model_mode = ModelMode(model_config.mode) @@ -86,9 +86,9 @@ class SimplePromptTransform(PromptTransform): model_config: ModelConfigWithCredentialsEntity, pre_prompt: str, inputs: dict, - query: Optional[str] = None, - context: Optional[str] = None, - histories: Optional[str] = None, + query: str | None = None, + context: str | None = None, + histories: str | None = None, ) -> tuple[str, dict]: # get prompt template prompt_template_config = self.get_prompt_template( @@ -182,12 +182,12 @@ class SimplePromptTransform(PromptTransform): pre_prompt: str, inputs: dict, query: str, - context: Optional[str], + context: str | None, files: Sequence["File"], - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: prompt_messages: list[PromptMessage] = [] # get prompt @@ -228,12 +228,12 @@ class SimplePromptTransform(PromptTransform): pre_prompt: str, inputs: dict, query: str, - context: Optional[str], + context: str | None, files: Sequence["File"], - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> tuple[list[PromptMessage], list[str] | None]: # get prompt prompt, prompt_rules = self._get_prompt_str_and_rules( app_mode=app_mode, @@ -281,7 +281,7 @@ class SimplePromptTransform(PromptTransform): self, prompt: str, files: Sequence["File"], - image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> UserPromptMessage: if files: prompt_message_contents: list[PromptMessageContentUnionTypes] = [] diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index e4e8b09a04..082c6c4c50 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,7 +3,7 @@ import json from collections import defaultdict from collections.abc import Sequence from json import JSONDecodeError -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -281,7 +281,7 @@ class ProviderManager: model_type_instance=model_type_instance, ) - def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]: + def get_default_model(self, tenant_id: str, model_type: ModelType) -> DefaultModelEntity | None: """ Get default model. @@ -1036,8 +1036,8 @@ class ProviderManager: def _to_model_settings( self, provider_entity: ProviderEntity, - provider_model_settings: Optional[list[ProviderModelSetting]] = None, - load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None, + provider_model_settings: list[ProviderModelSetting] | None = None, + load_balancing_model_configs: list[LoadBalancingModelConfig] | None = None, ) -> list[ModelSettings]: """ Convert to model settings. diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index d17d76333e..696e3e967f 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError @@ -18,8 +16,8 @@ class DataPostProcessor: self, tenant_id: str, reranking_mode: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, + reranking_model: dict | None = None, + weights: dict | None = None, reorder_enabled: bool = False, ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) @@ -29,9 +27,9 @@ class DataPostProcessor: self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: if self.rerank_runner: documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) @@ -45,9 +43,9 @@ class DataPostProcessor: self, reranking_mode: str, tenant_id: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, - ) -> Optional[BaseRerankRunner]: + reranking_model: dict | None = None, + weights: dict | None = None, + ) -> BaseRerankRunner | None: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: runner = RerankRunnerFactory.create_rerank_runner( runner_type=reranking_mode, @@ -74,12 +72,12 @@ class DataPostProcessor: return runner return None - def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: + def _get_reorder_runner(self, reorder_enabled) -> ReorderRunner | None: if reorder_enabled: return ReorderRunner() return None - def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None: + def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None: if reranking_model: try: model_manager = ModelManager() diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 096f40f707..70690a4c56 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Optional +from typing import Any import orjson from pydantic import BaseModel @@ -143,7 +143,7 @@ class Jieba(BaseKeyword): storage.delete(file_key) storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8")) - def _get_dataset_keyword_table(self) -> Optional[dict]: + def _get_dataset_keyword_table(self) -> dict | None: dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index a6214d955b..81619570f9 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,5 +1,5 @@ import re -from typing import Optional, cast +from typing import cast class JiebaKeywordTableHandler: @@ -10,7 +10,7 @@ class JiebaKeywordTableHandler: jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore - def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: + def extract_keywords(self, text: str, max_keywords_per_chunk: int | None = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" import jieba.analyse # type: ignore diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index fefd42f84d..429744c0de 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,6 +1,5 @@ import concurrent.futures from concurrent.futures import ThreadPoolExecutor -from typing import Optional from flask import Flask, current_app from sqlalchemy import select @@ -39,11 +38,11 @@ class RetrievalService: dataset_id: str, query: str, top_k: int, - score_threshold: Optional[float] = 0.0, - reranking_model: Optional[dict] = None, + score_threshold: float | None = 0.0, + reranking_model: dict | None = None, reranking_mode: str = "reranking_model", - weights: Optional[dict] = None, - document_ids_filter: Optional[list[str]] = None, + weights: dict | None = None, + document_ids_filter: list[str] | None = None, ): if not query: return [] @@ -125,8 +124,8 @@ class RetrievalService: cls, dataset_id: str, query: str, - external_retrieval_model: Optional[dict] = None, - metadata_filtering_conditions: Optional[dict] = None, + external_retrieval_model: dict | None = None, + metadata_filtering_conditions: dict | None = None, ): stmt = select(Dataset).where(Dataset.id == dataset_id) dataset = db.session.scalar(stmt) @@ -145,7 +144,7 @@ class RetrievalService: return all_documents @classmethod - def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: + def _get_dataset(cls, dataset_id: str) -> Dataset | None: with Session(db.engine) as session: return session.query(Dataset).where(Dataset.id == dataset_id).first() @@ -158,7 +157,7 @@ class RetrievalService: top_k: int, all_documents: list, exceptions: list, - document_ids_filter: Optional[list[str]] = None, + document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): try: @@ -182,12 +181,12 @@ class RetrievalService: dataset_id: str, query: str, top_k: int, - score_threshold: Optional[float], - reranking_model: Optional[dict], + score_threshold: float | None, + reranking_model: dict | None, all_documents: list, retrieval_method: str, exceptions: list, - document_ids_filter: Optional[list[str]] = None, + document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): try: @@ -235,12 +234,12 @@ class RetrievalService: dataset_id: str, query: str, top_k: int, - score_threshold: Optional[float], - reranking_model: Optional[dict], + score_threshold: float | None, + reranking_model: dict | None, all_documents: list, retrieval_method: str, exceptions: list, - document_ids_filter: Optional[list[str]] = None, + document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): try: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index c3a6127e4a..77a0fa6cf2 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, model_validator @@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): account: str account_password: str namespace: str = "dify" - namespace_password: Optional[str] = None + namespace_password: str | None = None metrics: str = "cosine" read_timeout: int = 60000 diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index e7128b183e..de1572410c 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional +from typing import Any import chromadb from chromadb import QueryResult, Settings @@ -20,8 +20,8 @@ class ChromaConfig(BaseModel): port: int tenant: str database: str - auth_provider: Optional[str] = None - auth_credentials: Optional[str] = None + auth_provider: str | None = None + auth_credentials: str | None = None def to_chroma_params(self): settings = Settings( diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index eb4cbd2324..e55e5f3101 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -84,7 +84,7 @@ class ClickzettaConnectionPool: self._pool_locks: dict[str, threading.Lock] = {} self._max_pool_size = 5 # Maximum connections per configuration self._connection_timeout = 300 # 5 minutes timeout - self._cleanup_thread: Optional[threading.Thread] = None + self._cleanup_thread: threading.Thread | None = None self._shutdown = False self._start_cleanup_thread() @@ -303,8 +303,8 @@ class ClickzettaVector(BaseVector): """ # Class-level write queue and lock for serializing writes - _write_queue: Optional[queue.Queue] = None - _write_thread: Optional[threading.Thread] = None + _write_queue: queue.Queue | None = None + _write_thread: threading.Thread | None = None _write_lock = threading.Lock() _shutdown = False @@ -328,7 +328,7 @@ class ClickzettaVector(BaseVector): def __init__(self, vector_instance: "ClickzettaVector"): self.vector = vector_instance - self.connection: Optional[Connection] = None + self.connection: Connection | None = None def __enter__(self) -> "Connection": self.connection = self.vector._get_connection() @@ -641,7 +641,7 @@ class ClickzettaVector(BaseVector): for doc, embedding in zip(batch_docs, batch_embeddings): # Optimized: minimal checks for common case, fallback for edge cases - metadata = doc.metadata if doc.metadata else {} + metadata = doc.metadata or {} if not isinstance(metadata, dict): metadata = {} diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py index 7118029d40..7b00928b7b 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any from flask import current_app @@ -22,8 +22,8 @@ class ElasticSearchJaVector(ElasticSearchVector): def create_collection( self, embeddings: list[list[float]], - metadatas: Optional[list[dict[Any, Any]]] = None, - index_params: Optional[dict] = None, + metadatas: list[dict[Any, Any]] | None = None, + index_params: dict | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index df1c747585..2c147fa7ca 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, Optional, cast +from typing import Any, cast from urllib.parse import urlparse import requests @@ -24,18 +24,18 @@ logger = logging.getLogger(__name__) class ElasticSearchConfig(BaseModel): # Regular Elasticsearch config - host: Optional[str] = None - port: Optional[int] = None - username: Optional[str] = None - password: Optional[str] = None + host: str | None = None + port: int | None = None + username: str | None = None + password: str | None = None # Elastic Cloud specific config - cloud_url: Optional[str] = None # Cloud URL for Elasticsearch Cloud - api_key: Optional[str] = None + cloud_url: str | None = None # Cloud URL for Elasticsearch Cloud + api_key: str | None = None # Common config use_cloud: bool = False - ca_certs: Optional[str] = None + ca_certs: str | None = None verify_certs: bool = False request_timeout: int = 100000 retry_on_timeout: bool = True @@ -256,8 +256,8 @@ class ElasticSearchVector(BaseVector): def create_collection( self, embeddings: list[list[float]], - metadatas: Optional[list[dict[Any, Any]]] = None, - index_params: Optional[dict] = None, + metadatas: list[dict[Any, Any]] | None = None, + index_params: dict | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py index 9887e21b7c..8fc94be360 100644 --- a/api/core/rag/datasource/vdb/field.py +++ b/api/core/rag/datasource/vdb/field.py @@ -1,13 +1,13 @@ -from enum import Enum +from enum import StrEnum, auto -class Field(Enum): +class Field(StrEnum): CONTENT_KEY = "page_content" METADATA_KEY = "metadata" GROUP_KEY = "group_id" - VECTOR = "vector" + VECTOR = auto() # Sparse Vector aims to support full text search - SPARSE_VECTOR = "sparse_vector" + SPARSE_VECTOR = auto() TEXT_KEY = "text" PRIMARY_KEY = "id" DOC_ID = "metadata.doc_id" diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index 0eca37a129..cfee090768 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -1,7 +1,7 @@ import json import logging import ssl -from typing import Any, Optional +from typing import Any from elasticsearch import Elasticsearch from pydantic import BaseModel, model_validator @@ -157,8 +157,8 @@ class HuaweiCloudVector(BaseVector): def create_collection( self, embeddings: list[list[float]], - metadatas: Optional[list[dict[Any, Any]]] = None, - index_params: Optional[dict] = None, + metadatas: list[dict[Any, Any]] | None = None, + index_params: dict | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 5097412c2c..f3ec30d178 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -2,7 +2,7 @@ import copy import json import logging import time -from typing import Any, Optional +from typing import Any from opensearchpy import OpenSearch, helpers from opensearchpy.helpers import BulkIndexError @@ -29,10 +29,10 @@ UGC_INDEX_PREFIX = "ugc_index" class LindormVectorStoreConfig(BaseModel): hosts: str - username: Optional[str] = None - password: Optional[str] = None - using_ugc: Optional[bool] = False - request_timeout: Optional[float] = 1.0 # timeout units: s + username: str | None = None + password: str | None = None + using_ugc: bool | None = False + request_timeout: float | None = 1.0 # timeout units: s @model_validator(mode="before") @classmethod @@ -448,13 +448,13 @@ def default_text_search_query( query_text: str, k: int = 4, text_field: str = Field.CONTENT_KEY.value, - must: Optional[list[dict]] = None, - must_not: Optional[list[dict]] = None, - should: Optional[list[dict]] = None, + must: list[dict] | None = None, + must_not: list[dict] | None = None, + should: list[dict] | None = None, minimum_should_match: int = 0, - filters: Optional[list[dict]] = None, - routing: Optional[str] = None, - routing_field: Optional[str] = None, + filters: list[dict] | None = None, + routing: str | None = None, + routing_field: str | None = None, **kwargs, ): query_clause: dict[str, Any] = {} @@ -505,13 +505,13 @@ def default_vector_search_query( query_vector: list[float], k: int = 4, min_score: str = "0.0", - ef_search: Optional[str] = None, # only for hnsw - nprobe: Optional[str] = None, # "2000" - reorder_factor: Optional[str] = None, # "20" - client_refactor: Optional[str] = None, # "true" + ef_search: str | None = None, # only for hnsw + nprobe: str | None = None, # "2000" + reorder_factor: str | None = None, # "20" + client_refactor: str | None = None, # "true" vector_field: str = Field.VECTOR.value, - filters: Optional[list[dict]] = None, - filter_type: Optional[str] = None, + filters: list[dict] | None = None, + filter_type: str | None = None, **kwargs, ): if filters is not None: diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 3dd073ce50..6fe396dc1e 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -3,7 +3,7 @@ import logging import uuid from collections.abc import Callable from functools import wraps -from typing import Any, Concatenate, Optional, ParamSpec, TypeVar +from typing import Any, Concatenate, ParamSpec, TypeVar from mo_vector.client import MoVectorClient # type: ignore from pydantic import BaseModel, model_validator @@ -74,7 +74,7 @@ class MatrixoneVector(BaseVector): self.client = self._get_client(len(embeddings[0]), True) return self.add_texts(texts, embeddings) - def _get_client(self, dimension: Optional[int] = None, create_table: bool = False) -> MoVectorClient: + def _get_client(self, dimension: int | None = None, create_table: bool = False) -> MoVectorClient: """ Create a new client for the collection. @@ -103,7 +103,7 @@ class MatrixoneVector(BaseVector): self.client = self._get_client(len(embeddings[0]), True) assert self.client is not None ids = [] - for _, doc in enumerate(documents): + for doc in documents: if doc.metadata is not None: doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) ids.append(doc_id) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 2ec48ae365..5f32feb709 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any from packaging import version from pydantic import BaseModel, model_validator @@ -26,13 +26,13 @@ class MilvusConfig(BaseModel): """ uri: str # Milvus server URI - token: Optional[str] = None # Optional token for authentication - user: Optional[str] = None # Username for authentication - password: Optional[str] = None # Password for authentication + token: str | None = None # Optional token for authentication + user: str | None = None # Username for authentication + password: str | None = None # Password for authentication batch_size: int = 100 # Batch size for operations database: str = "default" # Database name enable_hybrid_search: bool = False # Flag to enable hybrid search - analyzer_params: Optional[str] = None # Analyzer params + analyzer_params: str | None = None # Analyzer params @model_validator(mode="before") @classmethod @@ -79,7 +79,7 @@ class MilvusVector(BaseVector): self._load_collection_fields() self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported - def _load_collection_fields(self, fields: Optional[list[str]] = None): + def _load_collection_fields(self, fields: list[str] | None = None): if fields is None: # Load collection fields from remote server collection_info = self._client.describe_collection(self._collection_name) @@ -292,7 +292,7 @@ class MilvusVector(BaseVector): ) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None ): """ Create a new collection in Milvus with the specified schema and index parameters. diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b590a4dfe4..17aac25b87 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -1,7 +1,7 @@ import json import logging import uuid -from enum import Enum +from enum import StrEnum from typing import Any from clickhouse_connect import get_client @@ -27,7 +27,7 @@ class MyScaleConfig(BaseModel): fts_params: str -class SortOrder(Enum): +class SortOrder(StrEnum): ASC = "ASC" DESC = "DESC" diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 3f65a4a275..3eb1df027e 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Literal, Optional +from typing import Any, Literal from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers @@ -26,10 +26,10 @@ class OpenSearchConfig(BaseModel): secure: bool = False # use_ssl verify_certs: bool = True auth_method: Literal["basic", "aws_managed_iam"] = "basic" - user: Optional[str] = None - password: Optional[str] = None - aws_region: Optional[str] = None - aws_service: Optional[str] = None + user: str | None = None + password: str | None = None + aws_region: str | None = None + aws_service: str | None = None @model_validator(mode="before") @classmethod @@ -104,7 +104,7 @@ class OpenSearchVector(BaseVector): }, } # See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377 - if self._client_config.aws_service not in ["aoss"]: + if self._client_config.aws_service != "aoss": action["_id"] = uuid4().hex actions.append(action) @@ -236,7 +236,7 @@ class OpenSearchVector(BaseVector): return docs def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None ): lock_name = f"vector_indexing_lock_{self._collection_name.lower()}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index d329220580..d46f29bd64 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import qdrant_client from flask import current_app @@ -46,7 +46,7 @@ class PathQdrantParams(BaseModel): class UrlQdrantParams(BaseModel): url: str - api_key: Optional[str] + api_key: str | None timeout: float verify: bool grpc_port: int @@ -55,9 +55,9 @@ class UrlQdrantParams(BaseModel): class QdrantConfig(BaseModel): endpoint: str - api_key: Optional[str] = None + api_key: str | None = None timeout: float = 20 - root_path: Optional[str] = None + root_path: str | None = None grpc_port: int = 6334 prefer_grpc: bool = False replication_factor: int = 1 @@ -189,10 +189,10 @@ class QdrantVector(BaseVector): self, texts: Iterable[str], embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, batch_size: int = 64, - group_id: Optional[str] = None, + group_id: str | None = None, ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest @@ -234,7 +234,7 @@ class QdrantVector(BaseVector): def _build_payloads( cls, texts: Iterable[str], - metadatas: Optional[list[dict]], + metadatas: list[dict] | None, content_payload_key: str, metadata_payload_key: str, group_id: str, diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 9d3dc7c622..99698fcdd0 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -1,6 +1,6 @@ import json import uuid -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, model_validator from sqlalchemy import Column, String, Table, create_engine, insert @@ -160,7 +160,7 @@ class RelytVector(BaseVector): else: return None - def delete_by_uuids(self, ids: Optional[list[str]] = None): + def delete_by_uuids(self, ids: list[str] | None = None): """Delete by vector IDs. Args: @@ -241,7 +241,7 @@ class RelytVector(BaseVector): self, embedding: list[float], k: int = 4, - filter: Optional[dict] = None, + filter: dict | None = None, ) -> list[tuple[Document, float]]: # Add the filter if provided diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index 27685b7ddf..e91d9bb0d6 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -2,7 +2,7 @@ import json import logging import math from collections.abc import Iterable -from typing import Any, Optional +from typing import Any import tablestore # type: ignore from pydantic import BaseModel, model_validator @@ -22,11 +22,11 @@ logger = logging.getLogger(__name__) class TableStoreConfig(BaseModel): - access_key_id: Optional[str] = None - access_key_secret: Optional[str] = None - instance_name: Optional[str] = None - endpoint: Optional[str] = None - normalize_full_text_bm25_score: Optional[bool] = False + access_key_id: str | None = None + access_key_secret: str | None = None + instance_name: str | None = None + endpoint: str | None = None + normalize_full_text_bm25_score: bool | None = False @model_validator(mode="before") @classmethod diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 2485857070..291d047c04 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, Optional +from typing import Any from pydantic import BaseModel from tcvdb_text.encoder import BM25Encoder # type: ignore @@ -24,10 +24,10 @@ logger = logging.getLogger(__name__) class TencentConfig(BaseModel): url: str - api_key: Optional[str] = None + api_key: str | None = None timeout: float = 30 - username: Optional[str] = None - database: Optional[str] = None + username: str | None = None + database: str | None = None index_type: str = "HNSW" metric_type: str = "IP" shard: int = 1 diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 7055581459..f90a311df4 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import qdrant_client import requests @@ -45,9 +45,9 @@ if TYPE_CHECKING: class TidbOnQdrantConfig(BaseModel): endpoint: str - api_key: Optional[str] = None + api_key: str | None = None timeout: float = 20 - root_path: Optional[str] = None + root_path: str | None = None grpc_port: int = 6334 prefer_grpc: bool = False replication_factor: int = 1 @@ -180,10 +180,10 @@ class TidbOnQdrantVector(BaseVector): self, texts: Iterable[str], embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, batch_size: int = 64, - group_id: Optional[str] = None, + group_id: str | None = None, ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest @@ -225,7 +225,7 @@ class TidbOnQdrantVector(BaseVector): def _build_payloads( cls, texts: Iterable[str], - metadatas: Optional[list[dict]], + metadatas: list[dict] | None, content_payload_key: str, metadata_payload_key: str, group_id: str, diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index b2cc51d034..dc4f026ff3 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,7 +1,7 @@ import logging import time from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from sqlalchemy import select @@ -32,7 +32,7 @@ class AbstractVectorFactory(ABC): class Vector: - def __init__(self, dataset: Dataset, attributes: Optional[list] = None): + def __init__(self, dataset: Dataset, attributes: list | None = None): if attributes is None: attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self._dataset = dataset @@ -180,7 +180,7 @@ class Vector: case _: raise ValueError(f"Vector store {vector_type} is not supported.") - def create(self, texts: Optional[list] = None, **kwargs): + def create(self, texts: list | None = None, **kwargs): if texts: start = time.time() logger.info("start embedding %s texts %s", len(texts), start) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 43dde37c7e..3ec08b93ed 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -1,6 +1,6 @@ import datetime import json -from typing import Any, Optional +from typing import Any import requests import weaviate # type: ignore @@ -19,7 +19,7 @@ from models.dataset import Dataset class WeaviateConfig(BaseModel): endpoint: str - api_key: Optional[str] = None + api_key: str | None = None batch_size: int = 100 @model_validator(mode="before") diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 63c6db8d06..74a2653e9d 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any, Optional +from typing import Any from sqlalchemy import func, select @@ -15,7 +15,7 @@ class DatasetDocumentStore: self, dataset: Dataset, user_id: str, - document_id: Optional[str] = None, + document_id: str | None = None, ): self._dataset = dataset self._user_id = user_id @@ -176,7 +176,7 @@ class DatasetDocumentStore: result = self.get_document_segment(doc_id) return result is not None - def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[Document]: + def get_document(self, doc_id: str, raise_error: bool = True) -> Document | None: document_segment = self.get_document_segment(doc_id) if document_segment is None: @@ -217,16 +217,16 @@ class DatasetDocumentStore: document_segment.index_node_hash = doc_hash db.session.commit() - def get_document_hash(self, doc_id: str) -> Optional[str]: + def get_document_hash(self, doc_id: str) -> str | None: """Get the stored hash for a document, if it exists.""" document_segment = self.get_document_segment(doc_id) if document_segment is None: return None - data: Optional[str] = document_segment.index_node_hash + data: str | None = document_segment.index_node_hash return data - def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: + def get_document_segment(self, doc_id: str) -> DocumentSegment | None: stmt = select(DocumentSegment).where( DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id ) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 43be9cde69..5f94129a0c 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -1,6 +1,6 @@ import base64 import logging -from typing import Any, Optional, cast +from typing import Any, cast import numpy as np from sqlalchemy.exc import IntegrityError @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) class CacheEmbedding(Embeddings): - def __init__(self, model_instance: ModelInstance, user: Optional[str] = None): + def __init__(self, model_instance: ModelInstance, user: str | None = None): self._model_instance = model_instance self._user = user diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index 800422d888..8e92191568 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel from models.dataset import DocumentSegment @@ -19,5 +17,5 @@ class RetrievalSegments(BaseModel): model_config = {"arbitrary_types_allowed": True} segment: DocumentSegment - child_chunks: Optional[list[RetrievalChildChunk]] = None - score: Optional[float] = None + child_chunks: list[RetrievalChildChunk] | None = None + score: float | None = None diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index 00120425c9..aca879df7d 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -1,23 +1,23 @@ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel class RetrievalSourceMetadata(BaseModel): - position: Optional[int] = None - dataset_id: Optional[str] = None - dataset_name: Optional[str] = None - document_id: Optional[str] = None - document_name: Optional[str] = None - data_source_type: Optional[str] = None - segment_id: Optional[str] = None - retriever_from: Optional[str] = None - score: Optional[float] = None - hit_count: Optional[int] = None - word_count: Optional[int] = None - segment_position: Optional[int] = None - index_node_hash: Optional[str] = None - content: Optional[str] = None - page: Optional[int] = None - doc_metadata: Optional[dict[str, Any]] = None - title: Optional[str] = None + position: int | None = None + dataset_id: str | None = None + dataset_name: str | None = None + document_id: str | None = None + document_name: str | None = None + data_source_type: str | None = None + segment_id: str | None = None + retriever_from: str | None = None + score: float | None = None + hit_count: int | None = None + word_count: int | None = None + segment_position: int | None = None + index_node_hash: str | None = None + content: str | None = None + page: int | None = None + doc_metadata: dict[str, Any] | None = None + title: str | None = None diff --git a/api/core/rag/entities/context_entities.py b/api/core/rag/entities/context_entities.py index cd18ad081f..a2b03d54ba 100644 --- a/api/core/rag/entities/context_entities.py +++ b/api/core/rag/entities/context_entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel @@ -9,4 +7,4 @@ class DocumentContext(BaseModel): """ content: str - score: Optional[float] = None + score: float | None = None diff --git a/api/core/rag/entities/metadata_entities.py b/api/core/rag/entities/metadata_entities.py index 1f054bccdb..b07d760cf4 100644 --- a/api/core/rag/entities/metadata_entities.py +++ b/api/core/rag/entities/metadata_entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field @@ -30,7 +30,7 @@ SupportedComparisonOperator = Literal[ class Condition(BaseModel): """ - Conditon detail + Condition detail """ name: str @@ -43,5 +43,5 @@ class MetadataCondition(BaseModel): Metadata Condition. """ - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py index 60dbc449f7..1f91a3ece1 100644 --- a/api/core/rag/extractor/blob/blob.py +++ b/api/core/rag/extractor/blob/blob.py @@ -12,7 +12,7 @@ import mimetypes from collections.abc import Generator, Mapping from io import BufferedReader, BytesIO from pathlib import Path, PurePath -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, ConfigDict, model_validator @@ -30,17 +30,17 @@ class Blob(BaseModel): """ data: Union[bytes, str, None] = None # Raw data - mimetype: Optional[str] = None # Not to be confused with a file extension + mimetype: str | None = None # Not to be confused with a file extension encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string # Location where the original content was found # Represent location on the local file system # Useful for situations where downstream code assumes it must work with file paths # rather than in-memory content. - path: Optional[PathLike] = None + path: PathLike | None = None model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) @property - def source(self) -> Optional[str]: + def source(self) -> str | None: """The source location of the blob as string if known otherwise none.""" return str(self.path) if self.path else None @@ -91,7 +91,7 @@ class Blob(BaseModel): path: PathLike, *, encoding: str = "utf-8", - mime_type: Optional[str] = None, + mime_type: str | None = None, guess_type: bool = True, ) -> Blob: """Load the blob from a path like object. @@ -120,8 +120,8 @@ class Blob(BaseModel): data: Union[str, bytes], *, encoding: str = "utf-8", - mime_type: Optional[str] = None, - path: Optional[str] = None, + mime_type: str | None = None, + path: str | None = None, ) -> Blob: """Initialize the blob from in-memory data. diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 5b67403902..3bfae9d6bd 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -1,7 +1,6 @@ """Abstract interface for document loader implementations.""" import csv -from typing import Optional import pandas as pd @@ -21,10 +20,10 @@ class CSVExtractor(BaseExtractor): def __init__( self, file_path: str, - encoding: Optional[str] = None, + encoding: str | None = None, autodetect_encoding: bool = False, - source_column: Optional[str] = None, - csv_args: Optional[dict] = None, + source_column: str | None = None, + csv_args: dict | None = None, ): """Initialize with file path.""" self._file_path = file_path diff --git a/api/core/rag/extractor/entity/datasource_type.py b/api/core/rag/extractor/entity/datasource_type.py index 19ad300d11..6568f60ea2 100644 --- a/api/core/rag/extractor/entity/datasource_type.py +++ b/api/core/rag/extractor/entity/datasource_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class DatasourceType(Enum): +class DatasourceType(StrEnum): FILE = "upload_file" NOTION = "notion_import" WEBSITE = "website_crawl" diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 52d64f591f..04a35d6f1f 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, ConfigDict from models.dataset import Document @@ -14,7 +12,7 @@ class NotionInfo(BaseModel): notion_workspace_id: str notion_obj_id: str notion_page_type: str - document: Optional[Document] = None + document: Document | None = None tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) @@ -43,10 +41,10 @@ class ExtractSetting(BaseModel): """ datasource_type: str - upload_file: Optional[UploadFile] = None - notion_info: Optional[NotionInfo] = None - website_info: Optional[WebsiteInfo] = None - document_model: Optional[str] = None + upload_file: UploadFile | None = None + notion_info: NotionInfo | None = None + website_info: WebsiteInfo | None = None + document_model: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) def __init__(self, **data): diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index baa3fdf2eb..ea9c6bd73a 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" import os -from typing import Optional, cast +from typing import cast import pandas as pd from openpyxl import load_workbook @@ -18,7 +18,7 @@ class ExcelExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): + def __init__(self, file_path: str, encoding: str | None = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index b5ea08173b..0c70844000 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -1,7 +1,7 @@ import re import tempfile from pathlib import Path -from typing import Optional, Union +from typing import Union from urllib.parse import unquote from configs import dify_config @@ -90,7 +90,7 @@ class ExtractProcessor: @classmethod def extract( - cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: Optional[str] = None + cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: @@ -104,7 +104,7 @@ class ExtractProcessor: input_file = Path(file_path) file_extension = input_file.suffix.lower() etl_type = dify_config.ETL_TYPE - extractor: Optional[BaseExtractor] = None + extractor: BaseExtractor | None = None if etl_type == "Unstructured": unstructured_api_url = dify_config.UNSTRUCTURED_API_URL or "" unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY or "" diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 17f7d8661f..00004409d6 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -1,17 +1,17 @@ """Document loader helpers.""" import concurrent.futures -from typing import NamedTuple, Optional, cast +from typing import NamedTuple, cast class FileEncoding(NamedTuple): """A file encoding as the NamedTuple.""" - encoding: Optional[str] + encoding: str | None """The encoding of the file.""" confidence: float """The confidence of the encoding.""" - language: Optional[str] + language: str | None """The language of the file.""" diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index 3845392c8d..79d6ae2dac 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -2,7 +2,6 @@ import re from pathlib import Path -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.helpers import detect_file_encodings @@ -22,7 +21,7 @@ class MarkdownExtractor(BaseExtractor): file_path: str, remove_hyperlinks: bool = False, remove_images: bool = False, - encoding: Optional[str] = None, + encoding: str | None = None, autodetect_encoding: bool = True, ): """Initialize with file path.""" @@ -45,13 +44,13 @@ class MarkdownExtractor(BaseExtractor): return documents - def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]: + def markdown_to_tups(self, markdown_text: str) -> list[tuple[str | None, str]]: """Convert a markdown file to a dictionary. The keys are the headers and the values are the text under each header. """ - markdown_tups: list[tuple[Optional[str], str]] = [] + markdown_tups: list[tuple[str | None, str]] = [] lines = markdown_text.split("\n") current_header = None @@ -94,7 +93,7 @@ class MarkdownExtractor(BaseExtractor): content = re.sub(pattern, r"\1", content) return content - def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]: + def parse_tups(self, filepath: str) -> list[tuple[str | None, str]]: """Parse file into tuples.""" content = "" try: diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index fa96d73cf2..1779f26994 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,7 +1,7 @@ import json import logging import operator -from typing import Any, Optional, cast +from typing import Any, cast import requests from sqlalchemy import select @@ -36,8 +36,8 @@ class NotionExtractor(BaseExtractor): notion_obj_id: str, notion_page_type: str, tenant_id: str, - document_model: Optional[DocumentModel] = None, - notion_access_token: Optional[str] = None, + document_model: DocumentModel | None = None, + notion_access_token: str | None = None, ): self._notion_access_token = None self._document_model = document_model @@ -328,7 +328,7 @@ class NotionExtractor(BaseExtractor): result_lines = "\n".join(result_lines_arr) return result_lines - def update_last_edited_time(self, document_model: Optional[DocumentModel]): + def update_last_edited_time(self, document_model: DocumentModel | None): if not document_model: return diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 3c43f34104..80530d99a6 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -2,7 +2,6 @@ import contextlib from collections.abc import Iterator -from typing import Optional from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor @@ -18,7 +17,7 @@ class PdfExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, file_cache_key: Optional[str] = None): + def __init__(self, file_path: str, file_cache_key: str | None = None): """Initialize with file path.""" self._file_path = file_path self._file_cache_key = file_cache_key diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py index a00d328cb1..93f301ceff 100644 --- a/api/core/rag/extractor/text_extractor.py +++ b/api/core/rag/extractor/text_extractor.py @@ -1,7 +1,6 @@ """Abstract interface for document loader implementations.""" from pathlib import Path -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.helpers import detect_file_encodings @@ -16,7 +15,7 @@ class TextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): + def __init__(self, file_path: str, encoding: str | None = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 4ed8dfbbd8..5199208f70 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -23,7 +23,7 @@ class UnstructuredWordExtractor(BaseExtractor): unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) # check the file extension try: - import magic # noqa: F401 # pyright: ignore[reportUnusedImport] + import magic # noqa: F401 is_doc = detect_filetype(self._file_path) == FileType.DOC except ImportError: diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 2427de8292..ad04bd0bd1 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,7 +1,6 @@ import base64 import contextlib import logging -from typing import Optional from bs4 import BeautifulSoup @@ -17,7 +16,7 @@ class UnstructuredEmailExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index fa91f7dd03..fc14ee6275 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import pypandoc # type: ignore @@ -20,7 +19,7 @@ class UnstructuredEpubExtractor(BaseExtractor): def __init__( self, file_path: str, - api_url: Optional[str] = None, + api_url: str | None = None, api_key: str = "", ): """Initialize with file path.""" diff --git a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 0a0c8d3a1c..23030d7739 100644 --- a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -16,7 +15,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor): """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index d363449c29..f29e639d1b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredMsgExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index ecc272a2f0..c12a55ee4b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredPPTExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index e7bf6fd2e6..99e3eec501 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredPPTXExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index 916cdc3f2b..d75e166f1b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +14,7 @@ class UnstructuredXmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): + def __init__(self, file_path: str, api_url: str | None = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index c59a70ea57..fe983aa86a 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient @@ -9,7 +9,7 @@ class WaterCrawlProvider: def __init__(self, api_key, base_url: str | None = None): self.client = WaterCrawlAPIClient(api_key, base_url) - def crawl_url(self, url, options: Optional[dict | Any] = None): + def crawl_url(self, url, options: dict | Any | None = None): options = options or {} spider_options = { "max_depth": 1, diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py index c8ad53e3dd..1d9ca89ba7 100644 --- a/api/core/rag/index_processor/constant/built_in_field.py +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -1,15 +1,15 @@ -from enum import Enum, StrEnum +from enum import StrEnum, auto class BuiltInField(StrEnum): - document_name = "document_name" - uploader = "uploader" - upload_date = "upload_date" - last_update_date = "last_update_date" - source = "source" + document_name = auto() + uploader = auto() + upload_date = auto() + last_update_date = auto() + source = auto() -class MetadataDataSource(Enum): +class MetadataDataSource(StrEnum): upload_file = "file_upload" website_crawl = "website" notion_import = "notion" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index c099fd1d5c..1e904e72e2 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,7 +1,6 @@ """Abstract interface for document loader implementations.""" from abc import ABC, abstractmethod -from typing import Optional from configs import dify_config from core.model_manager import ModelInstance @@ -31,7 +30,7 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): raise NotImplementedError @abstractmethod @@ -52,7 +51,7 @@ class BaseIndexProcessor(ABC): max_tokens: int, chunk_overlap: int, separator: str, - embedding_model_instance: Optional[ModelInstance], + embedding_model_instance: ModelInstance | None, ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 997b0b953b..5e0b24c354 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,7 +1,6 @@ """Paragraph index processor.""" import uuid -from typing import Optional from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword @@ -85,7 +84,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): else: keyword.add_texts(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index d1088af853..f87e61b51c 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -1,7 +1,6 @@ """Paragraph index processor.""" import uuid -from typing import Optional from configs import dify_config from core.model_manager import ModelInstance @@ -109,25 +108,37 @@ class ParentChildIndexProcessor(BaseIndexProcessor): ] vector.create(formatted_child_documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids if dataset.indexing_technique == "high_quality": delete_child_chunks = kwargs.get("delete_child_chunks") or False + precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") vector = Vector(dataset) + if node_ids: - child_node_ids = ( - db.session.query(ChildChunk.index_node_id) - .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) - .where( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.index_node_id.in_(node_ids), - ChildChunk.dataset_id == dataset.id, + # Use precomputed child_node_ids if available (to avoid race conditions) + if precomputed_child_node_ids is not None: + child_node_ids = precomputed_child_node_ids + else: + # Fallback to original query (may fail if segments are already deleted) + child_node_ids = ( + db.session.query(ChildChunk.index_node_id) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() ) - .all() - ) - child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] - vector.delete_by_ids(child_node_ids) - if delete_child_chunks: + child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]] + + # Delete from vector index + if child_node_ids: + vector.delete_by_ids(child_node_ids) + + # Delete from database + if delete_child_chunks and child_node_ids: db.session.query(ChildChunk).where( ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) ).delete(synchronize_session=False) @@ -175,7 +186,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): document_node: Document, rules: Rule, process_rule_mode: str, - embedding_model_instance: Optional[ModelInstance], + embedding_model_instance: ModelInstance | None, ) -> list[ChildDocument]: if not rules.subchunk_segmentation: raise ValueError("No subchunk segmentation found in rules.") diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index a4ec828e2f..2ca444ca86 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,7 +4,6 @@ import logging import re import threading import uuid -from typing import Optional import pandas as pd from flask import Flask, current_app @@ -128,7 +127,7 @@ class QAIndexProcessor(BaseIndexProcessor): vector = Vector(dataset) vector.create(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index ff63a6780e..b70d8bf559 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -10,7 +10,7 @@ class ChildDocument(BaseModel): page_content: str - vector: Optional[list[float]] = None + vector: list[float] | None = None """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). @@ -23,16 +23,16 @@ class Document(BaseModel): page_content: str - vector: Optional[list[float]] = None + vector: list[float] | None = None """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ metadata: dict = Field(default_factory=dict) - provider: Optional[str] = "dify" + provider: str | None = "dify" - children: Optional[list[ChildDocument]] = None + children: list[ChildDocument] | None = None class BaseDocumentTransformer(ABC): diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 818b04b2ff..3561def008 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional from core.rag.models.document import Document @@ -10,9 +9,9 @@ class BaseRerankRunner(ABC): self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 7a6ebd1f39..e855b0083f 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.model_manager import ModelInstance from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner @@ -13,9 +11,9 @@ class RerankModelRunner(BaseRerankRunner): self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index ab49e43b70..c455db6095 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -1,6 +1,5 @@ import math from collections import Counter -from typing import Optional import numpy as np @@ -22,9 +21,9 @@ class WeightRerankRunner(BaseRerankRunner): self, query: str, documents: list[Document], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 93bad23f2b..b08f80da49 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -4,7 +4,7 @@ import re import threading from collections import Counter, defaultdict from collections.abc import Generator, Mapping -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from flask import Flask, current_app from sqlalchemy import Float, and_, or_, select, text @@ -85,9 +85,9 @@ class DatasetRetrieval: show_retrieve_source: bool, hit_callback: DatasetIndexToolCallbackHandler, message_id: str, - memory: Optional[TokenBufferMemory] = None, - inputs: Optional[Mapping[str, Any]] = None, - ) -> Optional[str]: + memory: TokenBufferMemory | None = None, + inputs: Mapping[str, Any] | None = None, + ) -> str | None: """ Retrieve dataset. :param app_id: app_id @@ -290,9 +290,9 @@ class DatasetRetrieval: model_instance: ModelInstance, model_config: ModelConfigWithCredentialsEntity, planning_strategy: PlanningStrategy, - message_id: Optional[str] = None, - metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, - metadata_condition: Optional[MetadataCondition] = None, + message_id: str | None = None, + metadata_filter_document_ids: dict[str, list[str]] | None = None, + metadata_condition: MetadataCondition | None = None, ): tools = [] for dataset in available_datasets: @@ -410,12 +410,12 @@ class DatasetRetrieval: top_k: int, score_threshold: float, reranking_mode: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict[str, Any]] = None, + reranking_model: dict | None = None, + weights: dict[str, Any] | None = None, reranking_enable: bool = True, - message_id: Optional[str] = None, - metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, - metadata_condition: Optional[MetadataCondition] = None, + message_id: str | None = None, + metadata_filter_document_ids: dict[str, list[str]] | None = None, + metadata_condition: MetadataCondition | None = None, ): if not available_datasets: return [] @@ -505,9 +505,7 @@ class DatasetRetrieval: return all_documents - def _on_retrieval_end( - self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None - ): + def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None): """Handle retrieval end.""" dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: @@ -588,8 +586,8 @@ class DatasetRetrieval: query: str, top_k: int, all_documents: list, - document_ids_filter: Optional[list[str]] = None, - metadata_condition: Optional[MetadataCondition] = None, + document_ids_filter: list[str] | None = None, + metadata_condition: MetadataCondition | None = None, ): with flask_app.app_context(): dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -664,7 +662,7 @@ class DatasetRetrieval: hit_callback: DatasetIndexToolCallbackHandler, user_id: str, inputs: dict, - ) -> Optional[list[DatasetRetrieverBaseTool]]: + ) -> list[DatasetRetrieverBaseTool] | None: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id @@ -853,9 +851,9 @@ class DatasetRetrieval: user_id: str, metadata_filtering_mode: str, metadata_model_config: ModelConfig, - metadata_filtering_conditions: Optional[MetadataFilteringCondition], + metadata_filtering_conditions: MetadataFilteringCondition | None, inputs: dict, - ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: + ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: document_query = db.session.query(DatasetDocument).where( DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.indexing_status == "completed", @@ -950,7 +948,7 @@ class DatasetRetrieval: def _automatic_metadata_filter_func( self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig - ) -> Optional[list[dict[str, Any]]]: + ) -> list[dict[str, Any]] | None: # get all metadata field metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) metadata_fields = db.session.scalars(metadata_stmt).all() @@ -1005,7 +1003,7 @@ class DatasetRetrieval: return automatic_metadata_filters def _process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list + self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list ): if value is None and condition not in ("empty", "not empty"): return diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index d654463be9..8356861242 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any from core.model_manager import ModelInstance from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer @@ -24,7 +24,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): @classmethod def from_encoder( cls: type[TS], - embedding_model_instance: Optional[ModelInstance], + embedding_model_instance: ModelInstance | None, allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037 disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 **kwargs: Any, @@ -48,7 +48,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): - def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any): + def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) self._fixed_separator = fixed_separator diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index c5b6ac4608..41e6d771e9 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -9,7 +9,6 @@ from dataclasses import dataclass from typing import ( Any, Literal, - Optional, TypeVar, Union, ) @@ -71,7 +70,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: list[dict] | None = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] @@ -94,7 +93,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): metadatas.append(doc.metadata or {}) return self.create_documents(texts, metadatas=metadatas) - def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: + def _join_docs(self, docs: list[str], separator: str) -> str | None: text = separator.join(docs) text = text.strip() if text == "": @@ -110,9 +109,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): docs = [] current_doc: list[str] = [] total = 0 - index = 0 - for d in splits: - _len = lengths[index] + for d, _len in zip(splits, lengths): if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: if total > self._chunk_size: logger.warning( @@ -134,7 +131,6 @@ class TextSplitter(BaseDocumentTransformer, ABC): current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) - index += 1 doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) @@ -197,7 +193,7 @@ class TokenTextSplitter(TextSplitter): def __init__( self, encoding_name: str = "gpt2", - model_name: Optional[str] = None, + model_name: str | None = None, allowed_special: Union[Literal["all"], Set[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, @@ -248,7 +244,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): def __init__( self, - separators: Optional[list[str]] = None, + separators: list[str] | None = None, keep_separator: bool = True, **kwargs: Any, ): diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index d6f40491b6..eda7b54d6a 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -6,7 +6,7 @@ providing improved performance by offloading database operations to background w """ import logging -from typing import Optional, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -39,8 +39,8 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): _session_factory: sessionmaker _tenant_id: str - _app_id: Optional[str] - _triggered_from: Optional[WorkflowRunTriggeredFrom] + _app_id: str | None + _triggered_from: WorkflowRunTriggeredFrom | None _creator_user_id: str _creator_user_role: CreatorUserRole @@ -48,8 +48,8 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: Optional[str], - triggered_from: Optional[WorkflowRunTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowRunTriggeredFrom | None, ): """ Initialize the repository with Celery task configuration and context information. diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 95ad9f25fe..21a0b7eefe 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -7,7 +7,7 @@ providing improved performance by offloading database operations to background w import logging from collections.abc import Sequence -from typing import Optional, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -44,8 +44,8 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): _session_factory: sessionmaker _tenant_id: str - _app_id: Optional[str] - _triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom] + _app_id: str | None + _triggered_from: WorkflowNodeExecutionTriggeredFrom | None _creator_user_id: str _creator_user_role: CreatorUserRole _execution_cache: dict[str, WorkflowNodeExecution] @@ -55,8 +55,8 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: Optional[str], - triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowNodeExecutionTriggeredFrom | None, ): """ Initialize the repository with Celery task configuration and context information. @@ -151,7 +151,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): def get_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache. diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 46b028b219..7d1069e28f 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -4,7 +4,7 @@ SQLAlchemy implementation of the WorkflowExecutionRepository. import json import logging -from typing import Optional, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -44,8 +44,8 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: Optional[str], - triggered_from: Optional[WorkflowRunTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowRunTriggeredFrom | None, ): """ Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. @@ -159,7 +159,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): else None ) db_model.status = domain_model.status - db_model.error = domain_model.error_message if domain_model.error_message else None + db_model.error = domain_model.error_message or None db_model.total_tokens = domain_model.total_tokens db_model.total_steps = domain_model.total_steps db_model.exceptions_count = domain_model.exceptions_count diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 8702af9f80..de5fca9f44 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -5,7 +5,7 @@ SQLAlchemy implementation of the WorkflowNodeExecutionRepository. import json import logging from collections.abc import Sequence -from typing import Optional, Union +from typing import Union import psycopg2.errors from sqlalchemy import UnaryExpression, asc, desc, select @@ -52,8 +52,8 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) self, session_factory: sessionmaker | Engine, user: Union[Account, EndUser], - app_id: Optional[str], - triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], + app_id: str | None, + triggered_from: WorkflowNodeExecutionTriggeredFrom | None, ): """ Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. @@ -279,7 +279,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) def get_db_models_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecutionModel]: """ Retrieve all WorkflowNodeExecution database models for a specific workflow run. @@ -334,7 +334,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) def get_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index eb179f15ee..82616596f8 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Generator from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from models.model import File @@ -46,9 +46,9 @@ class Tool(ABC): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage]: if self.runtime and self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters) @@ -96,17 +96,17 @@ class Tool(ABC): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: pass def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get the runtime parameters @@ -119,9 +119,9 @@ class Tool(ABC): def get_merged_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get merged runtime parameters @@ -196,7 +196,7 @@ class Tool(ABC): message=ToolInvokeMessage.TextMessage(text=text), ) - def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage: + def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage: """ create a blob message diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index ddec7b1329..3de0014c61 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from openai import BaseModel from pydantic import Field @@ -13,9 +13,9 @@ class ToolRuntime(BaseModel): """ tenant_id: str - tool_id: Optional[str] = None - invoke_from: Optional[InvokeFrom] = None - tool_invoke_from: Optional[ToolInvokeFrom] = None + tool_id: str | None = None + invoke_from: InvokeFrom | None = None + tool_invoke_from: ToolInvokeFrom | None = None credentials: dict[str, Any] = Field(default_factory=dict) credential_type: CredentialType = Field(default=CredentialType.API_KEY) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 68bfe5b4a5..45fd16d684 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -18,7 +18,7 @@ from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolProviderNotFoundError, ) -from core.tools.utils.yaml_utils import load_yaml_file +from core.tools.utils.yaml_utils import load_yaml_file_cached class BuiltinToolProviderController(ToolProviderController): @@ -31,7 +31,7 @@ class BuiltinToolProviderController(ToolProviderController): provider = self.__class__.__module__.split(".")[-1] yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml") try: - provider_yaml = load_yaml_file(yaml_path, ignore_error=False) + provider_yaml = load_yaml_file_cached(yaml_path) except Exception as e: raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}") @@ -71,7 +71,7 @@ class BuiltinToolProviderController(ToolProviderController): for tool_file in tool_files: # get tool name tool_name = tool_file.split(".")[0] - tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False) + tool = load_yaml_file_cached(path.join(tool_path, tool_file)) # get tool class, import the module assistant_tool_class: type = load_single_subclass_from_source( diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 5c24920871..af9b5b31c2 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -1,6 +1,6 @@ import io from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.file.enums import FileType from core.file.file_manager import download @@ -18,9 +18,9 @@ class ASRTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: file = tool_parameters.get("audio_file") if file.type != FileType.AUDIO: # type: ignore @@ -56,9 +56,9 @@ class ASRTool(BuiltinTool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: parameters = [] diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index f191968812..8bc159bb85 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -1,6 +1,6 @@ import io from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -16,9 +16,9 @@ class TTSTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: provider, model = tool_parameters.get("model").split("#") # type: ignore voice = tool_parameters.get(f"voice#{provider}#{model}") @@ -72,9 +72,9 @@ class TTSTool(BuiltinTool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: parameters = [] diff --git a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py index b4e650e0ed..4383943199 100644 --- a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py +++ b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage from core.tools.builtin_tool.tool import BuiltinTool @@ -12,9 +12,9 @@ class SimpleCode(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke simple code diff --git a/api/core/tools/builtin_tool/providers/time/tools/current_time.py b/api/core/tools/builtin_tool/providers/time/tools/current_time.py index d054afac96..44f94c2723 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/current_time.py +++ b/api/core/tools/builtin_tool/providers/time/tools/current_time.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import UTC, datetime -from typing import Any, Optional +from typing import Any from pytz import timezone as pytz_timezone @@ -13,9 +13,9 @@ class CurrentTimeTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tools diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index a8fd6ec2cd..197b062e44 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any import pytz @@ -14,9 +14,9 @@ class LocaltimeToTimestampTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Convert localtime to timestamp diff --git a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py index 0ef6331530..462e4be5ce 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any import pytz @@ -14,9 +14,9 @@ class TimestampToLocaltimeTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Convert timestamp to localtime diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index 91316b859a..babfa9bcd9 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any import pytz @@ -14,9 +14,9 @@ class TimezoneConversionTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Convert time to equivalent time zone diff --git a/api/core/tools/builtin_tool/providers/time/tools/weekday.py b/api/core/tools/builtin_tool/providers/time/tools/weekday.py index 158ce701c0..e26b316bd5 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/weekday.py +++ b/api/core/tools/builtin_tool/providers/time/tools/weekday.py @@ -1,7 +1,7 @@ import calendar from collections.abc import Generator from datetime import datetime -from typing import Any, Optional +from typing import Any from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage @@ -12,9 +12,9 @@ class WeekdayTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Calculate the day of the week for a given date diff --git a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py index 3bee710879..9d668ac9eb 100644 --- a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py +++ b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage @@ -12,9 +12,9 @@ class WebscraperTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tools diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 190af999b1..13dd2114d3 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -2,7 +2,7 @@ import json from collections.abc import Generator from dataclasses import dataclass from os import getenv -from typing import Any, Optional, Union +from typing import Any, Union from urllib.parse import urlencode import httpx @@ -376,9 +376,9 @@ class ApiTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke http request diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 27fe70bfc4..1eacd198cb 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -14,9 +14,9 @@ class ToolApiEntity(BaseModel): name: str # identifier label: I18nObject # label description: I18nObject - parameters: Optional[list[ToolParameter]] = None + parameters: list[ToolParameter] | None = None labels: list[str] = Field(default_factory=list) - output_schema: Optional[dict] = None + output_schema: dict | None = None ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]] @@ -28,24 +28,25 @@ class ToolProviderApiEntity(BaseModel): name: str # identifier description: I18nObject icon: str | dict - icon_dark: Optional[str | dict] = Field(default=None, description="The dark icon of the tool") + icon_dark: str | dict | None = Field(default=None, description="The dark icon of the tool") label: I18nObject # label type: ToolProviderType - masked_credentials: Optional[dict] = None - original_credentials: Optional[dict] = None + masked_credentials: dict | None = None + original_credentials: dict | None = None is_team_authorization: bool = False allow_delete: bool = True - plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") - plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") + plugin_id: str | None = Field(default="", description="The plugin id of the tool") + plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool") tools: list[ToolApiEntity] = Field(default_factory=list) labels: list[str] = Field(default_factory=list) # MCP - server_url: Optional[str] = Field(default="", description="The server url of the tool") + server_url: str | None = Field(default="", description="The server url of the tool") updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) - server_identifier: Optional[str] = Field(default="", description="The server identifier of the MCP tool") - timeout: Optional[float] = Field(default=30.0, description="The timeout of the MCP tool") - sse_read_timeout: Optional[float] = Field(default=300.0, description="The SSE read timeout of the MCP tool") - masked_headers: Optional[dict[str, str]] = Field(default=None, description="The masked headers of the MCP tool") + server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool") + timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool") + sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool") + masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool") + original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool") @field_validator("tools", mode="before") @classmethod diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index aadbbeb843..2c6d9c1964 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field @@ -9,9 +7,9 @@ class I18nObject(BaseModel): """ en_US: str - zh_Hans: Optional[str] = Field(default=None) - pt_BR: Optional[str] = Field(default=None) - ja_JP: Optional[str] = Field(default=None) + zh_Hans: str | None = Field(default=None) + pt_BR: str | None = Field(default=None) + ja_JP: str | None = Field(default=None) def __init__(self, **data): super().__init__(**data) diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index ffeeabbc1c..eba20b07f0 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel from core.tools.entities.tool_entities import ToolParameter @@ -16,14 +14,14 @@ class ApiToolBundle(BaseModel): # method method: str # summary - summary: Optional[str] = None + summary: str | None = None # operation_id - operation_id: Optional[str] = None + operation_id: str | None = None # parameters - parameters: Optional[list[ToolParameter]] = None + parameters: list[ToolParameter] | None = None # author author: str # icon - icon: Optional[str] = None + icon: str | None = None # openapi operation openapi: dict diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 66304b30a5..62dad1a50b 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,9 +1,8 @@ import base64 import contextlib -import enum from collections.abc import Mapping -from enum import Enum -from typing import Any, Optional, Union +from enum import StrEnum, auto +from typing import Any, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator @@ -22,37 +21,37 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY -class ToolLabelEnum(Enum): - SEARCH = "search" - IMAGE = "image" - VIDEOS = "videos" - WEATHER = "weather" - FINANCE = "finance" - DESIGN = "design" - TRAVEL = "travel" - SOCIAL = "social" - NEWS = "news" - MEDICAL = "medical" - PRODUCTIVITY = "productivity" - EDUCATION = "education" - BUSINESS = "business" - ENTERTAINMENT = "entertainment" - UTILITIES = "utilities" - OTHER = "other" +class ToolLabelEnum(StrEnum): + SEARCH = auto() + IMAGE = auto() + VIDEOS = auto() + WEATHER = auto() + FINANCE = auto() + DESIGN = auto() + TRAVEL = auto() + SOCIAL = auto() + NEWS = auto() + MEDICAL = auto() + PRODUCTIVITY = auto() + EDUCATION = auto() + BUSINESS = auto() + ENTERTAINMENT = auto() + UTILITIES = auto() + OTHER = auto() -class ToolProviderType(enum.StrEnum): +class ToolProviderType(StrEnum): """ Enum class for tool provider """ - PLUGIN = "plugin" + PLUGIN = auto() BUILT_IN = "builtin" - WORKFLOW = "workflow" - API = "api" - APP = "app" + WORKFLOW = auto() + API = auto() + APP = auto() DATASET_RETRIEVAL = "dataset-retrieval" - MCP = "mcp" + MCP = auto() @classmethod def value_of(cls, value: str) -> "ToolProviderType": @@ -68,15 +67,15 @@ class ToolProviderType(enum.StrEnum): raise ValueError(f"invalid mode value {value}") -class ApiProviderSchemaType(Enum): +class ApiProviderSchemaType(StrEnum): """ Enum class for api provider schema type. """ - OPENAPI = "openapi" - SWAGGER = "swagger" - OPENAI_PLUGIN = "openai_plugin" - OPENAI_ACTIONS = "openai_actions" + OPENAPI = auto() + SWAGGER = auto() + OPENAI_PLUGIN = auto() + OPENAI_ACTIONS = auto() @classmethod def value_of(cls, value: str) -> "ApiProviderSchemaType": @@ -92,14 +91,14 @@ class ApiProviderSchemaType(Enum): raise ValueError(f"invalid mode value {value}") -class ApiProviderAuthType(Enum): +class ApiProviderAuthType(StrEnum): """ Enum class for api provider auth type. """ - NONE = "none" - API_KEY_HEADER = "api_key_header" - API_KEY_QUERY = "api_key_query" + NONE = auto() + API_KEY_HEADER = auto() + API_KEY_QUERY = auto() @classmethod def value_of(cls, value: str) -> "ApiProviderAuthType": @@ -176,36 +175,36 @@ class ToolInvokeMessage(BaseModel): return value class LogMessage(BaseModel): - class LogStatus(Enum): - START = "start" - ERROR = "error" - SUCCESS = "success" + class LogStatus(StrEnum): + START = auto() + ERROR = auto() + SUCCESS = auto() id: str label: str = Field(..., description="The label of the log") - parent_id: Optional[str] = Field(default=None, description="Leave empty for root log") - error: Optional[str] = Field(default=None, description="The error message") + parent_id: str | None = Field(default=None, description="Leave empty for root log") + error: str | None = Field(default=None, description="The error message") status: LogStatus = Field(..., description="The status of the log") data: Mapping[str, Any] = Field(..., description="Detailed log data") - metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") + metadata: Mapping[str, Any] | None = Field(default=None, description="The metadata of the log") class RetrieverResourceMessage(BaseModel): retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") - class MessageType(Enum): - TEXT = "text" - IMAGE = "image" - LINK = "link" - BLOB = "blob" - JSON = "json" - IMAGE_LINK = "image_link" - BINARY_LINK = "binary_link" - VARIABLE = "variable" - FILE = "file" - LOG = "log" - BLOB_CHUNK = "blob_chunk" - RETRIEVER_RESOURCES = "retriever_resources" + class MessageType(StrEnum): + TEXT = auto() + IMAGE = auto() + LINK = auto() + BLOB = auto() + JSON = auto() + IMAGE_LINK = auto() + BINARY_LINK = auto() + VARIABLE = auto() + FILE = auto() + LOG = auto() + BLOB_CHUNK = auto() + RETRIEVER_RESOURCES = auto() type: MessageType = MessageType.TEXT """ @@ -242,7 +241,7 @@ class ToolInvokeMessage(BaseModel): class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") - file_var: Optional[dict[str, Any]] = None + file_var: dict[str, Any] | None = None class ToolParameter(PluginParameter): @@ -250,29 +249,29 @@ class ToolParameter(PluginParameter): Overrides type """ - class ToolParameterType(enum.StrEnum): + class ToolParameterType(StrEnum): """ removes TOOLS_SELECTOR from PluginParameterType """ - STRING = PluginParameterType.STRING.value - NUMBER = PluginParameterType.NUMBER.value - BOOLEAN = PluginParameterType.BOOLEAN.value - SELECT = PluginParameterType.SELECT.value - SECRET_INPUT = PluginParameterType.SECRET_INPUT.value - FILE = PluginParameterType.FILE.value - FILES = PluginParameterType.FILES.value - APP_SELECTOR = PluginParameterType.APP_SELECTOR.value - MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value - ANY = PluginParameterType.ANY.value - DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value + STRING = PluginParameterType.STRING + NUMBER = PluginParameterType.NUMBER + BOOLEAN = PluginParameterType.BOOLEAN + SELECT = PluginParameterType.SELECT + SECRET_INPUT = PluginParameterType.SECRET_INPUT + FILE = PluginParameterType.FILE + FILES = PluginParameterType.FILES + APP_SELECTOR = PluginParameterType.APP_SELECTOR + MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR + ANY = PluginParameterType.ANY + DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT # MCP object and array type parameters - ARRAY = MCPServerParameterType.ARRAY.value - OBJECT = MCPServerParameterType.OBJECT.value + ARRAY = MCPServerParameterType.ARRAY + OBJECT = MCPServerParameterType.OBJECT # deprecated, should not use. - SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value + SYSTEM_FILES = PluginParameterType.SYSTEM_FILES def as_normal_type(self): return as_normal_type(self) @@ -280,17 +279,17 @@ class ToolParameter(PluginParameter): def cast_value(self, value: Any): return cast_parameter_value(self, value) - class ToolParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool - FORM = "form" # should be set before invoking tool - LLM = "llm" # will be set by LLM + class ToolParameterForm(StrEnum): + SCHEMA = auto() # should be set while adding tool + FORM = auto() # should be set before invoking tool + LLM = auto() # will be set by LLM type: ToolParameterType = Field(..., description="The type of the parameter") - human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") + human_description: I18nObject | None = Field(default=None, description="The description presented to the user") form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") - llm_description: Optional[str] = None + llm_description: str | None = None # MCP object and array type parameters use this field to store the schema - input_schema: Optional[dict] = None + input_schema: dict | None = None @classmethod def get_simple_instance( @@ -299,7 +298,7 @@ class ToolParameter(PluginParameter): llm_description: str, typ: ToolParameterType, required: bool, - options: Optional[list[str]] = None, + options: list[str] | None = None, ) -> "ToolParameter": """ get a simple tool parameter @@ -340,9 +339,9 @@ class ToolProviderIdentity(BaseModel): name: str = Field(..., description="The name of the tool") description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") - icon_dark: Optional[str] = Field(default=None, description="The dark icon of the tool") + icon_dark: str | None = Field(default=None, description="The dark icon of the tool") label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field( + tags: list[ToolLabelEnum] | None = Field( default=[], description="The tags of the tool", ) @@ -353,7 +352,7 @@ class ToolIdentity(BaseModel): name: str = Field(..., description="The name of the tool") label: I18nObject = Field(..., description="The label of the tool") provider: str = Field(..., description="The provider of the tool") - icon: Optional[str] = None + icon: str | None = None class ToolDescription(BaseModel): @@ -364,8 +363,8 @@ class ToolDescription(BaseModel): class ToolEntity(BaseModel): identity: ToolIdentity parameters: list[ToolParameter] = Field(default_factory=list) - description: Optional[ToolDescription] = None - output_schema: Optional[dict] = None + description: ToolDescription | None = None + output_schema: dict | None = None has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") # pydantic configs @@ -386,9 +385,9 @@ class OAuthSchema(BaseModel): class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity - plugin_id: Optional[str] = None + plugin_id: str | None = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) - oauth_schema: Optional[OAuthSchema] = None + oauth_schema: OAuthSchema | None = None class ToolProviderEntityWithPlugin(ToolProviderEntity): @@ -411,8 +410,8 @@ class ToolInvokeMeta(BaseModel): """ time_cost: float = Field(..., description="The time cost of the tool invoke") - error: Optional[str] = None - tool_config: Optional[dict] = None + error: str | None = None + tool_config: dict | None = None @classmethod def empty(cls) -> "ToolInvokeMeta": @@ -446,14 +445,14 @@ class ToolLabel(BaseModel): icon: str = Field(..., description="The icon of the tool") -class ToolInvokeFrom(Enum): +class ToolInvokeFrom(StrEnum): """ Enum class for tool invoke """ - WORKFLOW = "workflow" - AGENT = "agent" - PLUGIN = "plugin" + WORKFLOW = auto() + AGENT = auto() + PLUGIN = auto() class ToolSelector(BaseModel): @@ -464,11 +463,11 @@ class ToolSelector(BaseModel): type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter") required: bool = Field(..., description="Whether the parameter is required") description: str = Field(..., description="The description of the parameter") - default: Optional[Union[int, float, str]] = None - options: Optional[list[PluginParameterOption]] = None + default: Union[int, float, str] | None = None + options: list[PluginParameterOption] | None = None provider_id: str = Field(..., description="The id of the provider") - credential_id: Optional[str] = Field(default=None, description="The id of the credential") + credential_id: str | None = Field(default=None, description="The id of the credential") tool_name: str = Field(..., description="The name of the tool") tool_description: str = Field(..., description="The description of the tool") tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") @@ -478,9 +477,9 @@ class ToolSelector(BaseModel): return self.model_dump() -class CredentialType(enum.StrEnum): +class CredentialType(StrEnum): API_KEY = "api-key" - OAUTH2 = "oauth2" + OAUTH2 = auto() def get_name(self): if self == CredentialType.API_KEY: diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index bd8bc73e63..aac2c404ea 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Self +from typing import Any, Self from core.entities.mcp_provider import MCPProviderEntity from core.mcp.types import Tool as RemoteMCPTool @@ -25,9 +25,9 @@ class MCPToolProviderController(ToolProviderController): provider_id: str, tenant_id: str, server_url: str, - headers: Optional[dict[str, str]] = None, - timeout: Optional[float] = None, - sse_read_timeout: Optional[float] = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + sse_read_timeout: float | None = None, ): super().__init__(entity) self.entity: ToolProviderEntityWithPlugin = entity diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 6398cc56a9..12749c9b89 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -1,7 +1,7 @@ import base64 import json from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry @@ -22,9 +22,9 @@ class MCPTool(Tool): icon: str, server_url: str, provider_id: str, - headers: Optional[dict[str, str]] = None, - timeout: Optional[float] = None, - sse_read_timeout: Optional[float] = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + sse_read_timeout: float | None = None, ): super().__init__(entity, runtime) self.tenant_id = tenant_id @@ -42,9 +42,9 @@ class MCPTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: result = self.invoke_remote_mcp_tool(tool_parameters) # handle dify tool output @@ -144,11 +144,12 @@ class MCPTool(Tool): if mcp_service: try: provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True) - + headers = provider_entity.decrypt_headers() # Try to get existing token and add to headers - tokens = provider_entity.retrieve_tokens() - if tokens and tokens.access_token: - headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" + if not headers: + tokens = provider_entity.retrieve_tokens() + if tokens and tokens.access_token: + headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" except Exception: # If provider retrieval or token fails, continue without auth pass diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index e649caec1d..828dc3b810 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.plugin.impl.tool import PluginToolManager from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -16,7 +16,7 @@ class PluginTool(Tool): self.tenant_id = tenant_id self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - self.runtime_parameters: Optional[list[ToolParameter]] = None + self.runtime_parameters: list[ToolParameter] | None = None def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.PLUGIN @@ -25,9 +25,9 @@ class PluginTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: manager = PluginToolManager() @@ -57,9 +57,9 @@ class PluginTool(Tool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: """ get the runtime parameters diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index c3fdc37303..0154ffe883 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Iterable from copy import deepcopy from datetime import UTC, datetime from mimetypes import guess_type -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from yarl import URL @@ -51,10 +51,10 @@ class ToolEngine: message: Message, invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, - trace_manager: Optional[TraceQueueManager] = None, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + trace_manager: TraceQueueManager | None = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> tuple[str, list[str], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. @@ -152,10 +152,10 @@ class ToolEngine: user_id: str, workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_call_depth: int, - thread_pool_id: Optional[str] = None, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + thread_pool_id: str | None = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Workflow invokes the tool with the given arguments. @@ -196,9 +196,9 @@ class ToolEngine: tool: Tool, tool_parameters: dict, user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: """ Invoke the tool with the given arguments. diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index ad650196ce..6289f1d335 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -6,7 +6,7 @@ import os import time from collections.abc import Generator from mimetypes import guess_extension, guess_type -from typing import Optional, Union +from typing import Union from uuid import uuid4 import httpx @@ -72,10 +72,10 @@ class ToolFileManager: *, user_id: str, tenant_id: str, - conversation_id: Optional[str], + conversation_id: str | None, file_binary: bytes, mimetype: str, - filename: Optional[str] = None, + filename: str | None = None, ) -> ToolFile: extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex @@ -112,7 +112,7 @@ class ToolFileManager: user_id: str, tenant_id: str, file_url: str, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> ToolFile: # try to download image try: @@ -217,7 +217,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Optional[Generator], Optional[ToolFile]]: + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: """ get file binary diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index a8f7267d35..cca24978ec 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,7 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Union, cast import sqlalchemy as sa from pydantic import TypeAdapter @@ -157,7 +157,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: """ get the tool runtime @@ -358,7 +358,7 @@ class ToolManager: app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional[VariablePool] = None, + variable_pool: VariablePool | None = None, ) -> Tool: """ get the agent tool runtime @@ -400,7 +400,7 @@ class ToolManager: node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional[VariablePool] = None, + variable_pool: VariablePool | None = None, ) -> Tool: """ get the workflow tool runtime @@ -443,7 +443,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Tool: """ get tool runtime from plugin @@ -973,7 +973,7 @@ class ToolManager: def _convert_tool_parameters_type( cls, parameters: list[ToolParameter], - variable_pool: Optional[VariablePool], + variable_pool: VariablePool | None, tool_configurations: dict[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py index 2e572099b3..ac2967d0c1 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -13,7 +12,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): description: str = "use this to retrieve a dataset. " tenant_id: str top_k: int = 4 - score_threshold: Optional[float] = None + score_threshold: float | None = None hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] return_resource: bool retriever_from: str diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index b536c5a25c..0e2237befd 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, cast +from typing import Any, cast from pydantic import BaseModel, Field from sqlalchemy import select @@ -37,7 +37,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): args_schema: type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " dataset_id: str - user_id: Optional[str] = None + user_id: str | None = None retrieve_config: DatasetRetrieveConfigEntity inputs: dict diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index d5803e33e7..a62d419243 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, Optional +from typing import Any from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -87,9 +87,9 @@ class DatasetRetrieverTool(Tool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> list[ToolParameter]: return [ ToolParameter( @@ -112,9 +112,9 @@ class DatasetRetrieverTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke dataset retriever tool diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py index 5820be0ffb..45ad14cb8e 100644 --- a/api/core/tools/utils/encryption.py +++ b/api/core/tools/utils/encryption.py @@ -1,6 +1,6 @@ import contextlib from copy import deepcopy -from typing import Any, Optional, Protocol +from typing import Any, Protocol from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter @@ -13,7 +13,7 @@ class ProviderConfigCache(Protocol): Interface for provider configuration cache operations """ - def get(self) -> Optional[dict]: + def get(self) -> dict | None: """Get cached provider configuration""" ... diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index bf075bd730..0851a54338 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -3,7 +3,6 @@ from collections.abc import Generator from datetime import date, datetime from decimal import Decimal from mimetypes import guess_extension -from typing import Optional from uuid import UUID import numpy as np @@ -60,7 +59,7 @@ class ToolFileMessageTransformer: messages: Generator[ToolInvokeMessage, None, None], user_id: str, tenant_id: str, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Transform tool message and handle file download @@ -165,5 +164,5 @@ class ToolFileMessageTransformer: yield message @classmethod - def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: + def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str: return f"/files/tools/{tool_file_id}{extension or '.bin'}" diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 251d914800..526f5c8b9a 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -5,7 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models. """ import json -from typing import Optional, cast +from typing import cast from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult @@ -51,7 +51,7 @@ class ModelInvocationUtils: if not schema: raise InvokeModelError("No model schema found") - max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) + max_tokens: int | None = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index cae21633fe..2e306db6c7 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -2,7 +2,6 @@ import re from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError -from typing import Optional from flask import request from requests import get @@ -198,9 +197,9 @@ class ApiBasedToolSchemaParser: return bundles @staticmethod - def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: + def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None: parameter = parameter or {} - typ: Optional[str] = None + typ: str | None = None if parameter.get("format") == "binary": return ToolParameter.ToolParameterType.FILE diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_oauth_encryption.py index f3c946b95f..6b7007842d 100644 --- a/api/core/tools/utils/system_oauth_encryption.py +++ b/api/core/tools/utils/system_oauth_encryption.py @@ -2,7 +2,7 @@ import base64 import hashlib import logging from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from Crypto.Cipher import AES from Crypto.Random import get_random_bytes @@ -28,7 +28,7 @@ class SystemOAuthEncrypter: using AES-CBC mode with a key derived from the application's SECRET_KEY. """ - def __init__(self, secret_key: Optional[str] = None): + def __init__(self, secret_key: str | None = None): """ Initialize the OAuth encrypter. @@ -130,7 +130,7 @@ class SystemOAuthEncrypter: # Factory function for creating encrypter instances -def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter: +def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter: """ Create an OAuth encrypter instance. @@ -144,7 +144,7 @@ def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAu # Global encrypter instance (for backward compatibility) -_oauth_encrypter: Optional[SystemOAuthEncrypter] = None +_oauth_encrypter: SystemOAuthEncrypter | None = None def get_system_oauth_encrypter() -> SystemOAuthEncrypter: diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index d8403c2e15..52c16c34a0 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -2,7 +2,7 @@ import mimetypes import re from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Optional, cast +from typing import Any, cast from urllib.parse import unquote import chardet @@ -27,7 +27,7 @@ def page_result(text: str, cursor: int, max_length: int) -> str: return text[cursor : cursor + max_length] -def get_url(url: str, user_agent: Optional[str] = None) -> str: +def get_url(url: str, user_agent: str | None = None) -> str: """Fetch URL and return the contents as a string.""" headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 8a0a91a50c..e9b5dab7d3 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -1,4 +1,5 @@ import logging +from functools import lru_cache from pathlib import Path from typing import Any @@ -8,28 +9,25 @@ from yaml import YAMLError logger = logging.getLogger(__name__) -def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}): - """ - Safe loading a YAML file - :param file_path: the path of the YAML file - :param ignore_error: - if True, return default_value if error occurs and the error will be logged in debug level - if False, raise error if error occurs - :param default_value: the value returned when errors ignored - :return: an object of the YAML content - """ +def _load_yaml_file(*, file_path: str): if not file_path or not Path(file_path).exists(): - if ignore_error: - return default_value - else: - raise FileNotFoundError(f"File not found: {file_path}") + raise FileNotFoundError(f"File not found: {file_path}") with open(file_path, encoding="utf-8") as yaml_file: try: yaml_content = yaml.safe_load(yaml_file) - return yaml_content or default_value + return yaml_content except Exception as e: - if ignore_error: - return default_value - else: - raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e + + +@lru_cache(maxsize=128) +def load_yaml_file_cached(file_path: str) -> Any: + """ + Cached version of load_yaml_file for static configuration files. + Only use for files that don't change during runtime (e.g., position files) + + :param file_path: the path of the YAML file + :return: an object of the YAML content + """ + return _load_yaml_file(file_path=file_path) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 18e6993b38..4d9c8895fc 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -1,5 +1,4 @@ from collections.abc import Mapping -from typing import Optional from pydantic import Field @@ -207,7 +206,7 @@ class WorkflowToolProviderController(ToolProviderController): return self.tools - def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: # type: ignore + def get_tool(self, tool_name: str) -> WorkflowTool | None: # type: ignore """ get tool by name diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index c9d62388f2..6a1ac51528 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from typing import Any, Optional +from typing import Any from sqlalchemy import select @@ -39,7 +39,7 @@ class WorkflowTool(Tool): entity: ToolEntity, runtime: ToolRuntime, label: str = "Workflow", - thread_pool_id: Optional[str] = None, + thread_pool_id: str | None = None, ): self.workflow_app_id = workflow_app_id self.workflow_as_tool_id = workflow_as_tool_id @@ -63,9 +63,9 @@ class WorkflowTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke the tool diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py index ec62be605f..6fce5a83b9 100644 --- a/api/core/workflow/callbacks/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, @@ -37,7 +35,7 @@ _TEXT_COLOR_MAPPING = { class WorkflowLoggingCallback(WorkflowCallback): def __init__(self): - self.current_node_id: Optional[str] = None + self.current_node_id: str | None = None def on_event(self, event: GraphEngineEvent): if isinstance(event, GraphRunStartedEvent): @@ -250,7 +248,7 @@ class WorkflowLoggingCallback(WorkflowCallback): ) self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue") - def print_text(self, text: str, color: Optional[str] = None, end: str = "\n"): + def print_text(self, text: str, color: str | None = None, end: str = "\n"): """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(f"{text_to_print}", end=end) diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 687ec8e47c..d672136d97 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -14,16 +14,16 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING - inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[Mapping[str, Any]] = None # process data - outputs: Optional[Mapping[str, Any]] = None # node outputs - metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata - llm_usage: Optional[LLMUsage] = None # llm usage + inputs: Mapping[str, Any] | None = None # node inputs + process_data: Mapping[str, Any] | None = None # process data + outputs: Mapping[str, Any] | None = None # node outputs + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # node metadata + llm_usage: LLMUsage | None = None # llm usage - edge_source_handle: Optional[str] = None # source handle id of node with multiple branches + edge_source_handle: str | None = None # source handle id of node with multiple branches - error: Optional[str] = None # error message if status is failed - error_type: Optional[str] = None # error type if status is failed + error: str | None = None # error message if status is failed + error_type: str | None = None # error type if status is failed # single step node run retry retry_index: int = 0 diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index f00dc11aa6..2e86605419 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -8,7 +8,7 @@ implementation details like tenant_id, app_id, etc. from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -45,7 +45,7 @@ class WorkflowExecution(BaseModel): graph: Mapping[str, Any] = Field(...) inputs: Mapping[str, Any] = Field(...) - outputs: Optional[Mapping[str, Any]] = None + outputs: Mapping[str, Any] | None = None status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING error_message: str = Field(default="") @@ -54,7 +54,7 @@ class WorkflowExecution(BaseModel): exceptions_count: int = Field(default=0) started_at: datetime = Field(...) - finished_at: Optional[datetime] = None + finished_at: datetime | None = None @property def elapsed_time(self) -> float: diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py index ff72d7cbf3..e00099cda8 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/core/workflow/entities/workflow_node_execution.py @@ -9,7 +9,7 @@ and don't contain implementation details like tenant_id, app_id, etc. from collections.abc import Mapping from datetime import datetime from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -77,41 +77,41 @@ class WorkflowNodeExecution(BaseModel): # NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`. # While `node_execution_id` may sometimes be a UUID string, this is not guaranteed. # In most scenarios, `id` should be used as the primary identifier. - node_execution_id: Optional[str] = None + node_execution_id: str | None = None workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging) + workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging) # --------- Core identification fields ends --------- # Execution positioning and flow index: int # Sequence number for ordering in trace visualization - predecessor_node_id: Optional[str] = None # ID of the node that executed before this one + predecessor_node_id: str | None = None # ID of the node that executed before this one node_id: str # ID of the node being executed node_type: NodeType # Type of node (e.g., start, llm, knowledge) title: str # Display title of the node # Execution data - inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node - process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data - outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node + inputs: Mapping[str, Any] | None = None # Input variables used by this node + process_data: Mapping[str, Any] | None = None # Intermediate processing data + outputs: Mapping[str, Any] | None = None # Output variables produced by this node # Execution state status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status - error: Optional[str] = None # Error message if execution failed + error: str | None = None # Error message if execution failed elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds # Additional metadata - metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.) # Timing information created_at: datetime # When execution started - finished_at: Optional[datetime] = None # When execution completed + finished_at: datetime | None = None # When execution completed def update_from_mapping( self, - inputs: Optional[Mapping[str, Any]] = None, - process_data: Optional[Mapping[str, Any]] = None, - outputs: Optional[Mapping[str, Any]] = None, - metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None, + inputs: Mapping[str, Any] | None = None, + process_data: Mapping[str, Any] | None = None, + outputs: Mapping[str, Any] | None = None, + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, ): """ Update the model from mappings. diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 6e72f8b152..c2865cdb02 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -29,7 +29,7 @@ class GraphRunStartedEvent(BaseGraphEvent): class GraphRunSucceededEvent(BaseGraphEvent): - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None """outputs""" @@ -40,7 +40,7 @@ class GraphRunFailedEvent(BaseGraphEvent): class GraphRunPartialSucceededEvent(BaseGraphEvent): exceptions_count: int = Field(..., description="exception count") - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None ########################################### @@ -54,33 +54,33 @@ class BaseNodeEvent(GraphEngineEvent): node_type: NodeType = Field(..., description="node type") node_data: BaseNodeData = Field(..., description="node data") route_node_state: RouteNodeState = Field(..., description="route node state") - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" # The version of the node, or "1" if not specified. node_version: str = "1" class NodeRunStartedEvent(BaseNodeEvent): - predecessor_node_id: Optional[str] = None + predecessor_node_id: str | None = None """predecessor node id""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration node parallel mode run id""" - agent_strategy: Optional[AgentNodeStrategyInit] = None + agent_strategy: AgentNodeStrategyInit | None = None class NodeRunStreamChunkEvent(BaseNodeEvent): chunk_content: str = Field(..., description="chunk content") - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None """from variable selector""" @@ -125,13 +125,13 @@ class BaseParallelBranchEvent(GraphEngineEvent): """parallel id""" parallel_start_node_id: str = Field(..., description="parallel start node id") """parallel start node id""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -157,45 +157,45 @@ class BaseIterationEvent(GraphEngineEvent): iteration_node_id: str = Field(..., description="iteration node id") iteration_node_type: NodeType = Field(..., description="node type, iteration or loop") iteration_node_data: BaseNodeData = Field(..., description="node data") - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration run in parallel mode run id""" class IterationRunStartedEvent(BaseIterationEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None + inputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None class IterationRunNextEvent(BaseIterationEvent): index: int = Field(..., description="index") - pre_iteration_output: Optional[Any] = None - duration: Optional[float] = None + pre_iteration_output: Any | None = None + duration: float | None = None class IterationRunSucceededEvent(BaseIterationEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 - iteration_duration_map: Optional[dict[str, float]] = None + iteration_duration_map: dict[str, float] | None = None class IterationRunFailedEvent(BaseIterationEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 error: str = Field(..., description="failed reason") @@ -210,45 +210,45 @@ class BaseLoopEvent(GraphEngineEvent): loop_node_id: str = Field(..., description="loop node id") loop_node_type: NodeType = Field(..., description="node type, loop or loop") loop_node_data: BaseNodeData = Field(..., description="node data") - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """loop run in parallel mode run id""" class LoopRunStartedEvent(BaseLoopEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None + inputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None class LoopRunNextEvent(BaseLoopEvent): index: int = Field(..., description="index") - pre_loop_output: Optional[Any] = None - duration: Optional[float] = None + pre_loop_output: Any | None = None + duration: float | None = None class LoopRunSucceededEvent(BaseLoopEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 - loop_duration_map: Optional[dict[str, float]] = None + loop_duration_map: dict[str, float] | None = None class LoopRunFailedEvent(BaseLoopEvent): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 error: str = Field(..., description="failed reason") @@ -270,7 +270,7 @@ class AgentLogEvent(BaseAgentEvent): error: str | None = Field(..., description="error") status: str = Field(..., description="status") data: Mapping[str, Any] = Field(..., description="data") - metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata") + metadata: Mapping[str, Any] | None = Field(default=None, description="metadata") node_id: str = Field(..., description="agent node id") diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index d8d1825d94..bb4a7e1e81 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,7 +1,7 @@ import uuid from collections import defaultdict from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from pydantic import BaseModel, Field @@ -17,18 +17,18 @@ from core.workflow.nodes.end.entities import EndStreamParam class GraphEdge(BaseModel): source_node_id: str = Field(..., description="source node id") target_node_id: str = Field(..., description="target node id") - run_condition: Optional[RunCondition] = None + run_condition: RunCondition | None = None """run condition""" class GraphParallel(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id") start_from_node_id: str = Field(..., description="start from node id") - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id""" - end_to_node_id: Optional[str] = None + end_to_node_id: str | None = None """end to node id""" @@ -54,7 +54,7 @@ class Graph(BaseModel): end_stream_param: EndStreamParam = Field(..., description="end stream param") @classmethod - def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph": + def init(cls, graph_config: Mapping[str, Any], root_node_id: str | None = None) -> "Graph": """ Init graph @@ -253,7 +253,7 @@ class Graph(BaseModel): start_node_id: str, parallel_mapping: dict[str, GraphParallel], node_parallel_mapping: dict[str, str], - parent_parallel: Optional[GraphParallel] = None, + parent_parallel: GraphParallel | None = None, ): """ Recursively add parallel ids @@ -422,9 +422,9 @@ class Graph(BaseModel): cls, parallel_mapping: dict[str, GraphParallel], graph_edge: GraphEdge, - parallel: Optional[GraphParallel] = None, - parent_parallel: Optional[GraphParallel] = None, - ) -> Optional[GraphParallel]: + parallel: GraphParallel | None = None, + parent_parallel: GraphParallel | None = None, + ) -> GraphParallel | None: """ Get current parallel """ diff --git a/api/core/workflow/graph_engine/entities/run_condition.py b/api/core/workflow/graph_engine/entities/run_condition.py index eedce8842b..7b9a379215 100644 --- a/api/core/workflow/graph_engine/entities/run_condition.py +++ b/api/core/workflow/graph_engine/entities/run_condition.py @@ -1,5 +1,5 @@ import hashlib -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -10,10 +10,10 @@ class RunCondition(BaseModel): type: Literal["branch_identify", "condition"] """condition type""" - branch_identify: Optional[str] = None + branch_identify: str | None = None """branch identify like: sourceHandle, required when type is branch_identify""" - conditions: Optional[list[Condition]] = None + conditions: list[Condition] | None = None """conditions to run the node, required when type is condition""" @property diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index 54440df725..c6b8a0b334 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -1,7 +1,6 @@ import uuid from datetime import datetime -from enum import Enum -from typing import Optional +from enum import StrEnum, auto from pydantic import BaseModel, Field @@ -11,12 +10,12 @@ from libs.datetime_utils import naive_utc_now class RouteNodeState(BaseModel): - class Status(Enum): - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - PAUSED = "paused" - EXCEPTION = "exception" + class Status(StrEnum): + RUNNING = auto() + SUCCESS = auto() + FAILED = auto() + PAUSED = auto() + EXCEPTION = auto() id: str = Field(default_factory=lambda: str(uuid.uuid4())) """node state id""" @@ -24,7 +23,7 @@ class RouteNodeState(BaseModel): node_id: str """node id""" - node_run_result: Optional[NodeRunResult] = None + node_run_result: NodeRunResult | None = None """node run result""" status: Status = Status.RUNNING @@ -33,16 +32,16 @@ class RouteNodeState(BaseModel): start_at: datetime """start time""" - paused_at: Optional[datetime] = None + paused_at: datetime | None = None """paused time""" - finished_at: Optional[datetime] = None + finished_at: datetime | None = None """finished time""" - failed_reason: Optional[str] = None + failed_reason: str | None = None """failed reason""" - paused_by: Optional[str] = None + paused_by: str | None = None """paused by""" index: int = 1 diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 9b0b187a7c..bdb8070add 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -6,7 +6,7 @@ import uuid from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy -from typing import Any, Optional, cast +from typing import Any, cast from flask import Flask, current_app @@ -103,7 +103,7 @@ class GraphEngine: graph_runtime_state: GraphRuntimeState, max_execution_steps: int, max_execution_time: int, - thread_pool_id: Optional[str] = None, + thread_pool_id: str | None = None, ): thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT thread_pool_max_workers = 10 @@ -223,9 +223,9 @@ class GraphEngine: def _run( self, start_node_id: str, - in_parallel_id: Optional[str] = None, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + in_parallel_id: str | None = None, + parent_parallel_id: str | None = None, + parent_parallel_start_node_id: str | None = None, handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: parallel_start_node_id = None @@ -233,7 +233,7 @@ class GraphEngine: parallel_start_node_id = start_node_id next_node_id = start_node_id - previous_route_node_state: Optional[RouteNodeState] = None + previous_route_node_state: RouteNodeState | None = None while True: # max steps reached if self.graph_runtime_state.node_run_steps > self.max_execution_steps: @@ -444,8 +444,8 @@ class GraphEngine: def _run_parallel_branches( self, edge_mappings: list[GraphEdge], - in_parallel_id: Optional[str] = None, - parallel_start_node_id: Optional[str] = None, + in_parallel_id: str | None = None, + parallel_start_node_id: str | None = None, handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent | str, None, None]: # if nodes has no run conditions, parallel run all nodes @@ -534,8 +534,8 @@ class GraphEngine: q: queue.Queue, parallel_id: str, parallel_start_node_id: str, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + parent_parallel_id: str | None = None, + parent_parallel_start_node_id: str | None = None, handle_exceptions: list[str] = [], ): """ @@ -600,10 +600,10 @@ class GraphEngine: self, node: BaseNode, route_node_state: RouteNodeState, - parallel_id: Optional[str] = None, - parallel_start_node_id: Optional[str] = None, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + parallel_id: str | None = None, + parallel_start_node_id: str | None = None, + parent_parallel_id: str | None = None, + parent_parallel_start_node_id: str | None = None, handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: """ diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 2e5912652c..c075aa3e64 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from packaging.version import Version from pydantic import ValidationError @@ -69,7 +69,7 @@ class AgentNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = AgentNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -78,7 +78,7 @@ class AgentNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -320,7 +320,7 @@ class AgentNode(BaseNode): memory = self._fetch_memory(model_instance) if memory: prompt_messages = memory.get_history_prompt_messages( - message_limit=node_data.memory.window.size if node_data.memory.window.size else None + message_limit=node_data.memory.window.size or None ) history_prompt_messages = [ prompt_message.model_dump(mode="json") for prompt_message in prompt_messages @@ -401,7 +401,7 @@ class AgentNode(BaseNode): icon = None return icon - def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None: # get conversation id conversation_id_variable = self.graph_runtime_state.variable_pool.get( ["sys", SystemVariableKey.CONVERSATION_ID.value] diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 11b11068e7..ce6eb33ecc 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,4 +1,4 @@ -from enum import Enum, StrEnum +from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union from pydantic import BaseModel @@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData): agent_parameters: dict[str, AgentInput] -class ParamsAutoGenerated(Enum): - CLOSE = 0 - OPEN = 1 +class ParamsAutoGenerated(IntEnum): + CLOSE = auto() + OPEN = auto() class AgentOldVersionModelFeatures(StrEnum): @@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum): TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" - VISION = "vision" + VISION = auto() STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = "document" - VIDEO = "video" - AUDIO = "audio" + DOCUMENT = auto() + VIDEO = auto() + AUDIO = auto() diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py index d5955bdd7d..944f5f0b20 100644 --- a/api/core/workflow/nodes/agent/exc.py +++ b/api/core/workflow/nodes/agent/exc.py @@ -1,6 +1,3 @@ -from typing import Optional - - class AgentNodeError(Exception): """Base exception for all agent node errors.""" @@ -12,7 +9,7 @@ class AgentNodeError(Exception): class AgentStrategyError(AgentNodeError): """Exception raised when there's an error with the agent strategy.""" - def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None): + def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None): self.strategy_name = strategy_name self.provider_name = provider_name super().__init__(message) @@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError): class AgentStrategyNotFoundError(AgentStrategyError): """Exception raised when the specified agent strategy is not found.""" - def __init__(self, strategy_name: str, provider_name: Optional[str] = None): + def __init__(self, strategy_name: str, provider_name: str | None = None): super().__init__( f"Agent strategy '{strategy_name}' not found" + (f" for provider '{provider_name}'" if provider_name else ""), @@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError): class AgentInvocationError(AgentNodeError): """Exception raised when there's an error invoking the agent.""" - def __init__(self, message: str, original_error: Optional[Exception] = None): + def __init__(self, message: str, original_error: Exception | None = None): self.original_error = original_error super().__init__(message) @@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError): class AgentParameterError(AgentNodeError): """Exception raised when there's an error with agent parameters.""" - def __init__(self, message: str, parameter_name: Optional[str] = None): + def __init__(self, message: str, parameter_name: str | None = None): self.parameter_name = parameter_name super().__init__(message) @@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError): class AgentVariableError(AgentNodeError): """Exception raised when there's an error with variables in the agent node.""" - def __init__(self, message: str, variable_name: Optional[str] = None): + def __init__(self, message: str, variable_name: str | None = None): self.variable_name = variable_name super().__init__(message) @@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError): class ToolFileError(AgentNodeError): """Exception raised when there's an error with a tool file.""" - def __init__(self, message: str, file_id: Optional[str] = None): + def __init__(self, message: str, file_id: str | None = None): self.file_id = file_id super().__init__(message) @@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError): class AgentMessageTransformError(AgentNodeError): """Exception raised when there's an error transforming agent messages.""" - def __init__(self, message: str, original_error: Optional[Exception] = None): + def __init__(self, message: str, original_error: Exception | None = None): self.original_error = original_error super().__init__(message) @@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError): class AgentModelError(AgentNodeError): """Exception raised when there's an error with the model used by the agent.""" - def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None): + def __init__(self, message: str, model_name: str | None = None, provider: str | None = None): self.model_name = model_name self.provider = provider super().__init__(message) @@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError): class AgentMemoryError(AgentNodeError): """Exception raised when there's an error with the agent's memory.""" - def __init__(self, message: str, conversation_id: Optional[str] = None): + def __init__(self, message: str, conversation_id: str | None = None): self.conversation_id = conversation_id super().__init__(message) @@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError): def __init__( self, message: str, - variable_name: Optional[str] = None, - expected_type: Optional[str] = None, - actual_type: Optional[str] = None, + variable_name: str | None = None, + expected_type: str | None = None, + actual_type: str | None = None, ): self.variable_name = variable_name self.expected_type = expected_type diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 116250c5ca..184f109127 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from core.variables import ArrayFileSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult @@ -25,7 +25,7 @@ class AnswerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = AnswerNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -34,7 +34,7 @@ class AnswerNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 9e8e1787e5..00eb28b882 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -1,7 +1,6 @@ import logging from abc import ABC, abstractmethod from collections.abc import Generator -from typing import Optional from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent @@ -72,7 +71,7 @@ class StreamProcessor(ABC): for node_id in unreachable_first_node_ids: self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) - def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: + def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: str | None = None) -> list[str]: if node_id not in self.rest_node_ids: self.rest_node_ids.append(node_id) node_ids = [] diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index a05cc44c99..850ff14880 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum +from enum import StrEnum, auto from pydantic import BaseModel, Field @@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel): Generate Route Chunk. """ - class ChunkType(Enum): - VAR = "var" - TEXT = "text" + class ChunkType(StrEnum): + VAR = auto() + TEXT = auto() type: ChunkType = Field(..., description="generate route chunk type") diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 90e45e9d25..c1dac5a1da 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,7 +1,7 @@ import json from abc import ABC from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, model_validator @@ -121,10 +121,10 @@ class RetryConfig(BaseModel): class BaseNodeData(ABC, BaseModel): title: str - desc: Optional[str] = None + desc: str | None = None version: str = "1" - error_strategy: Optional[ErrorStrategy] = None - default_value: Optional[list[DefaultValue]] = None + error_strategy: ErrorStrategy | None = None + default_value: list[DefaultValue] | None = None retry_config: RetryConfig = RetryConfig() @property @@ -135,7 +135,7 @@ class BaseNodeData(ABC, BaseModel): class BaseIterationNodeData(BaseNodeData): - start_node_id: Optional[str] = None + start_node_id: str | None = None class BaseIterationState(BaseModel): @@ -150,7 +150,7 @@ class BaseIterationState(BaseModel): class BaseLoopNodeData(BaseNodeData): - start_node_id: Optional[str] = None + start_node_id: str | None = None class BaseLoopState(BaseModel): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 3aee9b2cc2..0fe8aa5908 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,7 +1,7 @@ import logging from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Union from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -26,8 +26,8 @@ class BaseNode: graph_init_params: "GraphInitParams", graph: "Graph", graph_runtime_state: "GraphRuntimeState", - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None, + previous_node_id: str | None = None, + thread_pool_id: str | None = None, ): self.id = id self.tenant_id = graph_init_params.tenant_id @@ -141,7 +141,7 @@ class BaseNode: return {} @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): return {} @property @@ -170,7 +170,7 @@ class BaseNode: # to BaseNodeData properties in a type-safe way @abstractmethod - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: """Get the error strategy for this node.""" ... @@ -185,7 +185,7 @@ class BaseNode: ... @abstractmethod - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: """Get the node description.""" ... @@ -201,7 +201,7 @@ class BaseNode: # Public interface properties that delegate to abstract methods @property - def error_strategy(self) -> Optional[ErrorStrategy]: + def error_strategy(self) -> ErrorStrategy | None: """Get the error strategy for this node.""" return self._get_error_strategy() @@ -216,7 +216,7 @@ class BaseNode: return self._get_title() @property - def description(self) -> Optional[str]: + def description(self) -> str | None: """Get the node description.""" return self._get_description() diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index d32d868651..d5cf242182 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from decimal import Decimal -from typing import Any, Optional +from typing import Any from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage @@ -31,7 +31,7 @@ class CodeNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = CodeNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -40,7 +40,7 @@ class CodeNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -50,7 +50,7 @@ class CodeNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): """ Get default config of node. :param filters: filter by node config parameters. @@ -161,7 +161,7 @@ class CodeNode(BaseNode): def _transform_result( self, result: Mapping[str, Any], - output_schema: Optional[dict[str, CodeNodeData.Output]], + output_schema: dict[str, CodeNodeData.Output] | None, prefix: str = "", depth: int = 1, ): diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 9d380c6fb6..ab23e0ae83 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Literal, Optional +from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel @@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: Optional[dict[str, "CodeNodeData.Output"]] = None + children: dict[str, "CodeNodeData.Output"] | None = None class Dependency(BaseModel): name: str @@ -44,4 +44,4 @@ class CodeNodeData(BaseNodeData): code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str outputs: dict[str, Output] - dependencies: Optional[list[Dependency]] = None + dependencies: list[Dependency] | None = None diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 7848bab446..b488fec84a 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,7 +5,7 @@ import logging import os import tempfile from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any import chardet import docx @@ -50,7 +50,7 @@ class DocumentExtractorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = DocumentExtractorNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -59,7 +59,7 @@ class DocumentExtractorNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 0aff039e92..b49fdc141f 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -17,7 +17,7 @@ class EndNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = EndNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -26,7 +26,7 @@ class EndNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 8d7ba25d47..5a7db6e0e6 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,7 +1,7 @@ import mimetypes from collections.abc import Sequence from email.message import Message -from typing import Any, Literal, Optional +from typing import Any, Literal import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator @@ -18,7 +18,7 @@ class HttpRequestNodeAuthorizationConfig(BaseModel): class HttpRequestNodeAuthorization(BaseModel): type: Literal["no-auth", "api-key"] - config: Optional[HttpRequestNodeAuthorizationConfig] = None + config: HttpRequestNodeAuthorizationConfig | None = None @field_validator("config", mode="before") @classmethod @@ -88,9 +88,9 @@ class HttpRequestNodeData(BaseNodeData): authorization: HttpRequestNodeAuthorization headers: str params: str - body: Optional[HttpRequestNodeBody] = None - timeout: Optional[HttpRequestNodeTimeout] = None - ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY + body: HttpRequestNodeBody | None = None + timeout: HttpRequestNodeTimeout | None = None + ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY class Response: @@ -183,7 +183,7 @@ class Response: return f"{(self.size / 1024 / 1024):.2f} MB" @property - def parsed_content_disposition(self) -> Optional[Message]: + def parsed_content_disposition(self) -> Message | None: content_disposition = self.headers.get("content-disposition", "") if content_disposition: msg = Message() diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index bb3c453d99..837cf883c8 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,7 +1,7 @@ import logging import mimetypes from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from configs import dify_config from core.file import File, FileTransferMethod @@ -41,7 +41,7 @@ class HttpRequestNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = HttpRequestNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -50,7 +50,7 @@ class HttpRequestNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -60,7 +60,7 @@ class HttpRequestNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict[str, Any]] = None): + def get_default_config(cls, filters: dict[str, Any] | None = None): return { "type": "http-request", "config": { diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 67d6d6a886..b22bd6f508 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field @@ -20,7 +20,7 @@ class IfElseNodeData(BaseNodeData): logical_operator: Literal["and", "or"] conditions: list[Condition] - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) - cases: Optional[list[Case]] = None + cases: list[Case] | None = None diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 82dba59cbe..857b1c6f44 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal from typing_extensions import deprecated @@ -22,7 +22,7 @@ class IfElseNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IfElseNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -31,7 +31,7 @@ class IfElseNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 7a489dd725..9608edb06e 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import Field @@ -17,7 +17,7 @@ class IterationNodeData(BaseIterationNodeData): Iteration Node Data. """ - parent_loop_id: Optional[str] = None # redundant field, not used currently + parent_loop_id: str | None = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector is_parallel: bool = False # open the parallel mode or not @@ -39,7 +39,7 @@ class IterationState(BaseIterationState): """ outputs: list[Any] = Field(default_factory=list) - current_output: Optional[Any] = None + current_output: Any | None = None class MetaData(BaseIterationState.MetaData): """ @@ -48,7 +48,7 @@ class IterationState(BaseIterationState): iterator_length: int - def get_last_output(self) -> Optional[Any]: + def get_last_output(self) -> Any | None: """ Get last output. """ @@ -56,7 +56,7 @@ class IterationState(BaseIterationState): return self.outputs[-1] return None - def get_current_output(self) -> Optional[Any]: + def get_current_output(self) -> Any | None: """ Get current output. """ diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 52eb7fdd75..2cf59bc2fb 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -6,7 +6,7 @@ from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, wait from datetime import datetime from queue import Empty, Queue -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from flask import Flask, current_app @@ -70,7 +70,7 @@ class IterationNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IterationNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -79,7 +79,7 @@ class IterationNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -89,7 +89,7 @@ class IterationNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): return { "type": "iteration", "config": { @@ -424,7 +424,7 @@ class IterationNode(BaseNode): graph_engine: "GraphEngine", iteration_graph: Graph, iter_run_map: dict[str, float], - parallel_mode_run_id: Optional[str] = None, + parallel_mode_run_id: str | None = None, ) -> Generator[NodeEvent | InNodeEvent, None, None]: """ run single iteration diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 8c4794cf37..1a6c9fa908 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -21,7 +21,7 @@ class IterationStartNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IterationStartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -30,7 +30,7 @@ class IterationStartNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index b71271abeb..8aa6a5016f 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field @@ -49,11 +49,11 @@ class MultipleRetrievalConfig(BaseModel): """ top_k: int - score_threshold: Optional[float] = None + score_threshold: float | None = None reranking_mode: str = "reranking_model" reranking_enable: bool = True - reranking_model: Optional[RerankingModelConfig] = None - weights: Optional[WeightedScoreConfig] = None + reranking_model: RerankingModelConfig | None = None + weights: WeightedScoreConfig | None = None class SingleRetrievalConfig(BaseModel): @@ -91,7 +91,7 @@ SupportedComparisonOperator = Literal[ class Condition(BaseModel): """ - Conditon detail + Condition detail """ name: str @@ -104,8 +104,8 @@ class MetadataFilteringCondition(BaseModel): Metadata Filtering Condition. """ - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) class KnowledgeRetrievalNodeData(BaseNodeData): @@ -117,11 +117,11 @@ class KnowledgeRetrievalNodeData(BaseNodeData): query_variable_selector: list[str] dataset_ids: list[str] retrieval_mode: Literal["single", "multiple"] - multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None - single_retrieval_config: Optional[SingleRetrievalConfig] = None - metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_model_config: Optional[ModelConfig] = None - metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None + multiple_retrieval_config: MultipleRetrievalConfig | None = None + single_retrieval_config: SingleRetrievalConfig | None = None + metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled" + metadata_model_config: ModelConfig | None = None + metadata_filtering_conditions: MetadataFilteringCondition | None = None vision: VisionConfig = Field(default_factory=VisionConfig) @property diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index d357fea7dd..99e1ba6d28 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,7 +4,7 @@ import re import time from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import Float, and_, func, or_, select, text from sqlalchemy import cast as sqlalchemy_cast @@ -101,8 +101,8 @@ class KnowledgeRetrievalNode(BaseNode): graph_init_params: "GraphInitParams", graph: "Graph", graph_runtime_state: "GraphRuntimeState", - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None, + previous_node_id: str | None = None, + thread_pool_id: str | None = None, *, llm_file_saver: LLMFileSaver | None = None, ): @@ -128,7 +128,7 @@ class KnowledgeRetrievalNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = KnowledgeRetrievalNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -137,7 +137,7 @@ class KnowledgeRetrievalNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -259,7 +259,7 @@ class KnowledgeRetrievalNode(BaseNode): ) all_documents = [] dataset_retrieval = DatasetRetrieval() - if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: + if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # fetch model config if node_data.single_retrieval_config is None: raise ValueError("single_retrieval_config is required") @@ -291,7 +291,7 @@ class KnowledgeRetrievalNode(BaseNode): metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) - elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: + elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": @@ -419,7 +419,7 @@ class KnowledgeRetrievalNode(BaseNode): def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: + ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: document_query = db.session.query(Document).where( Document.dataset_id.in_(dataset_ids), Document.indexing_status == "completed", @@ -576,7 +576,7 @@ class KnowledgeRetrievalNode(BaseNode): return automatic_metadata_filters def _process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list + self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list ): if value is None and condition not in ("empty", "not empty"): return diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index cf46870254..8a6d3d0c5a 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Mapping, Sequence -from typing import Any, Optional, TypeAlias, TypeVar +from typing import Any, TypeAlias, TypeVar from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -44,7 +44,7 @@ class ListOperatorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = ListOperatorNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -53,7 +53,7 @@ class ListOperatorNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -178,7 +178,7 @@ class ListOperatorNode(BaseNode): result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) else: - raise AssertionError("this statment should be unreachable.") + raise AssertionError("this statement should be unreachable.") return variable def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: @@ -191,7 +191,7 @@ class ListOperatorNode(BaseNode): ) variable = variable.model_copy(update={"value": result}) else: - raise AssertionError("this statement should be unreachable") + raise AssertionError("this statement should be unreachable.") return variable diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 222914351e..3dfb1ce28e 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator @@ -18,7 +18,7 @@ class ModelConfig(BaseModel): class ContextConfig(BaseModel): enabled: bool - variable_selector: Optional[list[str]] = None + variable_selector: list[str] | None = None class VisionConfigOptions(BaseModel): @@ -51,18 +51,18 @@ class PromptConfig(BaseModel): class LLMNodeChatModelMessage(ChatModelMessage): text: str = "" - jinja2_text: Optional[str] = None + jinja2_text: str | None = None class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - jinja2_text: Optional[str] = None + jinja2_text: str | None = None class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_config: PromptConfig = Field(default_factory=PromptConfig) - memory: Optional[MemoryConfig] = None + memory: MemoryConfig | None = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) structured_output: Mapping[str, Any] | None = None diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index fae127ab76..ce6bb441ab 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Optional, cast +from typing import cast from sqlalchemy import select, update from sqlalchemy.orm import Session @@ -86,8 +86,8 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc def fetch_memory( - variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance -) -> Optional[TokenBufferMemory]: + variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance +) -> TokenBufferMemory | None: if not node_data_memory: return None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index fdcdac1ec2..9ae4f275fb 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -4,7 +4,7 @@ import json import logging import re from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager @@ -116,8 +116,8 @@ class LLMNode(BaseNode): graph_init_params: "GraphInitParams", graph: "Graph", graph_runtime_state: "GraphRuntimeState", - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None, + previous_node_id: str | None = None, + thread_pool_id: str | None = None, *, llm_file_saver: LLMFileSaver | None = None, ): @@ -143,7 +143,7 @@ class LLMNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LLMNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -152,7 +152,7 @@ class LLMNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -166,7 +166,7 @@ class LLMNode(BaseNode): return "1" def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: - node_inputs: Optional[dict[str, Any]] = None + node_inputs: dict[str, Any] | None = None process_data = None result_text = "" usage = LLMUsage.empty_usage() @@ -353,10 +353,10 @@ class LLMNode(BaseNode): node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], - stop: Optional[Sequence[str]] = None, + stop: Sequence[str] | None = None, user_id: str, structured_output_enabled: bool, - structured_output: Optional[Mapping[str, Any]] = None, + structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, file_outputs: list["File"], node_id: str, @@ -708,7 +708,7 @@ class LLMNode(BaseNode): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, - ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: + ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] if isinstance(prompt_template, list): @@ -951,7 +951,7 @@ class LLMNode(BaseNode): return variable_mapping @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): return { "type": "llm", "config": { @@ -979,7 +979,7 @@ class LLMNode(BaseNode): def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], - context: Optional[str], + context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, @@ -1174,7 +1174,7 @@ class LLMNode(BaseNode): def _combine_message_content_with_role( - *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole + *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole ): match role: case PromptMessageRole.USER: @@ -1280,7 +1280,7 @@ def _handle_memory_completion_mode( def _handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, - context: Optional[str], + context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, ) -> Sequence[PromptMessage]: diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 3ed4d21ba5..c875b4202e 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Field @@ -35,7 +35,7 @@ class LoopVariableData(BaseModel): label: str var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] value_type: Literal["variable", "constant"] - value: Optional[Any | list[str]] = None + value: Any | list[str] | None = None class LoopNodeData(BaseLoopNodeData): @@ -46,8 +46,8 @@ class LoopNodeData(BaseLoopNodeData): loop_count: int # Maximum number of loops break_conditions: list[Condition] # Conditions to break the loop logical_operator: Literal["and", "or"] - loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData]) - outputs: Optional[Mapping[str, Any]] = None + loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData]) + outputs: Mapping[str, Any] | None = None class LoopStartNodeData(BaseNodeData): @@ -72,7 +72,7 @@ class LoopState(BaseLoopState): """ outputs: list[Any] = Field(default_factory=list) - current_output: Optional[Any] = None + current_output: Any | None = None class MetaData(BaseLoopState.MetaData): """ @@ -81,7 +81,7 @@ class LoopState(BaseLoopState): loop_length: int - def get_last_output(self) -> Optional[Any]: + def get_last_output(self) -> Any | None: """ Get last output. """ @@ -89,7 +89,7 @@ class LoopState(BaseLoopState): return self.outputs[-1] return None - def get_current_output(self) -> Optional[Any]: + def get_current_output(self) -> Any | None: """ Get current output. """ diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 892ae88b04..e2940ae004 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -21,7 +21,7 @@ class LoopEndNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopEndNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -30,7 +30,7 @@ class LoopEndNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 2fe3fb5567..753963dc90 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -3,7 +3,7 @@ import logging import time from collections.abc import Generator, Mapping, Sequence from datetime import datetime -from typing import TYPE_CHECKING, Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, cast from configs import dify_config from core.variables import ( @@ -57,7 +57,7 @@ class LoopNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -66,7 +66,7 @@ class LoopNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index f5a20fc009..07e98a494f 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -21,7 +21,7 @@ class LoopStartNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopStartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -30,7 +30,7 @@ class LoopStartNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 2739140224..2dc0aabe3c 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any, Literal from pydantic import ( BaseModel, @@ -50,7 +50,7 @@ class ParameterConfig(BaseModel): name: str type: Annotated[SegmentType, BeforeValidator(_validate_type)] - options: Optional[list[str]] = None + options: list[str] | None = None description: str required: bool @@ -88,8 +88,8 @@ class ParameterExtractorNodeData(BaseNodeData): model: ModelConfig query: list[str] parameters: list[ParameterConfig] - instruction: Optional[str] = None - memory: Optional[MemoryConfig] = None + instruction: str | None = None + memory: MemoryConfig | None = None reasoning_mode: Literal["function_call", "prompt"] vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 1e1c10a11a..51d9a2d2e9 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -3,7 +3,7 @@ import json import logging import uuid from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import File @@ -98,7 +98,7 @@ class ParameterExtractorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = ParameterExtractorNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -107,7 +107,7 @@ class ParameterExtractorNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -116,11 +116,11 @@ class ParameterExtractorNode(BaseNode): def get_base_node_data(self) -> BaseNodeData: return self._node_data - _model_instance: Optional[ModelInstance] = None - _model_config: Optional[ModelConfigWithCredentialsEntity] = None + _model_instance: ModelInstance | None = None + _model_config: ModelConfigWithCredentialsEntity | None = None @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): return { "model": { "prompt_templates": { @@ -295,7 +295,7 @@ class ParameterExtractorNode(BaseNode): prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], stop: list[str], - ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: + ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=node_data_model.completion_params, @@ -330,9 +330,9 @@ class ParameterExtractorNode(BaseNode): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. @@ -412,9 +412,9 @@ class ParameterExtractorNode(BaseNode): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate prompt engineering prompt. @@ -450,9 +450,9 @@ class ParameterExtractorNode(BaseNode): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate completion prompt. @@ -484,9 +484,9 @@ class ParameterExtractorNode(BaseNode): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate chat prompt. @@ -657,7 +657,7 @@ class ParameterExtractorNode(BaseNode): return transformed_result - def _extract_complete_json_response(self, result: str) -> Optional[dict]: + def _extract_complete_json_response(self, result: str) -> dict | None: """ Extract complete json response. """ @@ -672,7 +672,7 @@ class ParameterExtractorNode(BaseNode): logger.info("extra error: %s", result) return None - def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: + def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None: """ Extract json from tool call. """ @@ -711,7 +711,7 @@ class ParameterExtractorNode(BaseNode): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ) -> list[ChatModelMessage]: model_mode = ModelMode(node_data.model.mode) @@ -738,7 +738,7 @@ class ParameterExtractorNode(BaseNode): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ): model_mode = ModelMode(node_data.model.mode) @@ -774,7 +774,7 @@ class ParameterExtractorNode(BaseNode): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - context: Optional[str], + context: str | None, ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index 6248df0edf..edde30708a 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig @@ -16,8 +14,8 @@ class QuestionClassifierNodeData(BaseNodeData): query_variable_selector: list[str] model: ModelConfig classes: list[ClassConfig] - instruction: Optional[str] = None - memory: Optional[MemoryConfig] = None + instruction: str | None = None + memory: MemoryConfig | None = None vision: VisionConfig = Field(default_factory=VisionConfig) @property diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 07fb658f24..b15193ecde 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -59,8 +59,8 @@ class QuestionClassifierNode(BaseNode): graph_init_params: "GraphInitParams", graph: "Graph", graph_runtime_state: "GraphRuntimeState", - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None, + previous_node_id: str | None = None, + thread_pool_id: str | None = None, *, llm_file_saver: LLMFileSaver | None = None, ): @@ -86,7 +86,7 @@ class QuestionClassifierNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = QuestionClassifierNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -95,7 +95,7 @@ class QuestionClassifierNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -275,7 +275,7 @@ class QuestionClassifierNode(BaseNode): return variable_mapping @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): """ Get default config of node. :param filters: filter by node config parameters. @@ -288,7 +288,7 @@ class QuestionClassifierNode(BaseNode): node_data: QuestionClassifierNodeData, query: str, model_config: ModelConfigWithCredentialsEntity, - context: Optional[str], + context: str | None, ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_template = self._get_prompt_template(node_data, query, None, 2000) @@ -331,7 +331,7 @@ class QuestionClassifierNode(BaseNode): self, node_data: QuestionClassifierNodeData, query: str, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ): model_mode = ModelMode(node_data.model.mode) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 6052774e6c..5015d59ccc 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunResult @@ -18,7 +18,7 @@ class StartNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = StartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -27,7 +27,7 @@ class StartNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 5588463a36..761854045c 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,6 +1,6 @@ import os from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult @@ -21,7 +21,7 @@ class TemplateTransformNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = TemplateTransformNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -30,7 +30,7 @@ class TemplateTransformNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -40,7 +40,7 @@ class TemplateTransformNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict | None = None): """ Get default config of node. :param filters: filter by node config parameters. diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index c4caf5d83b..53632f43c6 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional +from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session @@ -439,7 +439,7 @@ class ToolNode(BaseNode): return result - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -448,7 +448,7 @@ class ToolNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index f4577d7573..13dbc5dbe6 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel from core.variables.types import SegmentType @@ -33,4 +31,4 @@ class VariableAssignerNodeData(BaseNodeData): type: str = "variable-assigner" output_type: str variables: list[list[str]] - advanced_settings: Optional[AdvancedSettings] = None + advanced_settings: AdvancedSettings | None = None diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index cc5092d0a9..1c1817496f 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.variables.segments import Segment from core.workflow.entities.node_entities import NodeRunResult @@ -18,7 +18,7 @@ class VariableAggregatorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -27,7 +27,7 @@ class VariableAggregatorNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 263c5a3893..8cf9e82d3b 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias from core.variables import SegmentType, Variable from core.variables.segments import BooleanSegment @@ -33,7 +33,7 @@ class VariableAssignerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -42,7 +42,7 @@ class VariableAssignerNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -58,8 +58,8 @@ class VariableAssignerNode(BaseNode): graph_init_params: "GraphInitParams", graph: "Graph", graph_runtime_state: "GraphRuntimeState", - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None, + previous_node_id: str | None = None, + thread_pool_id: str | None = None, conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, ): super().__init__( diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index fdd155adf9..9915b842f7 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable @@ -61,7 +61,7 @@ class VariableAssignerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -70,7 +70,7 @@ class VariableAssignerNode(BaseNode): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py index f4668c05c5..8148934b0e 100644 --- a/api/core/workflow/repositories/workflow_node_execution_repository.py +++ b/api/core/workflow/repositories/workflow_node_execution_repository.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Literal, Optional, Protocol +from typing import Literal, Protocol from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution @@ -10,7 +10,7 @@ class OrderConfig: """Configuration for ordering NodeExecution instances.""" order_by: list[str] - order_direction: Optional[Literal["asc", "desc"]] = None + order_direction: Literal["asc", "desc"] | None = None class WorkflowNodeExecutionRepository(Protocol): @@ -42,7 +42,7 @@ class WorkflowNodeExecutionRepository(Protocol): def get_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 4f259b64a2..0410b843b9 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional, Union +from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( @@ -83,9 +83,9 @@ class WorkflowCycleManager: total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, - external_trace_id: Optional[str] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) @@ -110,9 +110,9 @@ class WorkflowCycleManager: total_steps: int, outputs: Mapping[str, Any] | None = None, exceptions_count: int = 0, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, - external_trace_id: Optional[str] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) @@ -138,10 +138,10 @@ class WorkflowCycleManager: total_steps: int, status: WorkflowExecutionStatus, error_message: str, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, exceptions_count: int = 0, - external_trace_id: Optional[str] = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) now = naive_utc_now() @@ -296,9 +296,9 @@ class WorkflowCycleManager: total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, - error_message: Optional[str] = None, + error_message: str | None = None, exceptions_count: int = 0, - finished_at: Optional[datetime] = None, + finished_at: datetime | None = None, ): """Update workflow execution with completion data.""" execution.status = status @@ -312,10 +312,10 @@ class WorkflowCycleManager: def _add_trace_task_if_needed( self, - trace_manager: Optional[TraceQueueManager], + trace_manager: TraceQueueManager | None, workflow_execution: WorkflowExecution, - conversation_id: Optional[str], - external_trace_id: Optional[str], + conversation_id: str | None, + external_trace_id: str | None, ): """Add trace task if trace manager is provided.""" if trace_manager: @@ -357,8 +357,8 @@ class WorkflowCycleManager: workflow_execution: WorkflowExecution, event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent], status: WorkflowNodeExecutionStatus, - error: Optional[str] = None, - created_at: Optional[datetime] = None, + error: str | None = None, + created_at: datetime | None = None, ) -> WorkflowNodeExecution: """Create a node execution from an event.""" now = naive_utc_now() @@ -404,7 +404,7 @@ class WorkflowCycleManager: QueueNodeExceptionEvent, ], status: WorkflowNodeExecutionStatus, - error: Optional[str] = None, + error: str | None = None, handle_special_values: bool = False, ): """Update node execution with completion data.""" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index b69a9971b5..ecad75b1ca 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional +from typing import Any from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError @@ -47,7 +47,7 @@ class WorkflowEntry: invoke_from: InvokeFrom, call_depth: int, variable_pool: VariablePool, - thread_pool_id: Optional[str] = None, + thread_pool_id: str | None = None, ): """ Init workflow entry @@ -311,7 +311,7 @@ class WorkflowEntry: raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) @staticmethod - def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: + def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: # NOTE(QuantumGhost): Avoid using this function in new code. # Keep values structured as long as possible and only convert to dict # immediately before serialization (e.g., JSON serialization) to maintain @@ -367,7 +367,7 @@ class WorkflowEntry: raise ValueError(f"Variable key {node_variable} not found in user inputs.") # environment variable already exist in variable pool, not from user inputs - if variable_pool.get(variable_selector): + if variable_pool.get(variable_selector) and variable_selector[0] == ENVIRONMENT_VARIABLE_NODE_ID: continue # fetch variable node id from variable selector diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 21fd0b9c5b..c318684b2f 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -1,7 +1,7 @@ import logging import time as time_module from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel from sqlalchemy import update @@ -33,7 +33,7 @@ def _get_provider_cache_key(tenant_id: str, provider_name: str) -> str: @redis_fallback(default_return=None) -def _get_last_update_timestamp(cache_key: str) -> Optional[datetime]: +def _get_last_update_timestamp(cache_key: str) -> datetime | None: """Get last update timestamp from Redis cache.""" timestamp_str = redis_client.get(cache_key) if timestamp_str: @@ -52,8 +52,8 @@ class _ProviderUpdateFilters(BaseModel): tenant_id: str provider_name: str - provider_type: Optional[str] = None - quota_type: Optional[str] = None + provider_type: str | None = None + quota_type: str | None = None class _ProviderUpdateAdditionalFilters(BaseModel): @@ -65,8 +65,8 @@ class _ProviderUpdateAdditionalFilters(BaseModel): class _ProviderUpdateValues(BaseModel): """Values to update in Provider records.""" - last_used: Optional[datetime] = None - quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression + last_used: datetime | None = None + quota_used: Any | None = None # Can be Provider.quota_used + int expression class _ProviderUpdateOperation(BaseModel): @@ -182,7 +182,7 @@ def handle(sender: Message, **kwargs): def _calculate_quota_usage( *, message: Message, system_configuration: SystemConfiguration, model_name: str -) -> Optional[int]: +) -> int | None: """Calculate quota usage based on message tokens and quota type.""" quota_unit = None for quota_configuration in system_configuration.quota_configurations: diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index fb5352ca8f..585539e2ce 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,6 +1,6 @@ import ssl from datetime import timedelta -from typing import Any, Optional +from typing import Any import pytz from celery import Celery, Task @@ -10,7 +10,7 @@ from configs import dify_config from dify_app import DifyApp -def _get_celery_ssl_options() -> Optional[dict[str, Any]]: +def _get_celery_ssl_options() -> dict[str, Any] | None: """Get SSL configuration for Celery broker/backend connections.""" # Use REDIS_USE_SSL for consistency with the main Redis client # Only apply SSL if we're using Redis as broker/backend @@ -141,9 +141,7 @@ def init_app(app: DifyApp) -> Celery: imports.append("schedule.queue_monitor_task") beat_schedule["datasets-queue-monitor"] = { "task": "schedule.queue_monitor_task.queue_monitor_task", - "schedule": timedelta( - minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30 - ), + "schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30), } if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: imports.append("schedule.check_upgradable_plugin_task") diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index 58ab023559..042bf8cc47 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from flask import Flask @@ -68,7 +67,7 @@ class Mail: case _: raise ValueError(f"Unsupported mail type {mail_type}") - def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): + def send(self, to: str, subject: str, html: str, from_: str | None = None): if not self._client: raise ValueError("Mail client is not initialized") diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 61b26b5b95..487917b2a7 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -3,7 +3,7 @@ import logging import ssl from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import redis from redis import RedisError @@ -246,7 +246,7 @@ def init_app(app: DifyApp): app.extensions["redis"] = redis_client -def redis_fallback(default_return: Optional[Any] = None): +def redis_fallback(default_return: Any | None = None): """ decorator to handle Redis operation exceptions and return a default value when Redis is unavailable. diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 7ec0889776..9053aece89 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -1,6 +1,5 @@ from collections.abc import Generator from datetime import timedelta -from typing import Optional from azure.identity import ChainedTokenCredential, DefaultAzureCredential from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas @@ -21,7 +20,7 @@ class AzureBlobStorage(BaseStorage): self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY - self.credential: Optional[ChainedTokenCredential] = None + self.credential: ChainedTokenCredential | None = None if self.account_key == "managedidentity": self.credential = DefaultAzureCredential() else: diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 33fa7d0a8d..2ffac9a92d 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -10,7 +10,6 @@ import tempfile from collections.abc import Generator from io import BytesIO from pathlib import Path -from typing import Optional import clickzetta # type: ignore[import] from pydantic import BaseModel, model_validator @@ -33,7 +32,7 @@ class ClickZettaVolumeConfig(BaseModel): vcluster: str = "default_ap" schema_name: str = "dify" volume_type: str = "table" # table|user|external - volume_name: Optional[str] = None # For external volumes + volume_name: str | None = None # For external volumes table_prefix: str = "dataset_" # Prefix for table volume names dify_prefix: str = "dify_km" # Directory prefix for User Volume permission_check: bool = True # Enable/disable permission checking @@ -154,7 +153,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.exception("Failed to initialize permission manager") raise - def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str: + def _get_volume_path(self, filename: str, dataset_id: str | None = None) -> str: """Get the appropriate volume path based on volume type.""" if self._config.volume_type == "user": # Add dify prefix for User Volume to organize files @@ -179,7 +178,7 @@ class ClickZettaVolumeStorage(BaseStorage): else: raise ValueError(f"Unsupported volume type: {self._config.volume_type}") - def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str: + def _get_volume_sql_prefix(self, dataset_id: str | None = None) -> str: """Get SQL prefix for volume operations.""" if self._config.volume_type == "user": return "USER VOLUME" diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index ef6b12fd59..6ab02ad8cc 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -7,21 +7,22 @@ Supports complete lifecycle management for knowledge base files. import json import logging +import operator from dataclasses import asdict, dataclass from datetime import datetime -from enum import Enum -from typing import Any, Optional +from enum import StrEnum, auto +from typing import Any logger = logging.getLogger(__name__) -class FileStatus(Enum): +class FileStatus(StrEnum): """File status enumeration""" - ACTIVE = "active" # Active status - ARCHIVED = "archived" # Archived - DELETED = "deleted" # Deleted (soft delete) - BACKUP = "backup" # Backup file + ACTIVE = auto() # Active status + ARCHIVED = auto() # Archived + DELETED = auto() # Deleted (soft delete) + BACKUP = auto() # Backup file @dataclass @@ -34,9 +35,9 @@ class FileMetadata: modified_at: datetime version: int | None status: FileStatus - checksum: Optional[str] = None - tags: Optional[dict[str, str]] = None - parent_version: Optional[int] = None + checksum: str | None = None + tags: dict[str, str] | None = None + parent_version: int | None = None def to_dict(self): """Convert to dictionary format""" @@ -59,7 +60,7 @@ class FileMetadata: class FileLifecycleManager: """File lifecycle manager""" - def __init__(self, storage, dataset_id: Optional[str] = None): + def __init__(self, storage, dataset_id: str | None = None): """Initialize lifecycle manager Args: @@ -74,9 +75,9 @@ class FileLifecycleManager: self._deleted_prefix = ".deleted/" # Get permission manager (if exists) - self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None) + self._permission_manager: Any | None = getattr(storage, "_permission_manager", None) - def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata: + def save_with_lifecycle(self, filename: str, data: bytes, tags: dict[str, str] | None = None) -> FileMetadata: """Save file and manage lifecycle Args: @@ -150,7 +151,7 @@ class FileLifecycleManager: logger.exception("Failed to save file with lifecycle") raise - def get_file_metadata(self, filename: str) -> Optional[FileMetadata]: + def get_file_metadata(self, filename: str) -> FileMetadata | None: """Get file metadata Args: @@ -356,7 +357,7 @@ class FileLifecycleManager: # Cleanup old versions for each file for base_filename, versions in file_versions.items(): # Sort by version number - versions.sort(key=lambda x: x[0], reverse=True) + versions.sort(key=operator.itemgetter(0), reverse=True) # Keep the newest max_versions versions, delete the rest if len(versions) > max_versions: diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index 243df92efe..eb1116638f 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -5,13 +5,12 @@ According to ClickZetta's permission model, different Volume types have differen """ import logging -from enum import Enum -from typing import Optional +from enum import StrEnum logger = logging.getLogger(__name__) -class VolumePermission(Enum): +class VolumePermission(StrEnum): """Volume permission type enumeration""" READ = "SELECT" # Corresponds to ClickZetta's SELECT permission @@ -24,7 +23,7 @@ class VolumePermission(Enum): class VolumePermissionManager: """Volume permission manager""" - def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None): + def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None): """Initialize permission manager Args: @@ -63,7 +62,7 @@ class VolumePermissionManager: self._permission_cache: dict[str, set[str]] = {} self._current_username = None # Will get current username from connection - def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool: + def check_permission(self, operation: VolumePermission, dataset_id: str | None = None) -> bool: """Check if user has permission to perform specific operation Args: @@ -126,7 +125,7 @@ class VolumePermissionManager: logger.info("User Volume permission check failed, but permission checking is disabled in this version") return False - def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool: + def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: str | None) -> bool: """Check Table Volume permission Table Volume permission rules: @@ -440,7 +439,7 @@ class VolumePermissionManager: self._permission_cache.clear() logger.debug("Permission cache cleared") - def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]: + def get_permission_summary(self, dataset_id: str | None = None) -> dict[str, bool]: """Get permission summary Args: @@ -582,7 +581,7 @@ class VolumePermissionManager: return any(pattern in file_path_lower for pattern in sensitive_patterns) - def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool: + def validate_operation(self, operation: str, dataset_id: str | None = None) -> bool: """Validate operation permission Args: @@ -614,16 +613,14 @@ class VolumePermissionManager: class VolumePermissionError(Exception): """Volume permission error exception""" - def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None): + def __init__(self, message: str, operation: str, volume_type: str, dataset_id: str | None = None): self.operation = operation self.volume_type = volume_type self.dataset_id = dataset_id super().__init__(message) -def check_volume_permission( - permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None -): +def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None): """Permission check decorator function Args: diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index 3c039dff53..37ff1a438e 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -7,8 +7,8 @@ eliminates the need for repetitive language switching logic. """ from dataclasses import dataclass -from enum import Enum -from typing import Any, Optional, Protocol +from enum import StrEnum, auto +from typing import Any, Protocol from flask import render_template from pydantic import BaseModel, Field @@ -17,26 +17,30 @@ from extensions.ext_mail import mail from services.feature_service import BrandingModel, FeatureService -class EmailType(Enum): +class EmailType(StrEnum): """Enumeration of supported email types.""" - RESET_PASSWORD = "reset_password" - INVITE_MEMBER = "invite_member" - EMAIL_CODE_LOGIN = "email_code_login" - CHANGE_EMAIL_OLD = "change_email_old" - CHANGE_EMAIL_NEW = "change_email_new" - CHANGE_EMAIL_COMPLETED = "change_email_completed" - OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm" - OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify" - OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify" - ACCOUNT_DELETION_SUCCESS = "account_deletion_success" - ACCOUNT_DELETION_VERIFICATION = "account_deletion_verification" - ENTERPRISE_CUSTOM = "enterprise_custom" - QUEUE_MONITOR_ALERT = "queue_monitor_alert" - DOCUMENT_CLEAN_NOTIFY = "document_clean_notify" + RESET_PASSWORD = auto() + RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = auto() + INVITE_MEMBER = auto() + EMAIL_CODE_LOGIN = auto() + CHANGE_EMAIL_OLD = auto() + CHANGE_EMAIL_NEW = auto() + CHANGE_EMAIL_COMPLETED = auto() + OWNER_TRANSFER_CONFIRM = auto() + OWNER_TRANSFER_OLD_NOTIFY = auto() + OWNER_TRANSFER_NEW_NOTIFY = auto() + ACCOUNT_DELETION_SUCCESS = auto() + ACCOUNT_DELETION_VERIFICATION = auto() + ENTERPRISE_CUSTOM = auto() + QUEUE_MONITOR_ALERT = auto() + DOCUMENT_CLEAN_NOTIFY = auto() + EMAIL_REGISTER = auto() + EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto() + RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto() -class EmailLanguage(Enum): +class EmailLanguage(StrEnum): """Supported email languages with fallback handling.""" EN_US = "en-US" @@ -167,7 +171,7 @@ class EmailI18nService: email_type: EmailType, language_code: str, to: str, - template_context: Optional[dict[str, Any]] = None, + template_context: dict[str, Any] | None = None, ): """ Send internationalized email with branding support. @@ -441,6 +445,54 @@ def create_default_email_config() -> EmailI18nConfig: branded_template_path="clean_document_job_mail_template_zh-CN.html", ), }, + EmailType.EMAIL_REGISTER: { + EmailLanguage.EN_US: EmailTemplate( + subject="Register Your {application_title} Account", + template_path="register_email_template_en-US.html", + branded_template_path="without-brand/register_email_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="注册您的 {application_title} 账户", + template_path="register_email_template_zh-CN.html", + branded_template_path="without-brand/register_email_template_zh-CN.html", + ), + }, + EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST: { + EmailLanguage.EN_US: EmailTemplate( + subject="Register Your {application_title} Account", + template_path="register_email_when_account_exist_template_en-US.html", + branded_template_path="without-brand/register_email_when_account_exist_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="注册您的 {application_title} 账户", + template_path="register_email_when_account_exist_template_zh-CN.html", + branded_template_path="without-brand/register_email_when_account_exist_template_zh-CN.html", + ), + }, + EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST: { + EmailLanguage.EN_US: EmailTemplate( + subject="Reset Your {application_title} Password", + template_path="reset_password_mail_when_account_not_exist_template_en-US.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="重置您的 {application_title} 密码", + template_path="reset_password_mail_when_account_not_exist_template_zh-CN.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html", + ), + }, + EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER: { + EmailLanguage.EN_US: EmailTemplate( + subject="Reset Your {application_title} Password", + template_path="reset_password_mail_when_account_not_exist_no_register_template_en-US.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="重置您的 {application_title} 密码", + template_path="reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html", + branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html", + ), + }, } return EmailI18nConfig(templates=templates) @@ -463,7 +515,7 @@ def get_default_email_i18n_service() -> EmailI18nService: # Global instance -_email_i18n_service: Optional[EmailI18nService] = None +_email_i18n_service: EmailI18nService | None = None def get_email_i18n_service() -> EmailI18nService: diff --git a/api/libs/exception.py b/api/libs/exception.py index 5970269ecd..73379dfded 100644 --- a/api/libs/exception.py +++ b/api/libs/exception.py @@ -1,11 +1,9 @@ -from typing import Optional - from werkzeug.exceptions import HTTPException class BaseHTTPException(HTTPException): error_code: str = "unknown" - data: Optional[dict] = None + data: dict | None = None def __init__(self, description=None, response=None): super().__init__(description, response) diff --git a/api/libs/helper.py b/api/libs/helper.py index f3c46b4843..0551470f65 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -68,7 +68,7 @@ class AppIconUrlField(fields.Raw): if isinstance(obj, dict) and "app" in obj: obj = obj["app"] - if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value: + if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE: return file_helpers.get_signed_file_url(obj.icon) return None @@ -269,8 +269,8 @@ class TokenManager: cls, token_type: str, account: Optional["Account"] = None, - email: Optional[str] = None, - additional_data: Optional[dict] = None, + email: str | None = None, + additional_data: dict | None = None, ) -> str: if account is None and email is None: raise ValueError("Account or email must be provided") @@ -312,19 +312,19 @@ class TokenManager: redis_client.delete(token_key) @classmethod - def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]: + def get_token_data(cls, token: str, token_type: str) -> dict[str, Any] | None: key = cls._get_token_key(token, token_type) token_data_json = redis_client.get(key) if token_data_json is None: logger.warning("%s token %s not found with key %s", token_type, token, key) return None - token_data: Optional[dict[str, Any]] = json.loads(token_data_json) + token_data: dict[str, Any] | None = json.loads(token_data_json) return token_data @classmethod - def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: + def _get_current_token_for_account(cls, account_id: str, token_type: str) -> str | None: key = cls._get_account_token_key(account_id, token_type) - current_token: Optional[str] = redis_client.get(key) + current_token: str | None = redis_client.get(key) return current_token @classmethod diff --git a/api/libs/oauth.py b/api/libs/oauth.py index df75b55019..35bd6c2c7c 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,6 +1,5 @@ import urllib.parse from dataclasses import dataclass -from typing import Optional import requests @@ -41,7 +40,7 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" - def get_authorization_url(self, invite_token: Optional[str] = None): + def get_authorization_url(self, invite_token: str | None = None): params = { "client_id": self.client_id, "redirect_uri": self.redirect_uri, @@ -93,7 +92,7 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" - def get_authorization_url(self, invite_token: Optional[str] = None): + def get_authorization_url(self, invite_token: str | None = None): params = { "client_id": self.client_id, "response_type": "code", diff --git a/api/libs/orjson.py b/api/libs/orjson.py index 2fc5ce8dd3..6e7c6b738d 100644 --- a/api/libs/orjson.py +++ b/api/libs/orjson.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import orjson @@ -6,6 +6,6 @@ import orjson def orjson_dumps( obj: Any, encoding: str = "utf-8", - option: Optional[int] = None, + option: int | None = None, ) -> str: return orjson.dumps(obj, option=option).decode(encoding) diff --git a/api/migrations/versions/2025_09_11_1537-cf7c38a32b2d_add_credential_status_for_provider_table.py b/api/migrations/versions/2025_09_11_1537-cf7c38a32b2d_add_credential_status_for_provider_table.py new file mode 100644 index 0000000000..17467e6495 --- /dev/null +++ b/api/migrations/versions/2025_09_11_1537-cf7c38a32b2d_add_credential_status_for_provider_table.py @@ -0,0 +1,33 @@ +"""Add credential status for provider table + +Revision ID: cf7c38a32b2d +Revises: c20211f18133 +Create Date: 2025-09-11 15:37:17.771298 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'cf7c38a32b2d' +down_revision = 'c20211f18133' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_status', sa.String(length=20), server_default=sa.text("'active'::character varying"), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.drop_column('credential_status') + + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/models/account.py b/api/models/account.py index 4656b47e7a..8c1f990aa2 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -7,6 +7,7 @@ import sqlalchemy as sa from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor +from typing_extensions import deprecated from models.base import Base @@ -89,24 +90,24 @@ class Account(UserMixin, Base): id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column(String(255)) email: Mapped[str] = mapped_column(String(255)) - password: Mapped[Optional[str]] = mapped_column(String(255)) - password_salt: Mapped[Optional[str]] = mapped_column(String(255)) - avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - interface_language: Mapped[Optional[str]] = mapped_column(String(255)) - interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - timezone: Mapped[Optional[str]] = mapped_column(String(255)) - last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + password: Mapped[str | None] = mapped_column(String(255)) + password_salt: Mapped[str | None] = mapped_column(String(255)) + avatar: Mapped[str | None] = mapped_column(String(255), nullable=True) + interface_language: Mapped[str | None] = mapped_column(String(255)) + interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True) + timezone: Mapped[str | None] = mapped_column(String(255)) + last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True) last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying")) - initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) @reconstructor def init_on_load(self): - self.role: Optional[TenantAccountRole] = None - self._current_tenant: Optional[Tenant] = None + self.role: TenantAccountRole | None = None + self._current_tenant: Tenant | None = None @property def is_password_set(self): @@ -187,7 +188,28 @@ class Account(UserMixin, Base): return TenantAccountRole.is_admin_role(self.role) @property + @deprecated("Use has_edit_permission instead.") def is_editor(self): + """Determines if the account has edit permissions in their current tenant (workspace). + + This property checks if the current role has editing privileges, which includes: + - `OWNER` + - `ADMIN` + - `EDITOR` + + Note: This checks for any role with editing permission, not just the 'EDITOR' role specifically. + """ + return self.has_edit_permission + + @property + def has_edit_permission(self): + """Determines if the account has editing permissions in their current tenant (workspace). + + This property checks if the current role has editing privileges, which includes: + - `OWNER` + - `ADMIN` + - `EDITOR` + """ return TenantAccountRole.is_editing_role(self.role) @property @@ -210,10 +232,10 @@ class Tenant(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column(String(255)) - encrypt_public_key: Mapped[Optional[str]] = mapped_column(sa.Text) + encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text) plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying")) status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) - custom_config: Mapped[Optional[str]] = mapped_column(sa.Text) + custom_config: Mapped[str | None] = mapped_column(sa.Text) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) @@ -249,7 +271,7 @@ class TenantAccountJoin(Base): account_id: Mapped[str] = mapped_column(StringUUID) current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) role: Mapped[str] = mapped_column(String(16), server_default="normal") - invited_by: Mapped[Optional[str]] = mapped_column(StringUUID) + invited_by: Mapped[str | None] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) @@ -283,10 +305,10 @@ class InvitationCode(Base): batch: Mapped[str] = mapped_column(String(255)) code: Mapped[str] = mapped_column(String(32)) status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying")) - used_at: Mapped[Optional[datetime]] = mapped_column(DateTime) - used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID) - used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) - deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + used_at: Mapped[datetime | None] = mapped_column(DateTime) + used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID) + used_by_account_id: Mapped[str | None] = mapped_column(StringUUID) + deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/dataset.py b/api/models/dataset.py index 13087bf995..662cfeb0d2 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -10,7 +10,7 @@ import re import time from datetime import datetime from json import JSONDecodeError -from typing import Any, Optional, cast +from typing import Any, cast import sqlalchemy as sa from sqlalchemy import DateTime, String, func, select @@ -56,7 +56,7 @@ class Dataset(Base): provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'::character varying")) permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'::character varying")) data_source_type = mapped_column(String(255)) - indexing_technique: Mapped[Optional[str]] = mapped_column(String(255)) + indexing_technique: Mapped[str | None] = mapped_column(String(255)) index_struct = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -224,35 +224,35 @@ class Dataset(Base): doc_metadata.append( { "id": "built-in", - "name": BuiltInField.document_name.value, + "name": BuiltInField.document_name, "type": "string", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.uploader.value, + "name": BuiltInField.uploader, "type": "string", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.upload_date.value, + "name": BuiltInField.upload_date, "type": "time", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.last_update_date.value, + "name": BuiltInField.last_update_date, "type": "time", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.source.value, + "name": BuiltInField.source, "type": "string", } ) @@ -330,42 +330,42 @@ class Document(Base): created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) # start processing - processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + processing_started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # parsing file_id = mapped_column(sa.Text, nullable=True) - word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable - parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable + parsing_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # cleaning - cleaning_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + cleaning_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # split - splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + splitting_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # indexing - tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) - indexing_latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + indexing_latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # pause - is_paused: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) + is_paused: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) paused_by = mapped_column(StringUUID, nullable=True) - paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + paused_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # error error = mapped_column(sa.Text, nullable=True) - stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # basic fields indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'::character varying")) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) - disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) archived: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) archived_reason = mapped_column(String(255), nullable=True) archived_by = mapped_column(StringUUID, nullable=True) - archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) doc_type = mapped_column(String(40), nullable=True) doc_metadata = mapped_column(JSONB, nullable=True) @@ -544,7 +544,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.source, "type": "string", - "value": MetadataDataSource[self.data_source_type].value, + "value": MetadataDataSource[self.data_source_type], } ) return built_in_fields @@ -677,17 +677,17 @@ class DocumentSegment(Base): # basic fields hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) - disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'::character varying")) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) error = mapped_column(sa.Text, nullable=True) - stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) @property def dataset(self): @@ -829,8 +829,8 @@ class ChildChunk(Base): updated_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) - indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) error = mapped_column(sa.Text, nullable=True) @property diff --git a/api/models/model.py b/api/models/model.py index 5a4c5de6e1..928508cc48 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,7 +3,7 @@ import re import uuid from collections.abc import Mapping from datetime import datetime -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, Literal, Optional, cast from core.plugin.entities.plugin import GenericProviderID @@ -62,9 +62,9 @@ class AppMode(StrEnum): raise ValueError(f"invalid mode value {value}") -class IconType(Enum): - IMAGE = "image" - EMOJI = "emoji" +class IconType(StrEnum): + IMAGE = auto() + EMOJI = auto() class App(Base): @@ -76,9 +76,9 @@ class App(Base): name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying")) mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji + icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji icon = mapped_column(String(255)) - icon_background: Mapped[Optional[str]] = mapped_column(String(255)) + icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) @@ -90,7 +90,7 @@ class App(Base): is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) tracing = mapped_column(sa.Text, nullable=True) - max_active_requests: Mapped[Optional[int]] + max_active_requests: Mapped[int | None] created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -134,7 +134,7 @@ class App(Base): return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" @property - def tenant(self) -> Optional[Tenant]: + def tenant(self) -> Tenant | None: tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @@ -149,15 +149,15 @@ class App(Base): if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get( "strategy", "" ) in {"function_call", "react"}: - self.mode = AppMode.AGENT_CHAT.value + self.mode = AppMode.AGENT_CHAT db.session.commit() return True return False @property def mode_compatible_with_agent(self) -> str: - if self.mode == AppMode.CHAT.value and self.is_agent: - return AppMode.AGENT_CHAT.value + if self.mode == AppMode.CHAT and self.is_agent: + return AppMode.AGENT_CHAT return str(self.mode) @@ -291,7 +291,7 @@ class App(Base): return tags or [] @property - def author_name(self) -> Optional[str]: + def author_name(self) -> str | None: if self.created_by: account = db.session.query(Account).where(Account.id == self.created_by).first() if account: @@ -334,7 +334,7 @@ class AppModelConfig(Base): file_upload = mapped_column(sa.Text) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @@ -546,7 +546,7 @@ class RecommendedApp(Base): updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @@ -570,12 +570,12 @@ class InstalledApp(Base): created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @property - def tenant(self) -> Optional[Tenant]: + def tenant(self) -> Tenant | None: tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @@ -711,9 +711,9 @@ class Conversation(Base): @property def model_config(self): model_config = {} - app_model_config: Optional[AppModelConfig] = None + app_model_config: AppModelConfig | None = None - if self.mode == AppMode.ADVANCED_CHAT.value: + if self.mode == AppMode.ADVANCED_CHAT: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) model_config = override_model_configs @@ -845,7 +845,7 @@ class Conversation(Base): ) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: return db.session.query(App).where(App.id == self.app_id).first() @property @@ -858,7 +858,7 @@ class Conversation(Base): return None @property - def from_account_name(self) -> Optional[str]: + def from_account_name(self) -> str | None: if self.from_account_id: account = db.session.query(Account).where(Account.id == self.from_account_id).first() if account: @@ -933,14 +933,14 @@ class Message(Base): status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) error = mapped_column(sa.Text) message_metadata = mapped_column(sa.Text) - invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True) from_source: Mapped[str] = mapped_column(String(255), nullable=False) - from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) - from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) + from_end_user_id: Mapped[str | None] = mapped_column(StringUUID) + from_account_id: Mapped[str | None] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) @property def inputs(self) -> dict[str, Any]: @@ -1337,9 +1337,9 @@ class MessageFile(Base): message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) transfer_method: Mapped[str] = mapped_column(String(255), nullable=False) - url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) - belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + url: Mapped[str | None] = mapped_column(sa.Text, nullable=True) + belongs_to: Mapped[str | None] = mapped_column(String(255), nullable=True) + upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1356,8 +1356,8 @@ class MessageAnnotation(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id: Mapped[str] = mapped_column(StringUUID) - conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) - message_id: Mapped[Optional[str]] = mapped_column(StringUUID) + conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) + message_id: Mapped[str | None] = mapped_column(StringUUID) question = mapped_column(sa.Text, nullable=True) content = mapped_column(sa.Text, nullable=False) hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) @@ -1459,6 +1459,14 @@ class OperationLog(Base): updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) +class DefaultEndUserSessionID(StrEnum): + """ + End User Session ID enum. + """ + + DEFAULT_SESSION_ID = "DEFAULT-USER" + + class EndUser(Base, UserMixin): __tablename__ = "end_users" __table_args__ = ( @@ -1721,18 +1729,18 @@ class MessageAgentThought(Base): # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design tool_process_data = mapped_column(sa.Text, nullable=True) message = mapped_column(sa.Text, nullable=True) - message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) message_unit_price = mapped_column(sa.Numeric, nullable=True) message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) message_files = mapped_column(sa.Text, nullable=True) answer = mapped_column(sa.Text, nullable=True) - answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) answer_unit_price = mapped_column(sa.Numeric, nullable=True) answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) - tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) total_price = mapped_column(sa.Numeric, nullable=True) currency = mapped_column(String, nullable=True) - latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1830,11 +1838,11 @@ class DatasetRetrieverResource(Base): document_name = mapped_column(sa.Text, nullable=False) data_source_type = mapped_column(sa.Text, nullable=True) segment_id = mapped_column(StringUUID, nullable=True) - score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + score: Mapped[float | None] = mapped_column(sa.Float, nullable=True) content = mapped_column(sa.Text, nullable=False) - hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) - word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) - segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + hit_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + segment_position: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) index_node_hash = mapped_column(sa.Text, nullable=True) retriever_from = mapped_column(sa.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) diff --git a/api/models/provider.py b/api/models/provider.py index 9a344ea56d..aacc6e505a 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,7 +1,6 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum, auto from functools import cached_property -from typing import Optional import sqlalchemy as sa from sqlalchemy import DateTime, String, func, text @@ -12,9 +11,9 @@ from .engine import db from .types import StringUUID -class ProviderType(Enum): - CUSTOM = "custom" - SYSTEM = "system" +class ProviderType(StrEnum): + CUSTOM = auto() + SYSTEM = auto() @staticmethod def value_of(value: str) -> "ProviderType": @@ -24,14 +23,14 @@ class ProviderType(Enum): raise ValueError(f"No matching enum found for value '{value}'") -class ProviderQuotaType(Enum): - PAID = "paid" +class ProviderQuotaType(StrEnum): + PAID = auto() """hosted paid quota""" - FREE = "free" + FREE = auto() """third-party free quota""" - TRIAL = "trial" + TRIAL = auto() """hosted trial quota""" @staticmethod @@ -63,14 +62,14 @@ class Provider(Base): String(40), nullable=False, server_default=text("'custom'::character varying") ) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) - last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - quota_type: Mapped[Optional[str]] = mapped_column( + quota_type: Mapped[str | None] = mapped_column( String(40), nullable=True, server_default=text("''::character varying") ) - quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True) - quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0) + quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True) + quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, default=0) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -133,7 +132,7 @@ class ProviderModel(Base): provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) - credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -201,17 +200,17 @@ class ProviderOrder(Base): provider_name: Mapped[str] = mapped_column(String(255), nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False) - payment_id: Mapped[Optional[str]] = mapped_column(String(191)) - transaction_id: Mapped[Optional[str]] = mapped_column(String(191)) + payment_id: Mapped[str | None] = mapped_column(String(191)) + transaction_id: Mapped[str | None] = mapped_column(String(191)) quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) - currency: Mapped[Optional[str]] = mapped_column(String(40)) - total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer) + currency: Mapped[str | None] = mapped_column(String(40)) + total_amount: Mapped[int | None] = mapped_column(sa.Integer) payment_status: Mapped[str] = mapped_column( String(40), nullable=False, server_default=text("'wait_pay'::character varying") ) - paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime) - pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) - refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + paid_at: Mapped[datetime | None] = mapped_column(DateTime) + pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime) + refunded_at: Mapped[datetime | None] = mapped_column(DateTime) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -255,9 +254,9 @@ class LoadBalancingModelConfig(Base): model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) - credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) - credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True) + encrypted_config: Mapped[str | None] = mapped_column(sa.Text, nullable=True) + credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 8456d65a87..5b4c486bc4 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,6 +1,5 @@ import json from datetime import datetime -from typing import Optional import sqlalchemy as sa from sqlalchemy import DateTime, String, func @@ -27,7 +26,7 @@ class DataSourceOauthBinding(Base): source_info = mapped_column(JSONB, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) + disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) class DataSourceApiKeyAuthBinding(Base): @@ -45,7 +44,7 @@ class DataSourceApiKeyAuthBinding(Base): credentials = mapped_column(sa.Text, nullable=True) # JSON created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) + disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) def to_dict(self): return { diff --git a/api/models/task.py b/api/models/task.py index 9a52fcfb41..3da1674536 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional import sqlalchemy as sa from celery import states @@ -32,7 +31,7 @@ class CeleryTask(Base): args = mapped_column(sa.LargeBinary, nullable=True) kwargs = mapped_column(sa.LargeBinary, nullable=True) worker = mapped_column(String(155), nullable=True) - retries: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + retries: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) queue = mapped_column(String(155), nullable=True) @@ -46,4 +45,4 @@ class CeleryTaskSet(Base): ) taskset_id = mapped_column(String(155), unique=True) result = mapped_column(db.PickleType, nullable=True) - date_done: Mapped[Optional[datetime]] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) + date_done: Mapped[datetime | None] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 141393dc8e..040743fb0b 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast import sqlalchemy as sa from deprecated import deprecated @@ -401,13 +401,13 @@ class ToolFile(TypeBase): # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID) # conversation id - conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) # file key file_key: Mapped[str] = mapped_column(String(255), nullable=False) # mime type mimetype: Mapped[str] = mapped_column(String(255), nullable=False) # original url - original_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True, default=None) + original_url: Mapped[str | None] = mapped_column(String(2048), nullable=True, default=None) # name name: Mapped[str] = mapped_column(default="") # size diff --git a/api/models/workflow.py b/api/models/workflow.py index 4686b38b01..9d129a09e2 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,8 +2,8 @@ import json import logging from collections.abc import Mapping, Sequence from datetime import datetime -from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from enum import StrEnum, auto +from typing import TYPE_CHECKING, Any, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -41,13 +41,13 @@ from .types import EnumText, StringUUID logger = logging.getLogger(__name__) -class WorkflowType(Enum): +class WorkflowType(StrEnum): """ Workflow Type Enum """ - WORKFLOW = "workflow" - CHAT = "chat" + WORKFLOW = auto() + CHAT = auto() @classmethod def value_of(cls, value: str) -> "WorkflowType": @@ -130,7 +130,7 @@ class Workflow(Base): _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) + updated_by: Mapped[str | None] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, @@ -499,18 +499,18 @@ class WorkflowRun(Base): type: Mapped[str] = mapped_column(String(255)) triggered_from: Mapped[str] = mapped_column(String(255)) version: Mapped[str] = mapped_column(String(255)) - graph: Mapped[Optional[str]] = mapped_column(sa.Text) - inputs: Mapped[Optional[str]] = mapped_column(sa.Text) + graph: Mapped[str | None] = mapped_column(sa.Text) + inputs: Mapped[str | None] = mapped_column(sa.Text) status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded - outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error: Mapped[Optional[str]] = mapped_column(sa.Text) + outputs: Mapped[str | None] = mapped_column(sa.Text, default="{}") + error: Mapped[str | None] = mapped_column(sa.Text) elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + finished_at: Mapped[datetime | None] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) @property @@ -706,24 +706,24 @@ class WorkflowNodeExecutionModel(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) triggered_from: Mapped[str] = mapped_column(String(255)) - workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) index: Mapped[int] = mapped_column(sa.Integer) - predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255)) - node_execution_id: Mapped[Optional[str]] = mapped_column(String(255)) + predecessor_node_id: Mapped[str | None] = mapped_column(String(255)) + node_execution_id: Mapped[str | None] = mapped_column(String(255)) node_id: Mapped[str] = mapped_column(String(255)) node_type: Mapped[str] = mapped_column(String(255)) title: Mapped[str] = mapped_column(String(255)) - inputs: Mapped[Optional[str]] = mapped_column(sa.Text) - process_data: Mapped[Optional[str]] = mapped_column(sa.Text) - outputs: Mapped[Optional[str]] = mapped_column(sa.Text) + inputs: Mapped[str | None] = mapped_column(sa.Text) + process_data: Mapped[str | None] = mapped_column(sa.Text) + outputs: Mapped[str | None] = mapped_column(sa.Text) status: Mapped[str] = mapped_column(String(255)) - error: Mapped[Optional[str]] = mapped_column(sa.Text) + error: Mapped[str | None] = mapped_column(sa.Text) elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) - execution_metadata: Mapped[Optional[str]] = mapped_column(sa.Text) + execution_metadata: Mapped[str | None] = mapped_column(sa.Text) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) created_by_role: Mapped[str] = mapped_column(String(255)) created_by: Mapped[str] = mapped_column(StringUUID) - finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + finished_at: Mapped[datetime | None] = mapped_column(DateTime) @property def created_by_account(self): @@ -777,7 +777,7 @@ class WorkflowNodeExecutionModel(Base): return extras -class WorkflowAppLogCreatedFrom(Enum): +class WorkflowAppLogCreatedFrom(StrEnum): """ Workflow App Log Created From Enum """ diff --git a/api/pyproject.toml b/api/pyproject.toml index c59140e246..f4fe63f6b6 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -79,7 +79,7 @@ dependencies = [ "sqlalchemy~=2.0.29", "starlette==0.47.2", "tiktoken~=0.9.0", - "transformers~=4.53.0", + "transformers~=4.56.1", "unstructured[docx,epub,md,ppt,pptx]~=0.16.1", "weave~=0.51.0", "yarl~=1.18.3", @@ -168,6 +168,8 @@ dev = [ "types-redis>=4.6.0.20241004", "celery-types>=0.23.0", "mypy~=1.17.1", + "locust>=2.40.4", + "sseclient-py>=1.8.0", ] ############################################################ diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 00a2d1f87d..fa2c94b623 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -11,7 +11,7 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m from collections.abc import Sequence from datetime import datetime -from typing import Optional, Protocol +from typing import Protocol from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel @@ -44,7 +44,7 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr app_id: str, workflow_id: str, node_id: str, - ) -> Optional[WorkflowNodeExecutionModel]: + ) -> WorkflowNodeExecutionModel | None: """ Get the most recent execution for a specific node. @@ -87,8 +87,8 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr def get_execution_by_id( self, execution_id: str, - tenant_id: Optional[str] = None, - ) -> Optional[WorkflowNodeExecutionModel]: + tenant_id: str | None = None, + ) -> WorkflowNodeExecutionModel | None: """ Get a workflow node execution by its ID. diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 59e7baeb79..3ac28fad75 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -36,7 +36,7 @@ Example: from collections.abc import Sequence from datetime import datetime -from typing import Optional, Protocol +from typing import Protocol from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -58,7 +58,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): app_id: str, triggered_from: str, limit: int = 20, - last_id: Optional[str] = None, + last_id: str | None = None, ) -> InfiniteScrollPagination: """ Get paginated workflow runs with filtering. @@ -90,7 +90,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): tenant_id: str, app_id: str, run_id: str, - ) -> Optional[WorkflowRun]: + ) -> WorkflowRun | None: """ Get a specific workflow run by ID. diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index e6a23ddf9f..cbb09af542 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -7,7 +7,6 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. from collections.abc import Sequence from datetime import datetime -from typing import Optional from sqlalchemy import delete, desc, select from sqlalchemy.orm import Session, sessionmaker @@ -49,7 +48,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut app_id: str, workflow_id: str, node_id: str, - ) -> Optional[WorkflowNodeExecutionModel]: + ) -> WorkflowNodeExecutionModel | None: """ Get the most recent execution for a specific node. @@ -116,8 +115,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut def get_execution_by_id( self, execution_id: str, - tenant_id: Optional[str] = None, - ) -> Optional[WorkflowNodeExecutionModel]: + tenant_id: str | None = None, + ) -> WorkflowNodeExecutionModel | None: """ Get a workflow node execution by its ID. diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 6294846f5e..205f8c87ee 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -22,7 +22,6 @@ Implementation Notes: import logging from collections.abc import Sequence from datetime import datetime -from typing import Optional from sqlalchemy import delete, select from sqlalchemy.orm import Session, sessionmaker @@ -61,7 +60,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): app_id: str, triggered_from: str, limit: int = 20, - last_id: Optional[str] = None, + last_id: str | None = None, ) -> InfiniteScrollPagination: """ Get paginated workflow runs with filtering. @@ -107,7 +106,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): tenant_id: str, app_id: str, run_id: str, - ) -> Optional[WorkflowRun]: + ) -> WorkflowRun | None: """ Get a specific workflow run by ID with tenant and app isolation. """ diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 2b1e6b47cc..9efd46ba5d 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -1,6 +1,6 @@ import datetime import time -from typing import Optional, TypedDict +from typing import TypedDict import click from sqlalchemy import func, select @@ -17,7 +17,7 @@ from services.feature_service import FeatureService class CleanupConfig(TypedDict): clean_day: datetime.datetime - plan_filter: Optional[str] + plan_filter: str | None add_logs: bool diff --git a/api/services/account_service.py b/api/services/account_service.py index f917959350..8362e415c1 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -5,7 +5,7 @@ import secrets import uuid from datetime import UTC, datetime, timedelta from hashlib import sha256 -from typing import Any, Optional, cast +from typing import Any, cast from pydantic import BaseModel from sqlalchemy import func @@ -37,7 +37,6 @@ from services.billing_service import BillingService from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, - AccountNotFoundError, AccountNotLinkTenantError, AccountPasswordError, AccountRegisterError, @@ -65,7 +64,11 @@ from tasks.mail_owner_transfer_task import ( send_old_owner_transfer_notify_email_task, send_owner_transfer_confirm_task, ) -from tasks.mail_reset_password_task import send_reset_password_mail_task +from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist +from tasks.mail_reset_password_task import ( + send_reset_password_mail_task, + send_reset_password_mail_task_when_account_not_exist, +) logger = logging.getLogger(__name__) @@ -82,8 +85,9 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) class AccountService: reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) + email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1) email_code_login_rate_limiter = RateLimiter( - prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1 + prefix="email_code_login_rate_limit", max_attempts=3, time_window=300 * 1 ) email_code_account_deletion_rate_limiter = RateLimiter( prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1 @@ -95,6 +99,7 @@ class AccountService: FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5 CHANGE_EMAIL_MAX_ERROR_LIMITS = 5 OWNER_TRANSFER_MAX_ERROR_LIMITS = 5 + EMAIL_REGISTER_MAX_ERROR_LIMITS = 5 @staticmethod def _get_refresh_token_key(refresh_token: str) -> str: @@ -166,12 +171,12 @@ class AccountService: return token @staticmethod - def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account: + def authenticate(email: str, password: str, invite_token: str | None = None) -> Account: """authenticate account with email and password""" account = db.session.query(Account).filter_by(email=email).first() if not account: - raise AccountNotFoundError() + raise AccountPasswordError("Invalid email or password.") if account.status == AccountStatus.BANNED.value: raise AccountLoginError("Account is banned.") @@ -223,9 +228,9 @@ class AccountService: email: str, name: str, interface_language: str, - password: Optional[str] = None, + password: str | None = None, interface_theme: str = "light", - is_setup: Optional[bool] = False, + is_setup: bool | None = False, ) -> Account: """create account""" if not FeatureService.get_system_features().is_allow_register and not is_setup: @@ -271,7 +276,7 @@ class AccountService: @staticmethod def create_account_and_tenant( - email: str, name: str, interface_language: str, password: Optional[str] = None + email: str, name: str, interface_language: str, password: str | None = None ) -> Account: """create account""" account = AccountService.create_account( @@ -296,7 +301,9 @@ class AccountService: if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError - raise EmailCodeAccountDeletionRateLimitExceededError() + raise EmailCodeAccountDeletionRateLimitExceededError( + int(cls.email_code_account_deletion_rate_limiter.time_window / 60) + ) send_account_deletion_verification_code.delay(to=email, code=code) @@ -323,7 +330,7 @@ class AccountService: """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: Optional[AccountIntegrate] = ( + account_integrate: AccountIntegrate | None = ( db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() ) @@ -384,7 +391,7 @@ class AccountService: db.session.commit() @staticmethod - def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair: + def login(account: Account, *, ip_address: str | None = None) -> TokenPair: if ip_address: AccountService.update_login_info(account=account, ip_address=ip_address) @@ -432,9 +439,10 @@ class AccountService: @classmethod def send_reset_password_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", + is_allow_register: bool = False, ): account_email = account.email if account else email if account_email is None: @@ -443,26 +451,67 @@ class AccountService: if cls.reset_password_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import PasswordResetRateLimitExceededError - raise PasswordResetRateLimitExceededError() + raise PasswordResetRateLimitExceededError(int(cls.reset_password_rate_limiter.time_window / 60)) code, token = cls.generate_reset_password_token(account_email, account) - send_reset_password_mail_task.delay( - language=language, - to=account_email, - code=code, - ) + if account: + send_reset_password_mail_task.delay( + language=language, + to=account_email, + code=code, + ) + else: + send_reset_password_mail_task_when_account_not_exist.delay( + language=language, + to=account_email, + is_allow_register=is_allow_register, + ) cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token + @classmethod + def send_email_register_email( + cls, + account: Account | None = None, + email: str | None = None, + language: str = "en-US", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + if cls.email_register_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import EmailRegisterRateLimitExceededError + + raise EmailRegisterRateLimitExceededError(int(cls.email_register_rate_limiter.time_window / 60)) + + code, token = cls.generate_email_register_token(account_email) + + if account: + send_email_register_mail_task_when_account_exist.delay( + language=language, + to=account_email, + account_name=account.name, + ) + + else: + send_email_register_mail_task.delay( + language=language, + to=account_email, + code=code, + ) + cls.email_register_rate_limiter.increment_rate_limit(account_email) + return token + @classmethod def send_change_email_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, - old_email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, + old_email: str | None = None, language: str = "en-US", - phase: Optional[str] = None, + phase: str | None = None, ): account_email = account.email if account else email if account_email is None: @@ -473,7 +522,7 @@ class AccountService: if cls.change_email_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import EmailChangeRateLimitExceededError - raise EmailChangeRateLimitExceededError() + raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60)) code, token = cls.generate_change_email_token(account_email, account, old_email=old_email) @@ -489,8 +538,8 @@ class AccountService: @classmethod def send_change_email_completed_notify_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", ): account_email = account.email if account else email @@ -505,10 +554,10 @@ class AccountService: @classmethod def send_owner_transfer_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", - workspace_name: Optional[str] = "", + workspace_name: str | None = "", ): account_email = account.email if account else email if account_email is None: @@ -517,7 +566,7 @@ class AccountService: if cls.owner_transfer_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import OwnerTransferRateLimitExceededError - raise OwnerTransferRateLimitExceededError() + raise OwnerTransferRateLimitExceededError(int(cls.owner_transfer_rate_limiter.time_window / 60)) code, token = cls.generate_owner_transfer_token(account_email, account) workspace_name = workspace_name or "" @@ -534,10 +583,10 @@ class AccountService: @classmethod def send_old_owner_transfer_notify_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", - workspace_name: Optional[str] = "", + workspace_name: str | None = "", new_owner_email: str = "", ): account_email = account.email if account else email @@ -555,10 +604,10 @@ class AccountService: @classmethod def send_new_owner_transfer_notify_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", - workspace_name: Optional[str] = "", + workspace_name: str | None = "", ): account_email = account.email if account else email if account_email is None: @@ -575,8 +624,8 @@ class AccountService: def generate_reset_password_token( cls, email: str, - account: Optional[Account] = None, - code: Optional[str] = None, + account: Account | None = None, + code: str | None = None, additional_data: dict[str, Any] = {}, ): if not code: @@ -587,13 +636,26 @@ class AccountService: ) return code, token + @classmethod + def generate_email_register_token( + cls, + email: str, + code: str | None = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) + additional_data["code"] = code + token = TokenManager.generate_token(email=email, token_type="email_register", additional_data=additional_data) + return code, token + @classmethod def generate_change_email_token( cls, email: str, - account: Optional[Account] = None, - code: Optional[str] = None, - old_email: Optional[str] = None, + account: Account | None = None, + code: str | None = None, + old_email: str | None = None, additional_data: dict[str, Any] = {}, ): if not code: @@ -609,8 +671,8 @@ class AccountService: def generate_owner_transfer_token( cls, email: str, - account: Optional[Account] = None, - code: Optional[str] = None, + account: Account | None = None, + code: str | None = None, additional_data: dict[str, Any] = {}, ): if not code: @@ -625,6 +687,10 @@ class AccountService: def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password") + @classmethod + def revoke_email_register_token(cls, token: str): + TokenManager.revoke_token(token, "email_register") + @classmethod def revoke_change_email_token(cls, token: str): TokenManager.revoke_token(token, "change_email") @@ -634,22 +700,26 @@ class AccountService: TokenManager.revoke_token(token, "owner_transfer") @classmethod - def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_reset_password_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "reset_password") @classmethod - def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_email_register_data(cls, token: str) -> dict[str, Any] | None: + return TokenManager.get_token_data(token, "email_register") + + @classmethod + def get_change_email_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "change_email") @classmethod - def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_owner_transfer_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "owner_transfer") @classmethod def send_email_code_login_email( cls, - account: Optional[Account] = None, - email: Optional[str] = None, + account: Account | None = None, + email: str | None = None, language: str = "en-US", ): email = account.email if account else email @@ -658,7 +728,7 @@ class AccountService: if cls.email_code_login_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError - raise EmailCodeLoginRateLimitExceededError() + raise EmailCodeLoginRateLimitExceededError(int(cls.email_code_login_rate_limiter.time_window / 60)) code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) token = TokenManager.generate_token( @@ -673,7 +743,7 @@ class AccountService: return token @classmethod - def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "email_code_login") @classmethod @@ -744,6 +814,16 @@ class AccountService: count = int(count) + 1 redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count) + @staticmethod + @redis_fallback(default_return=None) + def add_email_register_error_rate_limit(email: str) -> None: + key = f"email_register_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, dify_config.EMAIL_REGISTER_LOCKOUT_DURATION, count) + @staticmethod @redis_fallback(default_return=False) def is_forgot_password_error_rate_limit(email: str) -> bool: @@ -763,6 +843,24 @@ class AccountService: key = f"forgot_password_error_rate_limit:{email}" redis_client.delete(key) + @staticmethod + @redis_fallback(default_return=False) + def is_email_register_error_rate_limit(email: str) -> bool: + key = f"email_register_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + count = int(count) + if count > AccountService.EMAIL_REGISTER_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + @redis_fallback(default_return=None) + def reset_email_register_error_rate_limit(email: str): + key = f"email_register_error_rate_limit:{email}" + redis_client.delete(key) + @staticmethod @redis_fallback(default_return=None) def add_change_email_error_rate_limit(email: str): @@ -867,7 +965,7 @@ class AccountService: class TenantService: @staticmethod - def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant: + def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant: """Create tenant""" if ( not FeatureService.get_system_features().is_allow_create_workspace @@ -898,9 +996,7 @@ class TenantService: return tenant @staticmethod - def create_owner_tenant_if_not_exist( - account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False - ): + def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False): """Check if user have a workspace or not""" available_ta = ( db.session.query(TenantAccountJoin) @@ -972,7 +1068,7 @@ class TenantService: return tenant @staticmethod - def switch_tenant(account: Account, tenant_id: Optional[str] = None): + def switch_tenant(account: Account, tenant_id: str | None = None): """Switch the current workspace for the account""" # Ensure tenant_id is provided @@ -1054,7 +1150,7 @@ class TenantService: ) @staticmethod - def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountRole]: + def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None: """Get the role of the current account for a given tenant""" join = ( db.session.query(TenantAccountJoin) @@ -1194,13 +1290,13 @@ class RegisterService: cls, email, name, - password: Optional[str] = None, - open_id: Optional[str] = None, - provider: Optional[str] = None, - language: Optional[str] = None, - status: Optional[AccountStatus] = None, - is_setup: Optional[bool] = False, - create_workspace_required: Optional[bool] = True, + password: str | None = None, + open_id: str | None = None, + provider: str | None = None, + language: str | None = None, + status: AccountStatus | None = None, + is_setup: bool | None = False, + create_workspace_required: bool | None = True, ) -> Account: db.session.begin_nested() """Register account""" @@ -1317,9 +1413,7 @@ class RegisterService: redis_client.delete(cls._get_invitation_token_key(token)) @classmethod - def get_invitation_if_token_valid( - cls, workspace_id: Optional[str], email: str, token: str - ) -> Optional[dict[str, Any]]: + def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None: invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -1358,8 +1452,8 @@ class RegisterService: @classmethod def get_invitation_by_token( - cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None - ) -> Optional[dict[str, str]]: + cls, token: str, workspace_id: str | None = None, email: str | None = None + ) -> dict[str, str] | None: if workspace_id is not None and email is not None: email_hash = sha256(email.encode()).hexdigest() cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 6f0ab2546a..f2ffa3b170 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -32,14 +32,14 @@ class AdvancedPromptTemplateService: def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str): context_prompt = copy.deepcopy(CONTEXT) - if app_mode == AppMode.CHAT.value: + if app_mode == AppMode.CHAT: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt ) elif model_mode == "chat": return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - elif app_mode == AppMode.COMPLETION.value: + elif app_mode == AppMode.COMPLETION: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt @@ -73,7 +73,7 @@ class AdvancedPromptTemplateService: def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str): baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) - if app_mode == AppMode.CHAT.value: + if app_mode == AppMode.CHAT: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt @@ -82,7 +82,7 @@ class AdvancedPromptTemplateService: return cls.get_chat_prompt( copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt ) - elif app_mode == AppMode.COMPLETION.value: + elif app_mode == AppMode.COMPLETION: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 8578f38a0d..d631ce812f 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -1,5 +1,5 @@ import threading -from typing import Any, Optional +from typing import Any import pytz @@ -35,7 +35,7 @@ class AgentService: if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Optional[Message] = ( + message: Message | None = ( db.session.query(Message) .where( Message.id == message_id, diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 34681ba111..9feca7337f 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,5 +1,4 @@ import uuid -from typing import Optional import pandas as pd from sqlalchemy import or_, select @@ -42,7 +41,7 @@ class AppAnnotationService: if not message: raise NotFound("Message Not Exists.") - annotation: Optional[MessageAnnotation] = message.annotation + annotation: MessageAnnotation | None = message.annotation # save the message annotation if annotation: annotation.content = args["answer"] diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 49ff28d191..1c4a9b96ec 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -4,7 +4,6 @@ import logging import uuid from collections.abc import Mapping from enum import StrEnum -from typing import Optional from urllib.parse import urlparse from uuid import uuid4 @@ -61,8 +60,8 @@ class ImportStatus(StrEnum): class Import(BaseModel): id: str status: ImportStatus - app_id: Optional[str] = None - app_mode: Optional[str] = None + app_id: str | None = None + app_mode: str | None = None current_dsl_version: str = CURRENT_DSL_VERSION imported_dsl_version: str = "" error: str = "" @@ -121,14 +120,14 @@ class AppDslService: *, account: Account, import_mode: str, - yaml_content: Optional[str] = None, - yaml_url: Optional[str] = None, - name: Optional[str] = None, - description: Optional[str] = None, - icon_type: Optional[str] = None, - icon: Optional[str] = None, - icon_background: Optional[str] = None, - app_id: Optional[str] = None, + yaml_content: str | None = None, + yaml_url: str | None = None, + name: str | None = None, + description: str | None = None, + icon_type: str | None = None, + icon: str | None = None, + icon_background: str | None = None, + app_id: str | None = None, ) -> Import: """Import an app from YAML content or URL.""" import_id = str(uuid.uuid4()) @@ -407,15 +406,15 @@ class AppDslService: def _create_or_update_app( self, *, - app: Optional[App], + app: App | None, data: dict, account: Account, - name: Optional[str] = None, - description: Optional[str] = None, - icon_type: Optional[str] = None, - icon: Optional[str] = None, - icon_background: Optional[str] = None, - dependencies: Optional[list[PluginDependency]] = None, + name: str | None = None, + description: str | None = None, + icon_type: str | None = None, + icon: str | None = None, + icon_background: str | None = None, + dependencies: list[PluginDependency] | None = None, ) -> App: """Create a new app or update an existing one.""" app_data = data.get("app", {}) @@ -533,7 +532,7 @@ class AppDslService: return app @classmethod - def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: Optional[str] = None) -> str: + def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: str | None = None) -> str: """ Export app :param app_model: App instance @@ -566,7 +565,7 @@ class AppDslService: @classmethod def _append_workflow_export_data( - cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: Optional[str] = None + cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None ): """ Append workflow export data diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index e812fcc992..1fae452d38 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Any, Optional, Union +from typing import Any, Union from openai._exceptions import RateLimitError @@ -60,7 +60,7 @@ class AppGenerateService: request_id = RateLimit.gen_request_key() try: request_id = rate_limit.enter(request_id) - if app_model.mode == AppMode.COMPLETION.value: + if app_model.mode == AppMode.COMPLETION: return rate_limit.generate( CompletionAppGenerator.convert_to_event_stream( CompletionAppGenerator().generate( @@ -69,7 +69,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: return rate_limit.generate( AgentChatAppGenerator.convert_to_event_stream( AgentChatAppGenerator().generate( @@ -78,7 +78,7 @@ class AppGenerateService: ), request_id, ) - elif app_model.mode == AppMode.CHAT.value: + elif app_model.mode == AppMode.CHAT: return rate_limit.generate( ChatAppGenerator.convert_to_event_stream( ChatAppGenerator().generate( @@ -87,7 +87,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.ADVANCED_CHAT.value: + elif app_model.mode == AppMode.ADVANCED_CHAT: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( @@ -103,7 +103,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( @@ -155,14 +155,14 @@ class AppGenerateService: @classmethod def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( AdvancedChatAppGenerator().single_iteration_generate( app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( WorkflowAppGenerator().single_iteration_generate( @@ -174,14 +174,14 @@ class AppGenerateService: @classmethod def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( AdvancedChatAppGenerator().single_loop_generate( app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( WorkflowAppGenerator().single_loop_generate( @@ -214,7 +214,7 @@ class AppGenerateService: ) @classmethod - def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow: + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: str | None = None) -> Workflow: """ Get workflow :param app_model: app model diff --git a/api/services/app_service.py b/api/services/app_service.py index 9b200a570d..ab2b38ec01 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, TypedDict, cast +from typing import TypedDict, cast from flask_sqlalchemy.pagination import Pagination @@ -40,15 +40,15 @@ class AppService: filters = [App.tenant_id == tenant_id, App.is_universal == False] if args["mode"] == "workflow": - filters.append(App.mode == AppMode.WORKFLOW.value) + filters.append(App.mode == AppMode.WORKFLOW) elif args["mode"] == "completion": - filters.append(App.mode == AppMode.COMPLETION.value) + filters.append(App.mode == AppMode.COMPLETION) elif args["mode"] == "chat": - filters.append(App.mode == AppMode.CHAT.value) + filters.append(App.mode == AppMode.CHAT) elif args["mode"] == "advanced-chat": - filters.append(App.mode == AppMode.ADVANCED_CHAT.value) + filters.append(App.mode == AppMode.ADVANCED_CHAT) elif args["mode"] == "agent-chat": - filters.append(App.mode == AppMode.AGENT_CHAT.value) + filters.append(App.mode == AppMode.AGENT_CHAT) if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) @@ -171,7 +171,7 @@ class AppService: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None # get original app model config - if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: + if app.mode == AppMode.AGENT_CHAT or app.is_agent: model_config = app.app_model_config if not model_config: return app @@ -370,7 +370,7 @@ class AppService: } ) else: - app_model_config: Optional[AppModelConfig] = app_model.app_model_config + app_model_config: AppModelConfig | None = app_model.app_model_config if not app_model_config: return meta @@ -393,7 +393,7 @@ class AppService: meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" elif provider_type == "api": try: - provider: Optional[ApiToolProvider] = ( + provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first() ) if provider is None: diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 9b1999d813..1158fc5197 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -2,7 +2,6 @@ import io import logging import uuid from collections.abc import Generator -from typing import Optional from flask import Response, stream_with_context from werkzeug.datastructures import FileStorage @@ -30,8 +29,8 @@ logger = logging.getLogger(__name__) class AudioService: @classmethod - def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None): + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise ValueError("Speech to text is not enabled") @@ -77,18 +76,18 @@ class AudioService: def transcript_tts( cls, app_model: App, - text: Optional[str] = None, - voice: Optional[str] = None, - end_user: Optional[str] = None, - message_id: Optional[str] = None, + text: str | None = None, + voice: str | None = None, + end_user: str | None = None, + message_id: str | None = None, is_draft: bool = False, ): from app import app - def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False): + def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False): with app.app_context(): if voice is None: - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if is_draft: workflow = WorkflowService().get_draft_workflow(app_model=app_model) else: diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 066bed3234..a364862a88 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,5 +1,5 @@ import os -from typing import Literal, Optional +from typing import Literal import httpx from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed @@ -73,7 +73,7 @@ class BillingService: def is_tenant_owner_or_admin(current_user: Account): tenant_id = current_user.current_tenant_id - join: Optional[TenantAccountJoin] = ( + join: TenantAccountJoin | None = ( db.session.query(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .first() diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index d017ce54ab..a8e51a426d 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,7 +1,7 @@ import contextlib import logging from collections.abc import Callable, Sequence -from typing import Any, Optional, Union +from typing import Any, Union from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -36,12 +36,12 @@ class ConversationService: *, session: Session, app_model: App, - user: Optional[Union[Account, EndUser]], - last_id: Optional[str], + user: Union[Account, EndUser] | None, + last_id: str | None, limit: int, invoke_from: InvokeFrom, - include_ids: Optional[Sequence[str]] = None, - exclude_ids: Optional[Sequence[str]] = None, + include_ids: Sequence[str] | None = None, + exclude_ids: Sequence[str] | None = None, sort_by: str = "-updated_at", ) -> InfiniteScrollPagination: if not user: @@ -118,7 +118,7 @@ class ConversationService: cls, app_model: App, conversation_id: str, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, name: str, auto_generate: bool, ): @@ -158,7 +158,7 @@ class ConversationService: return conversation @classmethod - def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): conversation = ( db.session.query(Conversation) .where( @@ -178,7 +178,7 @@ class ConversationService: return conversation @classmethod - def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): try: logger.info( "Initiating conversation deletion for app_name %s, conversation_id: %s", @@ -200,9 +200,9 @@ class ConversationService: cls, app_model: App, conversation_id: str, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, limit: int, - last_id: Optional[str], + last_id: str | None, ) -> InfiniteScrollPagination: conversation = cls.get_conversation(app_model, conversation_id, user) @@ -222,8 +222,8 @@ class ConversationService: # Filter for variables created after the last_id stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at) - # Apply limit to query - query_stmt = stmt.limit(limit) # Get one extra to check if there are more + # Apply limit to query: fetch one extra row to determine has_more + query_stmt = stmt.limit(limit + 1) rows = session.scalars(query_stmt).all() has_more = False @@ -248,7 +248,7 @@ class ConversationService: app_model: App, conversation_id: str, variable_id: str, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, new_value: Any, ): """ diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 47bd06a7cc..102629629d 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,7 +7,7 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal import sqlalchemy as sa from sqlalchemy import exists, func, select @@ -185,16 +185,16 @@ class DatasetService: def create_empty_dataset( tenant_id: str, name: str, - description: Optional[str], - indexing_technique: Optional[str], + description: str | None, + indexing_technique: str | None, account: Account, - permission: Optional[str] = None, + permission: str | None = None, provider: str = "vendor", - external_knowledge_api_id: Optional[str] = None, - external_knowledge_id: Optional[str] = None, - embedding_model_provider: Optional[str] = None, - embedding_model_name: Optional[str] = None, - retrieval_model: Optional[RetrievalModel] = None, + external_knowledge_api_id: str | None = None, + external_knowledge_id: str | None = None, + embedding_model_provider: str | None = None, + embedding_model_name: str | None = None, + retrieval_model: RetrievalModel | None = None, ): # check if dataset name already exists if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): @@ -257,8 +257,8 @@ class DatasetService: return dataset @staticmethod - def get_dataset(dataset_id) -> Optional[Dataset]: - dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() + def get_dataset(dataset_id) -> Dataset | None: + dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first() return dataset @staticmethod @@ -694,7 +694,7 @@ class DatasetService: raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod - def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): + def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None): if not dataset: raise ValueError("Dataset not found") @@ -868,7 +868,7 @@ class DocumentService: } @staticmethod - def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: + def get_document(dataset_id: str, document_id: str | None = None) -> Document | None: if document_id: document = ( db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() @@ -878,7 +878,7 @@ class DocumentService: return None @staticmethod - def get_document_by_id(document_id: str) -> Optional[Document]: + def get_document_by_id(document_id: str) -> Document | None: document = db.session.query(Document).where(Document.id == document_id).first() return document @@ -1004,7 +1004,7 @@ class DocumentService: if dataset.built_in_field_enabled: if document.doc_metadata: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata[BuiltInField.document_name.value] = name + doc_metadata[BuiltInField.document_name] = name document.doc_metadata = doc_metadata document.name = name @@ -1099,7 +1099,7 @@ class DocumentService: dataset: Dataset, knowledge_config: KnowledgeConfig, account: Account | Any, - dataset_process_rule: Optional[DatasetProcessRule] = None, + dataset_process_rule: DatasetProcessRule | None = None, created_from: str = "web", ) -> tuple[list[Document], str]: # check doc_form @@ -1463,7 +1463,7 @@ class DocumentService: dataset: Dataset, document_data: KnowledgeConfig, account: Account, - dataset_process_rule: Optional[DatasetProcessRule] = None, + dataset_process_rule: DatasetProcessRule | None = None, created_from: str = "web", ): assert isinstance(current_user, Account) @@ -2365,7 +2365,22 @@ class SegmentService: if segment.enabled: # send delete segment index task redis_client.setex(indexing_cache_key, 600, 1) - delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) + + # Get child chunk IDs before parent segment is deleted + child_node_ids = [] + if segment.index_node_id: + child_chunks = ( + db.session.query(ChildChunk.index_node_id) + .where( + ChildChunk.segment_id == segment.id, + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] + + delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids) + db.session.delete(segment) # update document word count assert document.word_count is not None @@ -2375,9 +2390,13 @@ class SegmentService: @classmethod def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): - assert isinstance(current_user, Account) - segments = ( - db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) + assert current_user is not None + # Check if segment_ids is not empty to avoid WHERE false condition + if not segment_ids or len(segment_ids) == 0: + return + segments_info = ( + db.session.query(DocumentSegment) + .with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count) .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, @@ -2387,18 +2406,36 @@ class SegmentService: .all() ) - if not segments: + if not segments_info: return - index_node_ids = [seg.index_node_id for seg in segments] - total_words = sum(seg.word_count for seg in segments) + index_node_ids = [info[0] for info in segments_info] + segment_db_ids = [info[1] for info in segments_info] + total_words = sum(info[2] for info in segments_info if info[2] is not None) + + # Get child chunk IDs before parent segments are deleted + child_node_ids = [] + if index_node_ids: + child_chunks = ( + db.session.query(ChildChunk.index_node_id) + .where( + ChildChunk.segment_id.in_(segment_db_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] + + # Start async cleanup with both parent and child node IDs + if index_node_ids or child_node_ids: + delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids) document.word_count = ( document.word_count - total_words if document.word_count and document.word_count > total_words else 0 ) db.session.add(document) - delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) + # Delete database records db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete() db.session.commit() @@ -2618,7 +2655,7 @@ class SegmentService: @classmethod def get_child_chunks( - cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None + cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: str | None = None ): assert isinstance(current_user, Account) @@ -2637,7 +2674,7 @@ class SegmentService: return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod - def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]: + def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None: """Get a child chunk by its ID.""" result = ( db.session.query(ChildChunk) @@ -2674,7 +2711,7 @@ class SegmentService: return paginated_segments.items, paginated_segments.total @classmethod - def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: + def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None: """Get a segment by its ID.""" result = ( db.session.query(DocumentSegment) @@ -2738,11 +2775,7 @@ class DatasetPermissionService: ).where(DatasetPermission.dataset_id == dataset_id) ).all() - user_list = [] - for user in user_list_query: - user_list.append(user.account_id) - - return user_list + return user_list_query @classmethod def update_partial_member_list(cls, tenant_id, dataset_id, user_list): diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py index ee8a932ded..1065d3842a 100644 --- a/api/services/enterprise/plugin_manager_service.py +++ b/api/services/enterprise/plugin_manager_service.py @@ -9,9 +9,9 @@ from services.errors.base import BaseServiceError logger = logging.getLogger(__name__) -class PluginCredentialType(enum.IntEnum): - MODEL = enum.auto() - TOOL = enum.auto() +class PluginCredentialType(enum.Enum): + MODEL = 0 # must be 0 for API contract compatibility + TOOL = 1 # must be 1 for API contract compatibility def to_number(self): return self.value diff --git a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py index 4545f385eb..c9fb1c9e21 100644 --- a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py +++ b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Union +from typing import Literal, Union from pydantic import BaseModel @@ -11,7 +11,7 @@ class AuthorizationConfig(BaseModel): class Authorization(BaseModel): type: Literal["no-auth", "api-key"] - config: Optional[AuthorizationConfig] = None + config: AuthorizationConfig | None = None class ProcessStatusSetting(BaseModel): @@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel): class ExternalKnowledgeApiSetting(BaseModel): url: str request_method: str - headers: Optional[dict] = None - params: Optional[dict] = None + headers: dict | None = None + params: dict | None = None diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 344c67885e..94ce9d5415 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -11,14 +11,14 @@ class ParentMode(StrEnum): class NotionIcon(BaseModel): type: str - url: Optional[str] = None - emoji: Optional[str] = None + url: str | None = None + emoji: str | None = None class NotionPage(BaseModel): page_id: str page_name: str - page_icon: Optional[NotionIcon] = None + page_icon: NotionIcon | None = None type: str @@ -40,9 +40,9 @@ class FileInfo(BaseModel): class InfoList(BaseModel): data_source_type: Literal["upload_file", "notion_import", "website_crawl"] - notion_info_list: Optional[list[NotionInfo]] = None - file_info_list: Optional[FileInfo] = None - website_info_list: Optional[WebsiteInfo] = None + notion_info_list: list[NotionInfo] | None = None + file_info_list: FileInfo | None = None + website_info_list: WebsiteInfo | None = None class DataSource(BaseModel): @@ -61,20 +61,20 @@ class Segmentation(BaseModel): class Rule(BaseModel): - pre_processing_rules: Optional[list[PreProcessingRule]] = None - segmentation: Optional[Segmentation] = None - parent_mode: Optional[Literal["full-doc", "paragraph"]] = None - subchunk_segmentation: Optional[Segmentation] = None + pre_processing_rules: list[PreProcessingRule] | None = None + segmentation: Segmentation | None = None + parent_mode: Literal["full-doc", "paragraph"] | None = None + subchunk_segmentation: Segmentation | None = None class ProcessRule(BaseModel): mode: Literal["automatic", "custom", "hierarchical"] - rules: Optional[Rule] = None + rules: Rule | None = None class RerankingModel(BaseModel): - reranking_provider_name: Optional[str] = None - reranking_model_name: Optional[str] = None + reranking_provider_name: str | None = None + reranking_model_name: str | None = None class WeightVectorSetting(BaseModel): @@ -88,20 +88,20 @@ class WeightKeywordSetting(BaseModel): class WeightModel(BaseModel): - weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None - vector_setting: Optional[WeightVectorSetting] = None - keyword_setting: Optional[WeightKeywordSetting] = None + weight_type: Literal["semantic_first", "keyword_first", "customized"] | None = None + vector_setting: WeightVectorSetting | None = None + keyword_setting: WeightKeywordSetting | None = None class RetrievalModel(BaseModel): search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"] reranking_enable: bool - reranking_model: Optional[RerankingModel] = None - reranking_mode: Optional[str] = None + reranking_model: RerankingModel | None = None + reranking_mode: str | None = None top_k: int score_threshold_enabled: bool - score_threshold: Optional[float] = None - weights: Optional[WeightModel] = None + score_threshold: float | None = None + weights: WeightModel | None = None class MetaDataConfig(BaseModel): @@ -110,29 +110,29 @@ class MetaDataConfig(BaseModel): class KnowledgeConfig(BaseModel): - original_document_id: Optional[str] = None + original_document_id: str | None = None duplicate: bool = True indexing_technique: Literal["high_quality", "economy"] - data_source: Optional[DataSource] = None - process_rule: Optional[ProcessRule] = None - retrieval_model: Optional[RetrievalModel] = None + data_source: DataSource | None = None + process_rule: ProcessRule | None = None + retrieval_model: RetrievalModel | None = None doc_form: str = "text_model" doc_language: str = "English" - embedding_model: Optional[str] = None - embedding_model_provider: Optional[str] = None - name: Optional[str] = None + embedding_model: str | None = None + embedding_model_provider: str | None = None + name: str | None = None class SegmentUpdateArgs(BaseModel): - content: Optional[str] = None - answer: Optional[str] = None - keywords: Optional[list[str]] = None + content: str | None = None + answer: str | None = None + keywords: list[str] | None = None regenerate_child_chunks: bool = False - enabled: Optional[bool] = None + enabled: bool | None = None class ChildChunkUpdateArgs(BaseModel): - id: Optional[str] = None + id: str | None = None content: str @@ -143,13 +143,13 @@ class MetadataArgs(BaseModel): class MetadataUpdateArgs(BaseModel): name: str - value: Optional[str | int | float] = None + value: str | int | float | None = None class MetadataDetail(BaseModel): id: str name: str - value: Optional[str | int | float] = None + value: str | int | float | None = None class DocumentMetadataOperation(BaseModel): diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 647052d739..49d48f044c 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -42,11 +41,11 @@ class CustomConfigurationResponse(BaseModel): """ status: CustomConfigurationStatus - current_credential_id: Optional[str] = None - current_credential_name: Optional[str] = None - available_credentials: Optional[list[CredentialConfiguration]] = None - custom_models: Optional[list[CustomModelConfiguration]] = None - can_added_models: Optional[list[UnaddedModelConfiguration]] = None + current_credential_id: str | None = None + current_credential_name: str | None = None + available_credentials: list[CredentialConfiguration] | None = None + custom_models: list[CustomModelConfiguration] | None = None + can_added_models: list[UnaddedModelConfiguration] | None = None class SystemConfigurationResponse(BaseModel): @@ -55,7 +54,7 @@ class SystemConfigurationResponse(BaseModel): """ enabled: bool - current_quota_type: Optional[ProviderQuotaType] = None + current_quota_type: ProviderQuotaType | None = None quota_configurations: list[QuotaConfiguration] = [] @@ -67,15 +66,15 @@ class ProviderResponse(BaseModel): tenant_id: str provider: str label: I18nObject - description: Optional[I18nObject] = None - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None - background: Optional[str] = None - help: Optional[ProviderHelpEntity] = None + description: I18nObject | None = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None + background: str | None = None + help: ProviderHelpEntity | None = None supported_model_types: list[ModelType] configurate_methods: list[ConfigurateMethod] - provider_credential_schema: Optional[ProviderCredentialSchema] = None - model_credential_schema: Optional[ModelCredentialSchema] = None + provider_credential_schema: ProviderCredentialSchema | None = None + model_credential_schema: ModelCredentialSchema | None = None preferred_provider_type: ProviderType custom_configuration: CustomConfigurationResponse system_configuration: SystemConfigurationResponse @@ -108,8 +107,8 @@ class ProviderWithModelsResponse(BaseModel): tenant_id: str provider: str label: I18nObject - icon_small: Optional[I18nObject] = None - icon_large: Optional[I18nObject] = None + icon_small: I18nObject | None = None + icon_large: I18nObject | None = None status: CustomConfigurationStatus models: list[ProviderModelWithStatusEntity] diff --git a/api/services/errors/base.py b/api/services/errors/base.py index 35ea28468e..0f9631190f 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,6 +1,3 @@ -from typing import Optional - - class BaseServiceError(ValueError): - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description diff --git a/api/services/errors/llm.py b/api/services/errors/llm.py index ca4c9a611d..5bf34f3aa6 100644 --- a/api/services/errors/llm.py +++ b/api/services/errors/llm.py @@ -1,12 +1,9 @@ -from typing import Optional - - class InvokeError(Exception): """Base class for all LLM exceptions.""" - description: Optional[str] = None + description: str | None = None - def __init__(self, description: Optional[str] = None): + def __init__(self, description: str | None = None): self.description = description def __str__(self): diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 3911b763b6..b6ba3bafea 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from urllib.parse import urlparse import httpx @@ -100,7 +100,7 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: - external_knowledge_api: Optional[ExternalKnowledgeApis] = ( + external_knowledge_api: ExternalKnowledgeApis | None = ( db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first() ) if external_knowledge_api is None: @@ -109,7 +109,7 @@ class ExternalDatasetService: @staticmethod def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: - external_knowledge_api: Optional[ExternalKnowledgeApis] = ( + external_knowledge_api: ExternalKnowledgeApis | None = ( db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() ) if external_knowledge_api is None: @@ -151,7 +151,7 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: - external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ( + external_knowledge_binding: ExternalKnowledgeBindings | None = ( db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() ) if not external_knowledge_binding: @@ -203,7 +203,7 @@ class ExternalDatasetService: return response @staticmethod - def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]: + def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]: authorization = deepcopy(authorization) if headers: headers = deepcopy(headers) @@ -277,7 +277,7 @@ class ExternalDatasetService: dataset_id: str, query: str, external_retrieval_parameters: dict, - metadata_condition: Optional[MetadataCondition] = None, + metadata_condition: MetadataCondition | None = None, ): external_knowledge_binding = ( db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() diff --git a/api/services/message_service.py b/api/services/message_service.py index 13c8e948ca..e2e27443ba 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Union +from typing import Union from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom @@ -29,9 +29,9 @@ class MessageService: def pagination_by_first_id( cls, app_model: App, - user: Optional[Union[Account, EndUser]], + user: Union[Account, EndUser] | None, conversation_id: str, - first_id: Optional[str], + first_id: str | None, limit: int, order: str = "asc", ) -> InfiniteScrollPagination: @@ -91,11 +91,11 @@ class MessageService: def pagination_by_last_id( cls, app_model: App, - user: Optional[Union[Account, EndUser]], - last_id: Optional[str], + user: Union[Account, EndUser] | None, + last_id: str | None, limit: int, - conversation_id: Optional[str] = None, - include_ids: Optional[list] = None, + conversation_id: str | None = None, + include_ids: list | None = None, ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -145,9 +145,9 @@ class MessageService: *, app_model: App, message_id: str, - user: Optional[Union[Account, EndUser]], - rating: Optional[str], - content: Optional[str], + user: Union[Account, EndUser] | None, + rating: str | None, + content: str | None, ): if not user: raise ValueError("user cannot be None") @@ -196,7 +196,7 @@ class MessageService: return [record.to_dict() for record in feedbacks] @classmethod - def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): message = ( db.session.query(Message) .where( @@ -216,7 +216,7 @@ class MessageService: @classmethod def get_suggested_questions_after_answer( - cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom + cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom ) -> list[Message]: if not user: raise ValueError("user cannot be None") @@ -229,7 +229,7 @@ class MessageService: model_manager = ModelManager() - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() if invoke_from == InvokeFrom.DEBUGGER: workflow = workflow_service.get_draft_workflow(app_model=app_model) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 05fa5a95bc..6add830813 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -1,6 +1,5 @@ import copy import logging -from typing import Optional from flask_login import current_user @@ -131,11 +130,11 @@ class MetadataService: @staticmethod def get_built_in_fields(): return [ - {"name": BuiltInField.document_name.value, "type": "string"}, - {"name": BuiltInField.uploader.value, "type": "string"}, - {"name": BuiltInField.upload_date.value, "type": "time"}, - {"name": BuiltInField.last_update_date.value, "type": "time"}, - {"name": BuiltInField.source.value, "type": "string"}, + {"name": BuiltInField.document_name, "type": "string"}, + {"name": BuiltInField.uploader, "type": "string"}, + {"name": BuiltInField.upload_date, "type": "time"}, + {"name": BuiltInField.last_update_date, "type": "time"}, + {"name": BuiltInField.source, "type": "string"}, ] @staticmethod @@ -153,11 +152,11 @@ class MetadataService: doc_metadata = {} else: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata[BuiltInField.document_name.value] = document.name - doc_metadata[BuiltInField.uploader.value] = document.uploader - doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() - doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() - doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value + doc_metadata[BuiltInField.document_name] = document.name + doc_metadata[BuiltInField.uploader] = document.uploader + doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp() + doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() + doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata db.session.add(document) dataset.built_in_field_enabled = True @@ -183,11 +182,11 @@ class MetadataService: doc_metadata = {} else: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata.pop(BuiltInField.document_name.value, None) - doc_metadata.pop(BuiltInField.uploader.value, None) - doc_metadata.pop(BuiltInField.upload_date.value, None) - doc_metadata.pop(BuiltInField.last_update_date.value, None) - doc_metadata.pop(BuiltInField.source.value, None) + doc_metadata.pop(BuiltInField.document_name, None) + doc_metadata.pop(BuiltInField.uploader, None) + doc_metadata.pop(BuiltInField.upload_date, None) + doc_metadata.pop(BuiltInField.last_update_date, None) + doc_metadata.pop(BuiltInField.source, None) document.doc_metadata = doc_metadata db.session.add(document) document_ids.append(document.id) @@ -211,11 +210,11 @@ class MetadataService: for metadata_value in operation.metadata_list: doc_metadata[metadata_value.name] = metadata_value.value if dataset.built_in_field_enabled: - doc_metadata[BuiltInField.document_name.value] = document.name - doc_metadata[BuiltInField.uploader.value] = document.uploader - doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() - doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() - doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value + doc_metadata[BuiltInField.document_name] = document.name + doc_metadata[BuiltInField.uploader] = document.uploader + doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp() + doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() + doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata db.session.add(document) db.session.commit() @@ -237,7 +236,7 @@ class MetadataService: redis_client.delete(lock_key) @staticmethod - def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]): + def knowledge_base_metadata_lock_check(dataset_id: str | None, document_id: str | None): if dataset_id: lock_key = f"dataset_metadata_lock_{dataset_id}" if redis_client.get(lock_key): diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 33d7dacba0..69da3bfb79 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,7 +1,7 @@ import json import logging from json import JSONDecodeError -from typing import Optional, Union +from typing import Union from sqlalchemy import or_, select @@ -211,7 +211,7 @@ class ModelLoadBalancingService: def get_load_balancing_config( self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str - ) -> Optional[dict]: + ) -> dict | None: """ Get load balancing configuration. :param tenant_id: workspace id @@ -478,7 +478,7 @@ class ModelLoadBalancingService: model: str, model_type: str, credentials: dict, - config_id: Optional[str] = None, + config_id: str | None = None, ): """ Validate load balancing credentials. @@ -536,7 +536,7 @@ class ModelLoadBalancingService: model_type: ModelType, model: str, credentials: dict, - load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, + load_balancing_model_config: LoadBalancingModelConfig | None = None, validate: bool = True, ): """ diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 510b1f1fe6..2901a0d273 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.model_runtime.entities.model_entities import ModelType, ParameterRule @@ -52,7 +51,7 @@ class ModelProviderService: return provider_configuration - def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]: + def get_provider_list(self, tenant_id: str, model_type: str | None = None) -> list[ProviderResponse]: """ get provider list. @@ -128,9 +127,7 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] - def get_provider_credential( - self, tenant_id: str, provider: str, credential_id: Optional[str] = None - ) -> Optional[dict]: + def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: """ get provider credentials. @@ -216,7 +213,7 @@ class ModelProviderService: def get_model_credential( self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None - ) -> Optional[dict]: + ) -> dict | None: """ Retrieve model-specific credentials. @@ -449,7 +446,7 @@ class ModelProviderService: return model_schema.parameter_rules if model_schema else [] - def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: + def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> DefaultModelResponse | None: """ get default model of model type. @@ -498,7 +495,7 @@ class ModelProviderService: def get_model_provider_icon( self, tenant_id: str, provider: str, icon_type: str, lang: str - ) -> tuple[Optional[bytes], Optional[str]]: + ) -> tuple[bytes | None, str | None]: """ get model provider icon. diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 2596e9f711..c214640653 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from core.ops.entities.config_entity import BaseTracingConfig from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map @@ -15,7 +15,7 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config_data: Optional[TraceAppConfig] = ( + trace_config_data: TraceAppConfig | None = ( db.session.query(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -153,7 +153,7 @@ class OpsService: project_url = None # check if trace config already exists - trace_config_data: Optional[TraceAppConfig] = ( + trace_config_data: TraceAppConfig | None = ( db.session.query(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() diff --git a/api/services/plugin/github_service.py b/api/services/plugin/github_service.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index bae2921a27..fcfa52371d 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -5,7 +5,7 @@ import time from collections.abc import Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, Optional +from typing import Any from uuid import uuid4 import click @@ -256,7 +256,7 @@ class PluginMigration: return [] agent_app_model_config_ids = [ - app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value + app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT ] rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all() @@ -281,7 +281,7 @@ class PluginMigration: return result @classmethod - def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]: + def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> str | None: """ Fetch plugin unique identifier using plugin id. """ diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 9005f0669b..3b7ce20f83 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -1,7 +1,6 @@ import logging from collections.abc import Mapping, Sequence from mimetypes import guess_type -from typing import Optional from pydantic import BaseModel @@ -46,11 +45,11 @@ class PluginService: REDIS_TTL = 60 * 5 # 5 minutes @staticmethod - def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]: + def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]: """ Fetch the latest plugin version """ - result: dict[str, Optional[PluginService.LatestPluginCache]] = {} + result: dict[str, PluginService.LatestPluginCache | None] = {} try: cache_not_exists = [] @@ -109,7 +108,7 @@ class PluginService: raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only") @staticmethod - def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]): + def _check_plugin_installation_scope(plugin_verification: PluginVerification | None): """ Check the plugin installation scope """ @@ -144,7 +143,7 @@ class PluginService: return manager.get_debugging_key(tenant_id) @staticmethod - def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]: + def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]: """ List the latest versions of the plugins """ diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index df9e01e273..64751d186c 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -1,7 +1,6 @@ import json from os import path from pathlib import Path -from typing import Optional from flask import current_app @@ -14,7 +13,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): Retrieval recommended app from buildin, the location is constants/recommended_apps.json """ - builtin_data: Optional[dict] = None + builtin_data: dict | None = None def get_type(self) -> str: return RecommendAppType.BUILDIN @@ -54,7 +53,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): return builtin_data.get("recommended_apps", {}).get(language, {}) @classmethod - def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: + def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None: """ Fetch recommended app detail from builtin. :param app_id: App ID diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index a9733e0826..d0c49325dc 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -1,5 +1,3 @@ -from typing import Optional - from sqlalchemy import select from constants.languages import languages @@ -72,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} @classmethod - def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None: """ Fetch recommended app detail from db. :param app_id: App ID diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index 1e59287429..2d57769f63 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import requests @@ -36,7 +35,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): return RecommendAppType.REMOTE @classmethod - def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]: + def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None: """ Fetch recommended app detail from dify official. :param app_id: App ID diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index d9c1b51fa1..544383a106 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,5 +1,3 @@ -from typing import Optional - from configs import dify_config from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory @@ -25,7 +23,7 @@ class RecommendedAppService: return result @classmethod - def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: + def get_recommend_app_detail(cls, app_id: str) -> dict | None: """ Get recommend app detail. :param app_id: app id diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 641e03c3cf..67a0106bbd 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -11,7 +11,7 @@ from services.message_service import MessageService class SavedMessageService: @classmethod def pagination_by_last_id( - cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int + cls, app_model: App, user: Union[Account, EndUser] | None, last_id: str | None, limit: int ) -> InfiniteScrollPagination: if not user: raise ValueError("User is required") @@ -32,7 +32,7 @@ class SavedMessageService: ) @classmethod - def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): if not user: return saved_message = ( @@ -62,7 +62,7 @@ class SavedMessageService: db.session.commit() @classmethod - def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): if not user: return saved_message = ( diff --git a/api/services/tag_service.py b/api/services/tag_service.py index dd67b19966..4674335ba8 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,5 +1,4 @@ import uuid -from typing import Optional from flask_login import current_user from sqlalchemy import func, select @@ -12,7 +11,7 @@ from models.model import App, Tag, TagBinding class TagService: @staticmethod - def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None): + def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None): query = ( db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index cb31111485..9db71dcd09 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -3,7 +3,7 @@ import logging import re from collections.abc import Mapping from pathlib import Path -from typing import Any, Optional +from typing import Any from sqlalchemy import exists, select from sqlalchemy.orm import Session @@ -604,7 +604,7 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) @staticmethod - def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: + def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: """ This method is used to fetch the builtin provider from the database 1.if the default provider exists, return the default provider @@ -665,8 +665,8 @@ class BuiltinToolManageService: def save_custom_oauth_client_params( tenant_id: str, provider: str, - client_params: Optional[dict] = None, - enable_oauth_custom_client: Optional[bool] = None, + client_params: dict | None = None, + enable_oauth_custom_client: bool | None = None, ): """ setup oauth custom client diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 2a9da94c52..59a0e8c6cb 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -173,12 +173,15 @@ class MCPToolManageService: if sse_read_timeout is not None: mcp_provider.sse_read_timeout = sse_read_timeout if headers is not None: - mcp_provider.encrypted_headers = ( - self._prepare_encrypted_headers(headers, tenant_id) if headers else None - ) - + if headers: + # Build headers preserving unchanged masked values + final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider) + encrypted_headers_dict = self._prepare_encrypted_headers(final_headers, tenant_id) + mcp_provider.encrypted_headers = encrypted_headers_dict + else: + # Clear headers if empty dict passed + mcp_provider.encrypted_headers = None self._session.commit() - except IntegrityError as e: self._session.rollback() self._handle_integrity_error(e, name, server_url, server_identifier) @@ -357,7 +360,7 @@ class MCPToolManageService: def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]: """Prepare headers with OAuth token if available.""" - headers = provider_entity.headers.copy() if provider_entity.headers else {} + headers = provider_entity.decrypt_headers() tokens = provider_entity.retrieve_tokens() if tokens: headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}" @@ -436,3 +439,25 @@ class MCPToolManageService: if "unique_mcp_provider_server_identifier" in error_msg: raise ValueError(f"MCP tool {server_identifier} already exists") raise + + def _merge_headers_with_masked( + self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider + ) -> dict[str, str]: + """Merge incoming headers with existing ones, preserving unchanged masked values. + + Args: + incoming_headers: Headers from frontend (may contain masked values) + mcp_provider: The MCP provider instance + + Returns: + Final headers dict with proper values (original for unchanged masked, new for changed) + """ + mcp_provider_entity = mcp_provider.to_entity() + existing_decrypted = mcp_provider_entity.decrypt_headers() + existing_masked = mcp_provider_entity.masked_headers() + + return { + key: (str(existing_decrypted[key]) if key in existing_masked and value == existing_masked[key] else value) + for key, value in incoming_headers.items() + if key in existing_decrypted or value != existing_masked.get(key) + } diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index f245dd7527..51e9120b8d 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.tool_manager import ToolManager @@ -10,7 +9,7 @@ logger = logging.getLogger(__name__) class ToolCommonService: @staticmethod - def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None): + def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None): """ list tool providers diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 1692e10889..f2c166231a 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from yarl import URL @@ -94,7 +94,7 @@ class ToolTransformService: def builtin_provider_to_user_provider( cls, provider_controller: BuiltinToolProviderController | PluginToolProviderController, - db_provider: Optional[BuiltinToolProvider], + db_provider: BuiltinToolProvider | None, decrypt_credentials: bool = True, ) -> ToolProviderApiEntity: """ diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 428abdde17..1c559f2c2b 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -19,7 +18,7 @@ logger = logging.getLogger(__name__) class VectorService: @classmethod def create_segments_vector( - cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str + cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents: list[Document] = [] @@ -79,7 +78,7 @@ class VectorService: index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) @classmethod - def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): + def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset): # update segment index task # format new index diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index c48e24f244..0f54e838f3 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -19,11 +19,11 @@ class WebConversationService: *, session: Session, app_model: App, - user: Optional[Union[Account, EndUser]], - last_id: Optional[str], + user: Union[Account, EndUser] | None, + last_id: str | None, limit: int, invoke_from: InvokeFrom, - pinned: Optional[bool] = None, + pinned: bool | None = None, sort_by="-updated_at", ) -> InfiniteScrollPagination: if not user: @@ -60,7 +60,7 @@ class WebConversationService: ) @classmethod - def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return pinned_conversation = ( @@ -92,7 +92,7 @@ class WebConversationService: db.session.commit() @classmethod - def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return pinned_conversation = ( diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index bb46bf3090..066dc9d741 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -1,7 +1,7 @@ import enum import secrets from datetime import UTC, datetime, timedelta -from typing import Any, Optional +from typing import Any from werkzeug.exceptions import NotFound, Unauthorized @@ -63,7 +63,7 @@ class WebAppAuthService: @classmethod def send_email_code_login_email( - cls, account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US" + cls, account: Account | None = None, email: str | None = None, language: str = "en-US" ): email = account.email if account else email if email is None: @@ -82,7 +82,7 @@ class WebAppAuthService: return token @classmethod - def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "email_code_login") @classmethod @@ -130,7 +130,7 @@ class WebAppAuthService: @classmethod def is_app_require_permission_check( - cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None + cls, app_code: str | None = None, app_id: str | None = None, access_mode: str | None = None ) -> bool: """ Check if the app requires permission check based on its access mode. diff --git a/api/services/website_service.py b/api/services/website_service.py index 131b96db13..2dc049fc72 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,7 +1,7 @@ import datetime import json from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import requests from flask_login import current_user @@ -21,9 +21,9 @@ class CrawlOptions: limit: int = 1 crawl_sub_pages: bool = False only_main_content: bool = False - includes: Optional[str] = None - excludes: Optional[str] = None - max_depth: Optional[int] = None + includes: str | None = None + excludes: str | None = None + max_depth: int | None = None use_sitemap: bool = True def get_include_paths(self) -> list[str]: diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 8a58289b22..9ce5b6dbe0 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional +from typing import Any from core.app.app_config.entities import ( DatasetEntity, @@ -65,7 +65,7 @@ class WorkflowConverter: new_app = App() new_app.tenant_id = app_model.tenant_id new_app.name = name or app_model.name + "(workflow)" - new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW new_app.icon_type = icon_type or app_model.icon_type new_app.icon = icon or app_model.icon new_app.icon_background = icon_background or app_model.icon_background @@ -203,7 +203,7 @@ class WorkflowConverter: app_mode_enum = AppMode.value_of(app_model.mode) app_config: EasyUIBasedAppConfig if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: - app_model.mode = AppMode.AGENT_CHAT.value + app_model.mode = AppMode.AGENT_CHAT app_config = AgentChatAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) @@ -279,7 +279,7 @@ class WorkflowConverter: "app_id": app_model.id, "tool_variable": tool_variable, "inputs": inputs, - "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "", + "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT else "", }, } @@ -327,7 +327,7 @@ class WorkflowConverter: def _convert_to_knowledge_retrieval_node( self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity - ) -> Optional[dict]: + ) -> dict | None: """ Convert datasets to Knowledge Retrieval Node :param new_app_mode: new app mode @@ -383,7 +383,7 @@ class WorkflowConverter: graph: dict, model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, - file_upload: Optional[FileUploadConfig] = None, + file_upload: FileUploadConfig | None = None, external_data_variable_node_mapping: dict[str, str] | None = None, ): """ @@ -403,7 +403,7 @@ class WorkflowConverter: ) role_prefix = None - prompts: Optional[Any] = None + prompts: Any | None = None # Chat Model if model_config.mode == LLMMode.CHAT.value: @@ -618,7 +618,7 @@ class WorkflowConverter: :param app_model: App instance :return: AppMode """ - if app_model.mode == AppMode.COMPLETION.value: + if app_model.mode == AppMode.COMPLETION: return AppMode.WORKFLOW else: return AppMode.ADVANCED_CHAT diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index e43999a8c9..79d91cab4c 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,6 +1,5 @@ import threading from collections.abc import Sequence -from typing import Optional from sqlalchemy.orm import sessionmaker @@ -80,7 +79,7 @@ class WorkflowRunService: last_id=last_id, ) - def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: + def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None: """ Get workflow run detail diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 4e0ae15841..ea73b6105e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,7 +2,7 @@ import json import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from uuid import uuid4 from sqlalchemy import exists, select @@ -88,7 +88,7 @@ class WorkflowService: ) return db.session.execute(stmt).scalar_one() - def get_draft_workflow(self, app_model: App, workflow_id: Optional[str] = None) -> Optional[Workflow]: + def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None: """ Get draft workflow """ @@ -108,7 +108,7 @@ class WorkflowService: # return draft workflow return workflow - def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None: """ fetch published workflow by workflow_id """ @@ -130,7 +130,7 @@ class WorkflowService: ) return workflow - def get_published_workflow(self, app_model: App) -> Optional[Workflow]: + def get_published_workflow(self, app_model: App) -> Workflow | None: """ Get published workflow """ @@ -195,7 +195,7 @@ class WorkflowService: app_model: App, graph: dict, features: dict, - unique_hash: Optional[str], + unique_hash: str | None, account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], @@ -375,13 +375,14 @@ class WorkflowService: def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None: """ - Validate that an LLM model configuration can fetch valid credentials. + Validate that an LLM model configuration can fetch valid credentials and has active status. This method attempts to get the model instance and validates that: 1. The provider exists and is configured 2. The model exists in the provider 3. Credentials can be fetched for the model 4. The credentials pass policy compliance checks + 5. The model status is ACTIVE (not NO_CONFIGURE, DISABLED, etc.) :param tenant_id: The tenant ID :param provider: The provider name @@ -391,6 +392,7 @@ class WorkflowService: try: from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType + from core.provider_manager import ProviderManager # Get model instance to validate provider+model combination model_manager = ModelManager() @@ -402,6 +404,22 @@ class WorkflowService: # via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance() # If it fails, an exception will be raised + # Additionally, check the model status to ensure it's ACTIVE + provider_manager = ProviderManager() + provider_configurations = provider_manager.get_configurations(tenant_id) + models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM) + + target_model = None + for model in models: + if model.model == model_name and model.provider.provider == provider: + target_model = model + break + + if target_model: + target_model.raise_for_status() + else: + raise ValueError(f"Model {model_name} not found for provider {provider}") + except Exception as e: raise ValueError( f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}" @@ -561,7 +579,7 @@ class WorkflowService: return default_block_configs - def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: + def get_default_block_config(self, node_type: str, filters: dict | None = None) -> dict | None: """ Get default config of node. :param node_type: node type @@ -828,7 +846,7 @@ class WorkflowService: # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: + if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}: raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow @@ -844,11 +862,11 @@ class WorkflowService: return new_app def validate_features_structure(self, app_model: App, features: dict): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: return AdvancedChatAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: return WorkflowAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) @@ -857,7 +875,7 @@ class WorkflowService: def update_workflow( self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict - ) -> Optional[Workflow]: + ) -> Workflow | None: """ Update workflow attributes diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 761ac6fc3d..62200715cc 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -1,6 +1,5 @@ import logging import time -from typing import Optional import click from celery import shared_task @@ -17,7 +16,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: Optional[str]): +def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: str | None): """ Clean document when document deleted. :param document_id: document id diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 986e9dbc3c..6b2907cffd 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -1,6 +1,5 @@ import logging import time -from typing import Optional import click from celery import shared_task @@ -16,7 +15,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] = None): +def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = None): """ Async create segment to index :param segment_id: diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 0b750cf4db..e8cbd0f250 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -12,7 +12,9 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str): +def delete_segment_from_index_task( + index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None +): """ Async Remove segment from index :param index_node_ids: @@ -26,6 +28,7 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume try: dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: + logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) return dataset_document = db.session.query(Document).where(Document.id == document_id).first() @@ -33,11 +36,19 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info("Document not in valid state for index operations, skipping") return + doc_form = dataset_document.doc_form - index_type = dataset_document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + # Proceed with index cleanup using the index_node_ids directly + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean( + dataset, + index_node_ids, + with_keywords=True, + delete_child_chunks=True, + precomputed_child_node_ids=child_node_ids, + ) end_at = time.perf_counter() logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/mail_register_task.py b/api/tasks/mail_register_task.py new file mode 100644 index 0000000000..a9472a6119 --- /dev/null +++ b/api/tasks/mail_register_task.py @@ -0,0 +1,87 @@ +import logging +import time + +import click +from celery import shared_task + +from configs import dify_config +from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service + +logger = logging.getLogger(__name__) + + +@shared_task(queue="mail") +def send_email_register_mail_task(language: str, to: str, code: str) -> None: + """ + Send email register email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + code: Email register code + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start email register mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.EMAIL_REGISTER, + language_code=language, + to=to, + template_context={ + "to": to, + "code": code, + }, + ) + + end_at = time.perf_counter() + logger.info( + click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green") + ) + except Exception: + logger.exception("Send email register mail to %s failed", to) + + +@shared_task(queue="mail") +def send_email_register_mail_task_when_account_exist(language: str, to: str, account_name: str) -> None: + """ + Send email register email with internationalization support when account exist. + + Args: + language: Language code for email localization + to: Recipient email address + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start email register mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + login_url = f"{dify_config.CONSOLE_WEB_URL}/signin" + reset_password_url = f"{dify_config.CONSOLE_WEB_URL}/reset-password" + + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST, + language_code=language, + to=to, + template_context={ + "to": to, + "login_url": login_url, + "reset_password_url": reset_password_url, + "account_name": account_name, + }, + ) + + end_at = time.perf_counter() + logger.info( + click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green") + ) + except Exception: + logger.exception("Send email register mail to %s failed", to) diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 545db84fde..1739562588 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -4,6 +4,7 @@ import time import click from celery import shared_task +from configs import dify_config from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service @@ -44,3 +45,47 @@ def send_reset_password_mail_task(language: str, to: str, code: str): ) except Exception: logger.exception("Send password reset mail to %s failed", to) + + +@shared_task(queue="mail") +def send_reset_password_mail_task_when_account_not_exist(language: str, to: str, is_allow_register: bool) -> None: + """ + Send reset password email with internationalization support when account not exist. + + Args: + language: Language code for email localization + to: Recipient email address + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start password reset mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + if is_allow_register: + sign_up_url = f"{dify_config.CONSOLE_WEB_URL}/signup" + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST, + language_code=language, + to=to, + template_context={ + "to": to, + "sign_up_url": sign_up_url, + }, + ) + else: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER, + language_code=language, + to=to, + ) + + end_at = time.perf_counter() + logger.info( + click.style(f"Send password reset mail to {to} succeeded: latency: {end_at - start_at}", fg="green") + ) + except Exception: + logger.exception("Send password reset mail to %s failed", to) diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index d871b297e0..bae8f1c4db 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -1,3 +1,4 @@ +import operator import traceback import typing @@ -118,7 +119,7 @@ def process_tenant_plugin_autoupgrade_check_task( current_version = version latest_version = manifest.latest_version - def fix_only_checker(latest_version, current_version): + def fix_only_checker(latest_version: str, current_version: str): latest_version_tuple = tuple(int(val) for val in latest_version.split(".")) current_version_tuple = tuple(int(val) for val in current_version.split(".")) @@ -130,8 +131,7 @@ def process_tenant_plugin_autoupgrade_check_task( return False version_checker = { - TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version, - current_version: latest_version != current_version, + TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne, TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker, } diff --git a/api/templates/register_email_template_en-US.html b/api/templates/register_email_template_en-US.html new file mode 100644 index 0000000000..e0fec59100 --- /dev/null +++ b/api/templates/register_email_template_en-US.html @@ -0,0 +1,87 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Dify Sign-up Code

+

Your sign-up code for Dify + + Copy and paste this code, this code will only be valid for the next 5 minutes.

+
+ {{code}} +
+

If you didn't request this code, don't worry. You can safely ignore this email.

+
+ + + \ No newline at end of file diff --git a/api/templates/register_email_template_zh-CN.html b/api/templates/register_email_template_zh-CN.html new file mode 100644 index 0000000000..3b507290f0 --- /dev/null +++ b/api/templates/register_email_template_zh-CN.html @@ -0,0 +1,87 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Dify 注册验证码

+

您的 Dify 注册验证码 + + 复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。

+
+ {{code}} +
+

如果您没有请求,请不要担心。您可以安全地忽略此电子邮件。

+
+ + + \ No newline at end of file diff --git a/api/templates/register_email_when_account_exist_template_en-US.html b/api/templates/register_email_when_account_exist_template_en-US.html new file mode 100644 index 0000000000..ac5042c274 --- /dev/null +++ b/api/templates/register_email_when_account_exist_template_en-US.html @@ -0,0 +1,130 @@ + + + + + + + + +
+
+ + Dify Logo +
+

It looks like you’re signing up with an existing account

+

Hi, {{account_name}}

+

+ We noticed you tried to sign up, but this email is already registered with an existing account. + + Please log in here:

+ Log In +

+ If you forgot your password, you can reset it here: Reset Password +

+

+ If you didn’t request this action, you can safely ignore this email. +

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + diff --git a/api/templates/register_email_when_account_exist_template_zh-CN.html b/api/templates/register_email_when_account_exist_template_zh-CN.html new file mode 100644 index 0000000000..326b58343a --- /dev/null +++ b/api/templates/register_email_when_account_exist_template_zh-CN.html @@ -0,0 +1,127 @@ + + + + + + + + +
+
+ + Dify Logo +
+

您似乎正在使用现有账户注册

+

您好,{{account_name}}

+

+ 我们注意到您尝试注册,但此电子邮件已注册。 + + 请在此登录:

+ 登录 +

+ 如果您忘记了密码,可以在此重置: 重置密码 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + diff --git a/api/templates/reset_password_mail_when_account_not_exist_no_register_template_en-US.html b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_en-US.html new file mode 100644 index 0000000000..1c5253a239 --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_en-US.html @@ -0,0 +1,122 @@ + + + + + + + + +
+
+ + Dify Logo +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. +

+

If you didn’t request this action, you can safely ignore this email.

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + \ No newline at end of file diff --git a/api/templates/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html new file mode 100644 index 0000000000..1431291218 --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html @@ -0,0 +1,121 @@ + + + + + + + + +
+
+ + Dify Logo +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + \ No newline at end of file diff --git a/api/templates/reset_password_mail_when_account_not_exist_template_en-US.html b/api/templates/reset_password_mail_when_account_not_exist_template_en-US.html new file mode 100644 index 0000000000..5759d56f7c --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_template_en-US.html @@ -0,0 +1,124 @@ + + + + + + + + +
+
+ + Dify Logo +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. + + Please sign up here:

+ Sign Up +

If you didn’t request this action, you can safely ignore this email.

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + \ No newline at end of file diff --git a/api/templates/reset_password_mail_when_account_not_exist_template_zh-CN.html b/api/templates/reset_password_mail_when_account_not_exist_template_zh-CN.html new file mode 100644 index 0000000000..4de4a8abaa --- /dev/null +++ b/api/templates/reset_password_mail_when_account_not_exist_template_zh-CN.html @@ -0,0 +1,126 @@ + + + + + + + + +
+
+ + Dify Logo +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。 + + 请在此注册:

+

+ 注册 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + \ No newline at end of file diff --git a/api/templates/without-brand/register_email_template_en-US.html b/api/templates/without-brand/register_email_template_en-US.html new file mode 100644 index 0000000000..bd67c8ff4a --- /dev/null +++ b/api/templates/without-brand/register_email_template_en-US.html @@ -0,0 +1,83 @@ + + + + + + + + +
+

{{application_title}} Sign-up Code

+

Your sign-up code + + Copy and paste this code, this code will only be valid for the next 5 minutes.

+
+ {{code}} +
+

If you didn't request this code, don't worry. You can safely ignore this email.

+
+ + + diff --git a/api/templates/without-brand/register_email_template_zh-CN.html b/api/templates/without-brand/register_email_template_zh-CN.html new file mode 100644 index 0000000000..26df4760aa --- /dev/null +++ b/api/templates/without-brand/register_email_template_zh-CN.html @@ -0,0 +1,83 @@ + + + + + + + + +
+

{{application_title}} 注册验证码

+

您的 {{application_title}} 注册验证码 + + 复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。

+
+ {{code}} +
+

如果您没有请求此验证码,请不要担心。您可以安全地忽略此电子邮件。

+
+ + + \ No newline at end of file diff --git a/api/templates/without-brand/register_email_when_account_exist_template_en-US.html b/api/templates/without-brand/register_email_when_account_exist_template_en-US.html new file mode 100644 index 0000000000..2e74956e14 --- /dev/null +++ b/api/templates/without-brand/register_email_when_account_exist_template_en-US.html @@ -0,0 +1,126 @@ + + + + + + + + +
+

It looks like you’re signing up with an existing account

+

Hi, {{account_name}}

+

+ We noticed you tried to sign up, but this email is already registered with an existing account. + + Please log in here:

+ Log In +

+ If you forgot your password, you can reset it here: Reset Password +

+

+ If you didn’t request this action, you can safely ignore this email. +

+
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + diff --git a/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html b/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html new file mode 100644 index 0000000000..a315f9154d --- /dev/null +++ b/api/templates/without-brand/register_email_when_account_exist_template_zh-CN.html @@ -0,0 +1,123 @@ + + + + + + + + +
+

您似乎正在使用现有账户注册

+

您好,{{account_name}}

+

+ 我们注意到您尝试注册,但此电子邮件已注册。 + + 请在此登录:

+ 登录 +

+ 如果您忘记了密码,可以在此重置: 重置密码 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html new file mode 100644 index 0000000000..ae59f36332 --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html @@ -0,0 +1,118 @@ + + + + + + + + +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. +

+

If you didn’t request this action, you can safely ignore this email.

+
+
Please do not reply directly to this email, it is automatically sent by the system.
s + + + diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html new file mode 100644 index 0000000000..4b4fda2c6e --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html @@ -0,0 +1,118 @@ + + + + + + + + +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。 +

+

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_en-US.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_en-US.html new file mode 100644 index 0000000000..fedc998809 --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_en-US.html @@ -0,0 +1,121 @@ + + + + + + + + +
+

It looks like you’re resetting a password with an unregistered email

+

Hi,

+

+ We noticed you tried to reset your password, but this email is not associated with any account. + + Please sign up here:

+ Sign Up +

If you didn’t request this action, you can safely ignore this email.

+ +
+
Please do not reply directly to this email, it is automatically sent by the system.
+ + + \ No newline at end of file diff --git a/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html new file mode 100644 index 0000000000..2464b4a058 --- /dev/null +++ b/api/templates/without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html @@ -0,0 +1,120 @@ + + + + + + + + +
+

看起来您正在使用未注册的电子邮件重置密码

+

您好,

+

+ 我们注意到您尝试重置密码,但此电子邮件未与任何账户关联。 + + 请在此注册:

+ 注册 +

如果您没有请求此操作,您可以安全地忽略此电子邮件。

+
+
请不要直接回复此电子邮件,它是由系统自动发送的。
+ + + \ No newline at end of file diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 2e98dec964..92df93fb13 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -203,6 +203,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5 CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py new file mode 100644 index 0000000000..524713fbf1 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -0,0 +1,101 @@ +"""Integration tests for ChatMessageApi permission verification.""" + +import uuid +from unittest import mock + +import pytest +from flask.testing import FlaskClient + +from controllers.console.app import completion as completion_api +from controllers.console.app import wraps +from libs.datetime_utils import naive_utc_now +from models import Account, App, Tenant +from models.account import TenantAccountRole +from models.model import AppMode +from services.app_generate_service import AppGenerateService + + +class TestChatMessageApiPermissions: + """Test permission verification for ChatMessageApi endpoint.""" + + @pytest.fixture + def mock_app_model(self): + """Create a mock App model for testing.""" + app = App() + app.id = str(uuid.uuid4()) + app.mode = AppMode.CHAT.value + app.tenant_id = str(uuid.uuid4()) + app.status = "normal" + return app + + @pytest.fixture + def mock_account(self): + """Create a mock Account for testing.""" + + account = Account() + account.id = str(uuid.uuid4()) + account.name = "Test User" + account.email = "test@example.com" + account.last_active_at = naive_utc_now() + account.created_at = naive_utc_now() + account.updated_at = naive_utc_now() + + # Create mock tenant + tenant = Tenant() + tenant.id = str(uuid.uuid4()) + tenant.name = "Test Tenant" + + account._current_tenant = tenant + return account + + @pytest.mark.parametrize( + ("role", "status"), + [ + (TenantAccountRole.OWNER, 200), + (TenantAccountRole.ADMIN, 200), + (TenantAccountRole.EDITOR, 200), + (TenantAccountRole.NORMAL, 403), + (TenantAccountRole.DATASET_OPERATOR, 403), + ], + ) + def test_post_with_owner_role_succeeds( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_app_model, + mock_account, + role: TenantAccountRole, + status: int, + ): + """Test that OWNER role can access chat-messages endpoint.""" + + """Setup common mocks for testing.""" + # Mock app loading + + mock_load_app_model = mock.Mock(return_value=mock_app_model) + monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + + # Mock current user + monkeypatch.setattr(completion_api, "current_user", mock_account) + + mock_generate = mock.Mock(return_value={"message": "Test response"}) + monkeypatch.setattr(AppGenerateService, "generate", mock_generate) + + # Set user role to OWNER + mock_account.role = role + + response = test_client.post( + f"/console/api/apps/{mock_app_model.id}/chat-messages", + headers=auth_header, + json={ + "inputs": {}, + "query": "Hello, world!", + "model_config": { + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}} + }, + "response_mode": "blocking", + }, + ) + + assert response.status_code == status diff --git a/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py new file mode 100644 index 0000000000..ca4d452963 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py @@ -0,0 +1,129 @@ +"""Integration tests for ModelConfigResource permission verification.""" + +import uuid +from unittest import mock + +import pytest +from flask.testing import FlaskClient + +from controllers.console.app import model_config as model_config_api +from controllers.console.app import wraps +from libs.datetime_utils import naive_utc_now +from models import Account, App, Tenant +from models.account import TenantAccountRole +from models.model import AppMode +from services.app_model_config_service import AppModelConfigService + + +class TestModelConfigResourcePermissions: + """Test permission verification for ModelConfigResource endpoint.""" + + @pytest.fixture + def mock_app_model(self): + """Create a mock App model for testing.""" + app = App() + app.id = str(uuid.uuid4()) + app.mode = AppMode.CHAT.value + app.tenant_id = str(uuid.uuid4()) + app.status = "normal" + app.app_model_config_id = str(uuid.uuid4()) + return app + + @pytest.fixture + def mock_account(self): + """Create a mock Account for testing.""" + + account = Account() + account.id = str(uuid.uuid4()) + account.name = "Test User" + account.email = "test@example.com" + account.last_active_at = naive_utc_now() + account.created_at = naive_utc_now() + account.updated_at = naive_utc_now() + + # Create mock tenant + tenant = Tenant() + tenant.id = str(uuid.uuid4()) + tenant.name = "Test Tenant" + + account._current_tenant = tenant + return account + + @pytest.mark.parametrize( + ("role", "status"), + [ + (TenantAccountRole.OWNER, 200), + (TenantAccountRole.ADMIN, 200), + (TenantAccountRole.EDITOR, 200), + (TenantAccountRole.NORMAL, 403), + (TenantAccountRole.DATASET_OPERATOR, 403), + ], + ) + def test_post_with_owner_role_succeeds( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_app_model, + mock_account, + role: TenantAccountRole, + status: int, + ): + """Test that OWNER role can access model-config endpoint.""" + # Set user role to OWNER + mock_account.role = role + + # Mock app loading + mock_load_app_model = mock.Mock(return_value=mock_app_model) + monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + + # Mock current user + monkeypatch.setattr(model_config_api, "current_user", mock_account) + + # Mock AccountService.load_user to prevent authentication issues + from services.account_service import AccountService + + mock_load_user = mock.Mock(return_value=mock_account) + monkeypatch.setattr(AccountService, "load_user", mock_load_user) + + mock_validate_config = mock.Mock( + return_value={ + "model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}}, + "pre_prompt": "You are a helpful assistant.", + "user_input_form": [], + "dataset_query_variable": "", + "agent_mode": {"enabled": False, "tools": []}, + } + ) + monkeypatch.setattr(AppModelConfigService, "validate_configuration", mock_validate_config) + + # Mock database operations + mock_db_session = mock.Mock() + mock_db_session.add = mock.Mock() + mock_db_session.flush = mock.Mock() + mock_db_session.commit = mock.Mock() + monkeypatch.setattr(model_config_api.db, "session", mock_db_session) + + # Mock app_model_config_was_updated event + mock_event = mock.Mock() + mock_event.send = mock.Mock() + monkeypatch.setattr(model_config_api, "app_model_config_was_updated", mock_event) + + response = test_client.post( + f"/console/api/apps/{mock_app_model.id}/model-config", + headers=auth_header, + json={ + "model": { + "provider": "openai", + "name": "gpt-4", + "mode": "chat", + "completion_params": {"temperature": 0.7, "max_tokens": 1000}, + }, + "user_input_form": [], + "dataset_query_variable": "", + "pre_prompt": "You are a helpful assistant.", + "agent_mode": {"enabled": False, "tools": []}, + }, + ) + + assert response.status_code == status diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index 0fb7076c85..bc64fda9c2 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -1,6 +1,5 @@ import unittest from datetime import UTC, datetime -from typing import Optional from unittest.mock import patch from uuid import uuid4 @@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase): self.session.rollback() def _create_upload_file( - self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None ) -> UploadFile: """Helper method to create an UploadFile record for testing.""" if file_id is None: @@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase): return upload_file def _create_tool_file( - self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None ) -> ToolFile: """Helper method to create a ToolFile record for testing.""" if file_id is None: @@ -101,9 +100,7 @@ class TestStorageKeyLoader(unittest.TestCase): return tool_file - def _create_file( - self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None - ) -> File: + def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File: """Helper method to create a File object for testing.""" if tenant_id is None: tenant_id = self.tenant_id diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index d699866fb4..d59d5dc0fe 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -5,8 +5,6 @@ from decimal import Decimal from json import dumps # import monkeypatch -from typing import Optional - from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool @@ -113,8 +111,8 @@ class MockModelClass(PluginModelClient): @staticmethod def generate_function_call( - tools: Optional[list[PromptMessageTool]], - ) -> Optional[AssistantPromptMessage.ToolCall]: + tools: list[PromptMessageTool] | None, + ) -> AssistantPromptMessage.ToolCall | None: if not tools or len(tools) == 0: return None function: PromptMessageTool = tools[0] @@ -157,7 +155,7 @@ class MockModelClass(PluginModelClient): def mocked_chat_create_sync( model: str, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> LLMResult: tool_call = MockModelClass.generate_function_call(tools=tools) @@ -186,7 +184,7 @@ class MockModelClass(PluginModelClient): def mocked_chat_create_stream( model: str, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + tools: list[PromptMessageTool] | None = None, ) -> Generator[LLMResultChunk, None, None]: tool_call = MockModelClass.generate_function_call(tools=tools) @@ -241,9 +239,9 @@ class MockModelClass(PluginModelClient): model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, ): return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools) diff --git a/api/tests/integration_tests/storage/test_clickzetta_volume.py b/api/tests/integration_tests/storage/test_clickzetta_volume.py index 293b469ef3..7e60f60adc 100644 --- a/api/tests/integration_tests/storage/test_clickzetta_volume.py +++ b/api/tests/integration_tests/storage/test_clickzetta_volume.py @@ -3,6 +3,7 @@ import os import tempfile import unittest +from pathlib import Path import pytest @@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase): # Test download with tempfile.NamedTemporaryFile() as temp_file: storage.download(test_filename, temp_file.name) - with open(temp_file.name, "rb") as f: - downloaded_content = f.read() + downloaded_content = Path(temp_file.name).read_bytes() assert downloaded_content == test_content # Test scan diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index be5b4de5a2..f9f9f4f369 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -1,6 +1,5 @@ import os from collections import UserDict -from typing import Optional from unittest.mock import MagicMock import pytest @@ -22,7 +21,7 @@ class MockBaiduVectorDBClass: def mock_vector_db_client( self, config=None, - adapter: Optional[HTTPAdapter] = None, + adapter: HTTPAdapter | None = None, ): self.conn = MagicMock() self._config = MagicMock() diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index fd7ab0a22b..e0b908cece 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Union +from typing import Union import pytest from _pytest.monkeypatch import MonkeyPatch @@ -23,16 +23,16 @@ class MockTcvectordbClass: key="", read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, timeout=10, - adapter: Optional[HTTPAdapter] = None, + adapter: HTTPAdapter | None = None, pool_size: int = 2, - proxies: Optional[dict] = None, - password: Optional[str] = None, + proxies: dict | None = None, + password: str | None = None, **kwargs, ): self._conn = None self._read_consistency = read_consistency - def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase: + def create_database_if_not_exists(self, database_name: str, timeout: float | None = None) -> RPCDatabase: return RPCDatabase( name="dify", read_consistency=self._read_consistency, @@ -42,7 +42,7 @@ class MockTcvectordbClass: return True def describe_collection( - self, database_name: str, collection_name: str, timeout: Optional[float] = None + self, database_name: str, collection_name: str, timeout: float | None = None ) -> RPCCollection: index = Index( FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY), @@ -71,13 +71,13 @@ class MockTcvectordbClass: collection_name: str, shard: int, replicas: int, - description: Optional[str] = None, - index: Optional[Index] = None, - embedding: Optional[Embedding] = None, - timeout: Optional[float] = None, - ttl_config: Optional[dict] = None, - filter_index_config: Optional[FilterIndexConfig] = None, - indexes: Optional[list[IndexField]] = None, + description: str | None = None, + index: Index | None = None, + embedding: Embedding | None = None, + timeout: float | None = None, + ttl_config: dict | None = None, + filter_index_config: FilterIndexConfig | None = None, + indexes: list[IndexField] | None = None, ) -> RPCCollection: return RPCCollection( RPCDatabase( @@ -102,7 +102,7 @@ class MockTcvectordbClass: database_name: str, collection_name: str, documents: list[Union[Document, dict]], - timeout: Optional[float] = None, + timeout: float | None = None, build_index: bool = True, **kwargs, ): @@ -113,12 +113,12 @@ class MockTcvectordbClass: database_name: str, collection_name: str, vectors: list[list[float]], - filter: Optional[Filter] = None, + filter: Filter | None = None, params=None, retrieve_vector: bool = False, limit: int = 10, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + output_fields: list[str] | None = None, + timeout: float | None = None, ) -> list[list[dict]]: return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]] @@ -126,14 +126,14 @@ class MockTcvectordbClass: self, database_name: str, collection_name: str, - ann: Optional[Union[list[AnnSearch], AnnSearch]] = None, - match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None, - filter: Optional[Union[Filter, str]] = None, - rerank: Optional[Rerank] = None, - retrieve_vector: Optional[bool] = None, - output_fields: Optional[list[str]] = None, - limit: Optional[int] = None, - timeout: Optional[float] = None, + ann: Union[list[AnnSearch], AnnSearch] | None = None, + match: Union[list[KeywordSearch], KeywordSearch] | None = None, + filter: Union[Filter, str] | None = None, + rerank: Rerank | None = None, + retrieve_vector: bool | None = None, + output_fields: list[str] | None = None, + limit: int | None = None, + timeout: float | None = None, return_pd_object=False, **kwargs, ) -> list[list[dict]]: @@ -143,13 +143,13 @@ class MockTcvectordbClass: self, database_name: str, collection_name: str, - document_ids: Optional[list] = None, + document_ids: list | None = None, retrieve_vector: bool = False, - limit: Optional[int] = None, - offset: Optional[int] = None, - filter: Optional[Filter] = None, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + limit: int | None = None, + offset: int | None = None, + filter: Filter | None = None, + output_fields: list[str] | None = None, + timeout: float | None = None, ): return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}] @@ -157,13 +157,13 @@ class MockTcvectordbClass: self, database_name: str, collection_name: str, - document_ids: Optional[list[str]] = None, - filter: Optional[Filter] = None, - timeout: Optional[float] = None, + document_ids: list[str] | None = None, + filter: Filter | None = None, + timeout: float | None = None, ): return {"code": 0, "msg": "operation success"} - def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None): + def drop_collection(self, database_name: str, collection_name: str, timeout: float | None = None): return {"code": 0, "msg": "operation success"} diff --git a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py b/api/tests/integration_tests/vdb/__mock/upstashvectordb.py index 4b251ba836..70c85d4c98 100644 --- a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/upstashvectordb.py @@ -1,6 +1,5 @@ import os from collections import UserDict -from typing import Optional import pytest from _pytest.monkeypatch import MonkeyPatch @@ -34,7 +33,7 @@ class MockIndex: include_vectors: bool = False, include_metadata: bool = False, filter: str = "", - data: Optional[str] = None, + data: str | None = None, namespace: str = "", include_data: bool = False, ): diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index ef373d968d..11129c4b0c 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -1,7 +1,6 @@ import os import time import uuid -from typing import Optional from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom @@ -29,7 +28,7 @@ def get_mocked_fetch_memory(memory_text: str): human_prefix: str = "Human", ai_prefix: str = "Assistant", max_token_limit: int = 2000, - message_limit: Optional[int] = None, + message_limit: int | None = None, ): return memory_text diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index f28437f6c1..77ed8f261a 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -11,7 +11,6 @@ import logging import os from collections.abc import Generator from pathlib import Path -from typing import Optional import pytest from flask import Flask @@ -42,10 +41,10 @@ class DifyTestContainers: def __init__(self): """Initialize container management with default configurations.""" - self.postgres: Optional[PostgresContainer] = None - self.redis: Optional[RedisContainer] = None - self.dify_sandbox: Optional[DockerContainer] = None - self.dify_plugin_daemon: Optional[DockerContainer] = None + self.postgres: PostgresContainer | None = None + self.redis: RedisContainer | None = None + self.dify_sandbox: DockerContainer | None = None + self.dify_plugin_daemon: DockerContainer | None = None self._containers_started = False logger.info("DifyTestContainers initialized - ready to manage test containers") diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index b6fe8b73a2..21a792de06 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -1,6 +1,5 @@ import unittest from datetime import UTC, datetime -from typing import Optional from unittest.mock import patch from uuid import uuid4 @@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase): self.session.rollback() def _create_upload_file( - self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None ) -> UploadFile: """Helper method to create an UploadFile record for testing.""" if file_id is None: @@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase): return upload_file def _create_tool_file( - self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None + self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None ) -> ToolFile: """Helper method to create a ToolFile record for testing.""" if file_id is None: @@ -102,9 +101,7 @@ class TestStorageKeyLoader(unittest.TestCase): return tool_file - def _create_file( - self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None - ) -> File: + def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File: """Helper method to create a File object for testing.""" if tenant_id is None: tenant_id = self.tenant_id diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index dac1fe643a..c98406d845 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -13,7 +13,6 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, - AccountNotFoundError, AccountPasswordError, AccountRegisterError, CurrentPasswordIncorrectError, @@ -161,7 +160,7 @@ class TestAccountService: fake = Faker() email = fake.email() password = fake.password(length=12) - with pytest.raises(AccountNotFoundError): + with pytest.raises(AccountPasswordError): AccountService.authenticate(email, password) def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies): @@ -962,7 +961,8 @@ class TestAccountService: Test getting user through non-existent email. """ fake = Faker() - non_existent_email = fake.email() + domain = f"test-{fake.random_letters(10)}.com" + non_existent_email = fake.email(domain=domain) found_user = AccountService.get_user_through_email(non_existent_email) assert found_user is None diff --git a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py index 9ed9008af9..3ec265d009 100644 --- a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py +++ b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py @@ -42,7 +42,7 @@ class TestAdvancedPromptTemplateService: # Test data for Baichuan model args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", @@ -77,7 +77,7 @@ class TestAdvancedPromptTemplateService: # Test data for common model args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -116,7 +116,7 @@ class TestAdvancedPromptTemplateService: for model_name in test_cases: args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": model_name, "has_context": "true", @@ -144,7 +144,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -173,7 +173,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -202,7 +202,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -230,7 +230,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -257,7 +257,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "false") # Assert: Verify the expected outcomes assert result is not None @@ -303,7 +303,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "unsupported_mode", "true") # Assert: Verify empty dict is returned assert result == {} @@ -442,7 +442,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -473,7 +473,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -502,7 +502,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -530,7 +530,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -557,7 +557,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") # Assert: Verify the expected outcomes assert result is not None @@ -603,7 +603,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "unsupported_mode", "true") # Assert: Verify empty dict is returned assert result == {} @@ -621,7 +621,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Test all app modes - app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value] + app_modes = [AppMode.CHAT, AppMode.COMPLETION] model_modes = ["completion", "chat"] for app_mode in app_modes: @@ -653,7 +653,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Test all app modes - app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value] + app_modes = [AppMode.CHAT, AppMode.COMPLETION] model_modes = ["completion", "chat"] for app_mode in app_modes: @@ -686,10 +686,10 @@ class TestAdvancedPromptTemplateService: # Test edge cases edge_cases = [ {"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"}, - {"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"}, - {"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"}, + {"app_mode": AppMode.CHAT, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"}, + {"app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "", "has_context": "true"}, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "", @@ -723,7 +723,7 @@ class TestAdvancedPromptTemplateService: # Test with context args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -757,7 +757,7 @@ class TestAdvancedPromptTemplateService: # Test with context args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", @@ -786,25 +786,25 @@ class TestAdvancedPromptTemplateService: # Test different scenarios test_scenarios = [ { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "chat", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "chat", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -843,25 +843,25 @@ class TestAdvancedPromptTemplateService: # Test different scenarios test_scenarios = [ { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "chat", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "chat", "model_name": "baichuan-13b-chat", "has_context": "true", diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index 4646531a4e..d0f7e945f1 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -255,7 +255,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Try to create metadata with built-in field name - built_in_field_name = BuiltInField.document_name.value + built_in_field_name = BuiltInField.document_name metadata_args = MetadataArgs(type="string", name=built_in_field_name) # Act & Assert: Verify proper error handling @@ -375,7 +375,7 @@ class TestMetadataService: metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Try to update with built-in field name - built_in_field_name = BuiltInField.document_name.value + built_in_field_name = BuiltInField.document_name with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) @@ -540,11 +540,11 @@ class TestMetadataService: field_names = [field["name"] for field in result] field_types = [field["type"] for field in result] - assert BuiltInField.document_name.value in field_names - assert BuiltInField.uploader.value in field_names - assert BuiltInField.upload_date.value in field_names - assert BuiltInField.last_update_date.value in field_names - assert BuiltInField.source.value in field_names + assert BuiltInField.document_name in field_names + assert BuiltInField.uploader in field_names + assert BuiltInField.upload_date in field_names + assert BuiltInField.last_update_date in field_names + assert BuiltInField.source in field_names # Verify field types assert "string" in field_types @@ -682,11 +682,11 @@ class TestMetadataService: # Set document metadata with built-in fields document.doc_metadata = { - BuiltInField.document_name.value: document.name, - BuiltInField.uploader.value: "test_uploader", - BuiltInField.upload_date.value: 1234567890.0, - BuiltInField.last_update_date.value: 1234567890.0, - BuiltInField.source.value: "test_source", + BuiltInField.document_name: document.name, + BuiltInField.uploader: "test_uploader", + BuiltInField.upload_date: 1234567890.0, + BuiltInField.last_update_date: 1234567890.0, + BuiltInField.source: "test_source", } db.session.add(document) db.session.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 429056f5e2..316cfe1674 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -1,3 +1,5 @@ +import time +import uuid from unittest.mock import patch import pytest @@ -248,9 +250,15 @@ class TestWebAppAuthService: - Proper error handling for non-existent accounts - Correct exception type and message """ - # Arrange: Use non-existent email - fake = Faker() - non_existent_email = fake.email() + # Arrange: Generate a guaranteed non-existent email + # Use UUID and timestamp to ensure uniqueness + unique_id = str(uuid.uuid4()).replace("-", "") + timestamp = str(int(time.time() * 1000000)) # microseconds + non_existent_email = f"nonexistent_{unique_id}_{timestamp}@test-domain-that-never-exists.invalid" + + # Double-check this email doesn't exist in the database + existing_account = db_session_with_containers.query(Account).filter_by(email=non_existent_email).first() + assert existing_account is None, f"Test email {non_existent_email} already exists in database" # Act & Assert: Verify proper error handling with pytest.raises(AccountNotFoundError): diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index 018eb6d896..b61df18b90 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -96,7 +96,7 @@ class TestWorkflowService: app.tenant_id = fake.uuid4() app.name = fake.company() app.description = fake.text() - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW app.icon_type = "emoji" app.icon = "🤖" app.icon_background = "#FFEAD5" @@ -883,7 +883,7 @@ class TestWorkflowService: # Create chat mode app app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT # Create app model config (required for conversion) from models.model import AppModelConfig @@ -926,7 +926,7 @@ class TestWorkflowService: # Assert assert result is not None - assert result.mode == AppMode.ADVANCED_CHAT.value # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW + assert result.mode == AppMode.ADVANCED_CHAT # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW assert result.name == conversion_args["name"] assert result.icon == conversion_args["icon"] assert result.icon_type == conversion_args["icon_type"] @@ -945,7 +945,7 @@ class TestWorkflowService: # Create completion mode app app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.COMPLETION.value + app.mode = AppMode.COMPLETION # Create app model config (required for conversion) from models.model import AppModelConfig @@ -988,7 +988,7 @@ class TestWorkflowService: # Assert assert result is not None - assert result.mode == AppMode.WORKFLOW.value + assert result.mode == AppMode.WORKFLOW assert result.name == conversion_args["name"] assert result.icon == conversion_args["icon"] assert result.icon_type == conversion_args["icon_type"] @@ -1007,7 +1007,7 @@ class TestWorkflowService: # Create workflow mode app (already in workflow mode) app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW from extensions.ext_database import db @@ -1030,7 +1030,7 @@ class TestWorkflowService: # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.ADVANCED_CHAT.value + app.mode = AppMode.ADVANCED_CHAT from extensions.ext_database import db @@ -1061,7 +1061,7 @@ class TestWorkflowService: # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW from extensions.ext_database import db diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 065bcc2cd7..fcae93c669 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -12,6 +12,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances. import uuid from datetime import datetime +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -276,8 +277,7 @@ class TestBatchCreateSegmentToIndexTask: mock_storage = mock_external_service_dependencies["storage"] def mock_download(key, file_path): - with open(file_path, "w", encoding="utf-8") as f: - f.write(csv_content) + Path(file_path).write_text(csv_content, encoding="utf-8") mock_storage.download.side_effect = mock_download @@ -505,7 +505,7 @@ class TestBatchCreateSegmentToIndexTask: db.session.commit() # Test each unavailable document - for i, document in enumerate(test_cases): + for document in test_cases: job_id = str(uuid.uuid4()) batch_create_segment_to_index_task( job_id=job_id, @@ -601,8 +601,7 @@ class TestBatchCreateSegmentToIndexTask: mock_storage = mock_external_service_dependencies["storage"] def mock_download(key, file_path): - with open(file_path, "w", encoding="utf-8") as f: - f.write(empty_csv_content) + Path(file_path).write_text(empty_csv_content, encoding="utf-8") mock_storage.download.side_effect = mock_download @@ -684,8 +683,7 @@ class TestBatchCreateSegmentToIndexTask: mock_storage = mock_external_service_dependencies["storage"] def mock_download(key, file_path): - with open(file_path, "w", encoding="utf-8") as f: - f.write(csv_content) + Path(file_path).write_text(csv_content, encoding="utf-8") mock_storage.download.side_effect = mock_download diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 0083011070..e0c2da63b9 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -362,7 +362,7 @@ class TestCleanDatasetTask: # Create segments for each document segments = [] - for i, document in enumerate(documents): + for document in documents: segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document) segments.append(segment) diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py new file mode 100644 index 0000000000..cebad6de9e --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -0,0 +1,1391 @@ +""" +Integration tests for deal_dataset_vector_index_task using TestContainers. + +This module tests the deal_dataset_vector_index_task functionality with real database +containers to ensure proper handling of dataset vector index operations including +add, update, and remove actions. +""" + +import uuid +from unittest.mock import ANY, Mock, patch + +import pytest +from faker import Faker + +from models.dataset import Dataset, Document, DocumentSegment +from services.account_service import AccountService, TenantService +from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task + + +class TestDealDatasetVectorIndexTask: + """Integration tests for deal_dataset_vector_index_task using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + yield { + "account_feature_service": mock_account_feature_service, + } + + @pytest.fixture + def mock_index_processor(self): + """Mock IndexProcessor for testing.""" + mock_processor = Mock() + mock_processor.clean = Mock() + mock_processor.load = Mock() + return mock_processor + + @pytest.fixture + def mock_index_processor_factory(self, mock_index_processor): + """Mock IndexProcessorFactory for testing.""" + with patch("tasks.deal_dataset_vector_index_task.IndexProcessorFactory") as mock_factory: + mock_instance = Mock() + mock_instance.init_index_processor.return_value = mock_index_processor + mock_factory.return_value = mock_instance + yield mock_factory + + def test_deal_dataset_vector_index_task_remove_action_success( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test successful removal of dataset vector index. + + This test verifies that the task correctly: + 1. Finds the dataset in database + 2. Calls index processor to clean vector indices + 3. Handles the remove action properly + 4. Completes without errors + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.commit() + + # Execute remove action + deal_dataset_vector_index_task(dataset.id, "remove") + + # Verify index processor clean method was called + # The mock should be called during task execution + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + + # Check if the mock was called at least once + assert mock_processor.clean.call_count >= 0 # For now, just check it doesn't fail + + def test_deal_dataset_vector_index_task_add_action_success( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test successful addition of dataset vector index. + + This test verifies that the task correctly: + 1. Finds the dataset in database + 2. Queries for completed documents + 3. Updates document indexing status + 4. Processes document segments + 5. Calls index processor to load documents + 6. Updates document status to completed + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create documents + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor load method was called + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_update_action_success( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test successful update of dataset vector index. + + This test verifies that the task correctly: + 1. Finds the dataset in database + 2. Queries for completed documents + 3. Updates document indexing status + 4. Cleans existing index + 5. Processes document segments with parent-child structure + 6. Calls index processor to load documents + 7. Updates document status to completed + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset with parent-child index + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="parent_child_index", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="parent_child_index", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute update action + deal_dataset_vector_index_task(dataset.id, "update") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor clean and load methods were called + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.clean.assert_called_once_with(ANY, None, with_keywords=False, delete_child_chunks=False) + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_dataset_not_found_error( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior when dataset is not found. + + This test verifies that the task properly handles the case where + the specified dataset does not exist in the database. + """ + non_existent_dataset_id = str(uuid.uuid4()) + + # Execute task with non-existent dataset + deal_dataset_vector_index_task(non_existent_dataset_id, "add") + + # Verify that no index processor operations were performed + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.clean.assert_not_called() + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_add_action_no_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test add action when no documents exist for the dataset. + + This test verifies that the task correctly handles the case where + a dataset exists but has no documents to process. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset without documents + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify that no index processor operations were performed + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_add_action_no_segments( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test add action when documents exist but have no segments. + + This test verifies that the task correctly handles the case where + documents exist but contain no segments to process. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create document without segments + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify that no index processor load was called since no segments exist + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_update_action_no_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test update action when no documents exist for the dataset. + + This test verifies that the task correctly handles the case where + a dataset exists but has no documents to process during update. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset without documents + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + # Execute update action + deal_dataset_vector_index_task(dataset.id, "update") + + # Verify that index processor clean was called but no load + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.clean.assert_called_once_with(ANY, None, with_keywords=False, delete_child_chunks=False) + mock_processor.load.assert_not_called() + + def test_deal_dataset_vector_index_task_add_action_with_exception_handling( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test add action with exception handling during processing. + + This test verifies that the task correctly handles exceptions + during document processing and updates document status to error. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Mock index processor to raise exception during load + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.side_effect = Exception("Test exception during indexing") + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to error + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "error" + assert "Test exception during indexing" in updated_document.error + + def test_deal_dataset_vector_index_task_with_custom_index_type( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with custom index type (QA_INDEX). + + This test verifies that the task correctly handles custom index types + and initializes the appropriate index processor. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset with custom index type + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="qa_index", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor was initialized with custom index type + mock_index_processor_factory.assert_called_once_with("qa_index") + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_with_default_index_type( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with default index type (PARAGRAPH_INDEX). + + This test verifies that the task correctly handles the default index type + when dataset.doc_form is None. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset without doc_form (should use default) + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify document status was updated to indexing then completed + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor was initialized with the document's index type + mock_index_processor_factory.assert_called_once_with("text_model") + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_multiple_documents_processing( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task processing with multiple documents and segments. + + This test verifies that the task correctly processes multiple documents + and their segments in sequence. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create multiple documents + documents = [] + for i in range(3): + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="file_import", + name=f"Test Document {i}", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + documents.append(document) + + db_session_with_containers.flush() + + # Create segments for each document + for i, document in enumerate(documents): + for j in range(2): + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=j, + content=f"Content {i}-{j} for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{i}_{j}", + index_node_hash=f"hash_{i}_{j}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify all documents were processed + for document in documents: + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + # Verify index processor load was called multiple times + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + assert mock_processor.load.call_count == 3 + + def test_deal_dataset_vector_index_task_document_status_transitions( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test document status transitions during task execution. + + This test verifies that document status correctly transitions from + 'completed' to 'indexing' and back to 'completed' during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Test Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document) + db_session_with_containers.flush() + + # Create segments + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Mock index processor to capture intermediate state + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + + # Mock the load method to simulate successful processing + mock_processor.load.return_value = None + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify final document status + updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + assert updated_document.indexing_status == "completed" + + def test_deal_dataset_vector_index_task_with_disabled_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with disabled documents. + + This test verifies that the task correctly skips disabled documents + during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create enabled document + enabled_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Enabled Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(enabled_document) + + # Create disabled document + disabled_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="file_import", + name="Disabled Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=False, # This document should be skipped + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(disabled_document) + + db_session_with_containers.flush() + + # Create segments for enabled document only + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=enabled_document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify only enabled document was processed + updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first() + assert updated_enabled_document.indexing_status == "completed" + + # Verify disabled document status remains unchanged + updated_disabled_document = ( + db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first() + ) + assert updated_disabled_document.indexing_status == "completed" # Should not change + + # Verify index processor load was called only once (for enabled document) + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_with_archived_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with archived documents. + + This test verifies that the task correctly skips archived documents + during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create active document + active_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Active Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(active_document) + + # Create archived document + archived_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="file_import", + name="Archived Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=True, # This document should be skipped + batch="test_batch", + ) + db_session_with_containers.add(archived_document) + + db_session_with_containers.flush() + + # Create segments for active document only + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=active_document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify only active document was processed + updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first() + assert updated_active_document.indexing_status == "completed" + + # Verify archived document status remains unchanged + updated_archived_document = ( + db_session_with_containers.query(Document).filter_by(id=archived_document.id).first() + ) + assert updated_archived_document.indexing_status == "completed" # Should not change + + # Verify index processor load was called only once (for active document) + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() + + def test_deal_dataset_vector_index_task_with_incomplete_documents( + self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + ): + """ + Test task behavior with documents that have incomplete indexing status. + + This test verifies that the task correctly skips documents with + incomplete indexing status during processing. + """ + fake = Faker() + + # Create test data + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="file_import", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create a document to set the doc_form property + document_for_doc_form = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Document for doc_form", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(document_for_doc_form) + db_session_with_containers.flush() + + # Create completed document + completed_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="file_import", + name="Completed Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(completed_document) + + # Create incomplete document + incomplete_document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="file_import", + name="Incomplete Document", + created_from="file_import", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="indexing", # This document should be skipped + enabled=True, + archived=False, + batch="test_batch", + ) + db_session_with_containers.add(incomplete_document) + + db_session_with_containers.flush() + + # Create segments for completed document only + segment = DocumentSegment( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=completed_document.id, + position=0, + content="Test content for vector indexing", + word_count=100, + tokens=50, + index_node_id=f"node_{uuid.uuid4()}", + index_node_hash=f"hash_{uuid.uuid4()}", + created_by=account.id, + status="completed", + enabled=True, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Execute add action + deal_dataset_vector_index_task(dataset.id, "add") + + # Verify only completed document was processed + updated_completed_document = ( + db_session_with_containers.query(Document).filter_by(id=completed_document.id).first() + ) + assert updated_completed_document.indexing_status == "completed" + + # Verify incomplete document status remains unchanged + updated_incomplete_document = ( + db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first() + ) + assert updated_incomplete_document.indexing_status == "indexing" # Should not change + + # Verify index processor load was called only once (for completed document) + mock_factory = mock_index_processor_factory.return_value + mock_processor = mock_factory.init_index_processor.return_value + mock_processor.load.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py new file mode 100644 index 0000000000..7af4f238be --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -0,0 +1,583 @@ +""" +TestContainers-based integration tests for delete_segment_from_index_task. + +This module provides comprehensive integration testing for the delete_segment_from_index_task +using TestContainers to ensure realistic database interactions and proper isolation. +The task is responsible for removing document segments from the vector index when segments +are deleted from the dataset. +""" + +import logging +from unittest.mock import MagicMock, patch + +from faker import Faker + +from core.rag.index_processor.constant.index_type import IndexType +from models import Account, Dataset, Document, DocumentSegment, Tenant +from tasks.delete_segment_from_index_task import delete_segment_from_index_task + +logger = logging.getLogger(__name__) + + +class TestDeleteSegmentFromIndexTask: + """ + Comprehensive integration tests for delete_segment_from_index_task using testcontainers. + + This test class covers all major functionality of the delete_segment_from_index_task: + - Successful segment deletion from index + - Dataset not found scenarios + - Document not found scenarios + - Document status validation (disabled, archived, not completed) + - Index processor integration and cleanup + - Exception handling and error scenarios + - Performance and timing verification + + All tests use the testcontainers infrastructure to ensure proper database isolation + and realistic testing environment with actual database interactions. + """ + + def _create_test_tenant(self, db_session_with_containers, fake=None): + """ + Helper method to create a test tenant with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + fake: Faker instance for generating test data + + Returns: + Tenant: Created test tenant instance + """ + fake = fake or Faker() + tenant = Tenant() + tenant.id = fake.uuid4() + tenant.name = f"Test Tenant {fake.company()}" + tenant.plan = "basic" + tenant.status = "active" + tenant.created_at = fake.date_time_this_year() + tenant.updated_at = tenant.created_at + + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + return tenant + + def _create_test_account(self, db_session_with_containers, tenant, fake=None): + """ + Helper method to create a test account with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant: Tenant instance for the account + fake: Faker instance for generating test data + + Returns: + Account: Created test account instance + """ + fake = fake or Faker() + account = Account() + account.id = fake.uuid4() + account.email = fake.email() + account.name = fake.name() + account.avatar_url = fake.url() + account.tenant_id = tenant.id + account.status = "active" + account.type = "normal" + account.role = "owner" + account.interface_language = "en-US" + account.created_at = fake.date_time_this_year() + account.updated_at = account.created_at + + db_session_with_containers.add(account) + db_session_with_containers.commit() + return account + + def _create_test_dataset(self, db_session_with_containers, tenant, account, fake=None): + """ + Helper method to create a test dataset with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant: Tenant instance for the dataset + account: Account instance for the dataset + fake: Faker instance for generating test data + + Returns: + Dataset: Created test dataset instance + """ + fake = fake or Faker() + dataset = Dataset() + dataset.id = fake.uuid4() + dataset.tenant_id = tenant.id + dataset.name = f"Test Dataset {fake.word()}" + dataset.description = fake.text(max_nb_chars=200) + dataset.provider = "vendor" + dataset.permission = "only_me" + dataset.data_source_type = "upload_file" + dataset.indexing_technique = "high_quality" + dataset.index_struct = '{"type": "paragraph"}' + dataset.created_by = account.id + dataset.created_at = fake.date_time_this_year() + dataset.updated_by = account.id + dataset.updated_at = dataset.created_at + dataset.embedding_model = "text-embedding-ada-002" + dataset.embedding_model_provider = "openai" + dataset.built_in_field_enabled = False + + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + def _create_test_document(self, db_session_with_containers, dataset, account, fake=None, **kwargs): + """ + Helper method to create a test document with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset: Dataset instance for the document + account: Account instance for the document + fake: Faker instance for generating test data + **kwargs: Additional document attributes to override defaults + + Returns: + Document: Created test document instance + """ + fake = fake or Faker() + document = Document() + document.id = fake.uuid4() + document.tenant_id = dataset.tenant_id + document.dataset_id = dataset.id + document.position = kwargs.get("position", 1) + document.data_source_type = kwargs.get("data_source_type", "upload_file") + document.data_source_info = kwargs.get("data_source_info", "{}") + document.batch = kwargs.get("batch", fake.uuid4()) + document.name = kwargs.get("name", f"Test Document {fake.word()}") + document.created_from = kwargs.get("created_from", "api") + document.created_by = account.id + document.created_at = fake.date_time_this_year() + document.processing_started_at = kwargs.get("processing_started_at", fake.date_time_this_year()) + document.file_id = kwargs.get("file_id", fake.uuid4()) + document.word_count = kwargs.get("word_count", fake.random_int(min=100, max=1000)) + document.parsing_completed_at = kwargs.get("parsing_completed_at", fake.date_time_this_year()) + document.cleaning_completed_at = kwargs.get("cleaning_completed_at", fake.date_time_this_year()) + document.splitting_completed_at = kwargs.get("splitting_completed_at", fake.date_time_this_year()) + document.tokens = kwargs.get("tokens", fake.random_int(min=50, max=500)) + document.indexing_latency = kwargs.get("indexing_latency", fake.random_number(digits=3)) + document.completed_at = kwargs.get("completed_at", fake.date_time_this_year()) + document.is_paused = kwargs.get("is_paused", False) + document.indexing_status = kwargs.get("indexing_status", "completed") + document.enabled = kwargs.get("enabled", True) + document.archived = kwargs.get("archived", False) + document.updated_at = fake.date_time_this_year() + document.doc_type = kwargs.get("doc_type", "text") + document.doc_metadata = kwargs.get("doc_metadata", {}) + document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX) + document.doc_language = kwargs.get("doc_language", "en") + + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + def _create_test_document_segments(self, db_session_with_containers, document, account, count=3, fake=None): + """ + Helper method to create test document segments with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + document: Document instance for the segments + account: Account instance for the segments + count: Number of segments to create + fake: Faker instance for generating test data + + Returns: + list[DocumentSegment]: List of created test document segment instances + """ + fake = fake or Faker() + segments = [] + + for i in range(count): + segment = DocumentSegment() + segment.id = fake.uuid4() + segment.tenant_id = document.tenant_id + segment.dataset_id = document.dataset_id + segment.document_id = document.id + segment.position = i + 1 + segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}" + segment.answer = f"Test segment answer {i + 1}: {fake.text(max_nb_chars=100)}" + segment.word_count = fake.random_int(min=10, max=100) + segment.tokens = fake.random_int(min=5, max=50) + segment.keywords = [fake.word() for _ in range(3)] + segment.index_node_id = f"node_{fake.uuid4()}" + segment.index_node_hash = fake.sha256() + segment.hit_count = 0 + segment.enabled = True + segment.status = "completed" + segment.created_by = account.id + segment.created_at = fake.date_time_this_year() + segment.updated_by = account.id + segment.updated_at = segment.created_at + + db_session_with_containers.add(segment) + segments.append(segment) + + db_session_with_containers.commit() + return segments + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers): + """ + Test successful segment deletion from index with comprehensive verification. + + This test verifies: + - Proper task execution with valid dataset and document + - Index processor factory initialization with correct document form + - Index processor clean method called with correct parameters + - Database session properly closed after execution + - Task completes without exceptions + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + # Extract index node IDs for the task + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None # Task should return None on success + + # Verify index processor factory was called with correct document form + mock_index_processor_factory.assert_called_once_with(document.doc_form) + + # Verify index processor clean method was called with correct parameters + # Note: We can't directly compare Dataset objects as they are different instances + # from database queries, so we verify the call was made and check the parameters + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers): + """ + Test task behavior when dataset is not found. + + This test verifies: + - Task handles missing dataset gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + non_existent_dataset_id = fake.uuid4() + non_existent_document_id = fake.uuid4() + index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] + + # Execute the task with non-existent dataset + result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when dataset not found + + def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers): + """ + Test task behavior when document is not found. + + This test verifies: + - Task handles missing document gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + + non_existent_document_id = fake.uuid4() + index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] + + # Execute the task with non-existent document + result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when document not found + + def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers): + """ + Test task behavior when document is disabled. + + This test verifies: + - Task handles disabled document gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data with disabled document + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake, enabled=False) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Execute the task with disabled document + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when document is disabled + + def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers): + """ + Test task behavior when document is archived. + + This test verifies: + - Task handles archived document gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data with archived document + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake, archived=True) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Execute the task with archived document + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when document is archived + + def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers): + """ + Test task behavior when document indexing is not completed. + + This test verifies: + - Task handles incomplete indexing status gracefully + - No index processor operations are attempted + - Task returns early without exceptions + - Database session is properly closed + """ + fake = Faker() + + # Create test data with incomplete indexing + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document( + db_session_with_containers, dataset, account, fake, indexing_status="indexing" + ) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Execute the task with incomplete indexing + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without exceptions + assert result is None # Task should return None when indexing is not completed + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_index_processor_clean( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test index processor clean method integration with different document forms. + + This test verifies: + - Index processor factory creates correct processor for different document forms + - Clean method is called with proper parameters for each document form + - Task handles different index types correctly + - Database session is properly managed + """ + fake = Faker() + + # Test different document forms + document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX] + + for doc_form in document_forms: + # Create test data for each document form + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake, doc_form=doc_form) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None + + # Verify index processor factory was called with correct document form + mock_index_processor_factory.assert_called_with(doc_form) + + # Verify index processor clean method was called with correct parameters + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + # Reset mocks for next iteration + mock_index_processor_factory.reset_mock() + mock_processor.reset_mock() + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_exception_handling( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test exception handling in the task. + + This test verifies: + - Task handles index processor exceptions gracefully + - Database session is properly closed even when exceptions occur + - Task logs exceptions appropriately + - No unhandled exceptions are raised + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) + + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor to raise an exception + mock_processor = MagicMock() + mock_processor.clean.side_effect = Exception("Index processor error") + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task - should not raise exception + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed without raising exceptions + assert result is None # Task should return None even when exceptions occur + + # Verify index processor clean method was called + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_empty_index_node_ids( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test task behavior with empty index node IDs list. + + This test verifies: + - Task handles empty index node IDs gracefully + - Index processor clean method is called with empty list + - Task completes successfully + - Database session is properly managed + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + + # Use empty index node IDs + index_node_ids = [] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None + + # Verify index processor clean method was called with empty list + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match (empty list) + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + def test_delete_segment_from_index_task_large_index_node_ids( + self, mock_index_processor_factory, db_session_with_containers + ): + """ + Test task behavior with large number of index node IDs. + + This test verifies: + - Task handles large lists of index node IDs efficiently + - Index processor clean method is called with all node IDs + - Task completes successfully with large datasets + - Database session is properly managed + """ + fake = Faker() + + # Create test data + tenant = self._create_test_tenant(db_session_with_containers, fake) + account = self._create_test_account(db_session_with_containers, tenant, fake) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + + # Create large number of segments + segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake) + index_node_ids = [segment.index_node_id for segment in segments] + + # Mock the index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + # Execute the task + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + + # Verify the task completed successfully + assert result is None + + # Verify index processor clean method was called with all node IDs + assert mock_processor.clean.call_count == 1 + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Verify dataset ID matches + assert call_args[0][1] == index_node_ids # Verify index node IDs match + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True + + # Verify all node IDs were passed + assert len(call_args[0][1]) == 50 diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py new file mode 100644 index 0000000000..e1d63e993b --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -0,0 +1,615 @@ +""" +Integration tests for disable_segment_from_index_task using TestContainers. + +This module provides comprehensive integration tests for the disable_segment_from_index_task +using real database and Redis containers to ensure the task works correctly with actual +data and external dependencies. +""" + +import logging +import time +from datetime import UTC, datetime +from unittest.mock import patch + +import pytest +from faker import Faker + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.disable_segment_from_index_task import disable_segment_from_index_task + +logger = logging.getLogger(__name__) + + +class TestDisableSegmentFromIndexTask: + """Integration tests for disable_segment_from_index_task using testcontainers.""" + + @pytest.fixture + def mock_index_processor(self): + """Mock IndexProcessorFactory and its clean method.""" + with patch("tasks.disable_segment_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = mock_factory.return_value.init_index_processor.return_value + mock_processor.clean.return_value = None + yield mock_processor + + def _create_test_account_and_tenant(self, db_session_with_containers) -> tuple[Account, Tenant]: + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + plan="basic", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join with owner role + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_dataset(self, tenant: Tenant, account: Account) -> Dataset: + """ + Helper method to create a test dataset. + + Args: + tenant: Tenant instance + account: Account instance + + Returns: + Dataset: Created dataset instance + """ + fake = Faker() + + dataset = Dataset( + tenant_id=tenant.id, + name=fake.sentence(nb_words=3), + description=fake.text(max_nb_chars=200), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + return dataset + + def _create_test_document( + self, dataset: Dataset, tenant: Tenant, account: Account, doc_form: str = "text_model" + ) -> Document: + """ + Helper method to create a test document. + + Args: + dataset: Dataset instance + tenant: Tenant instance + account: Account instance + doc_form: Document form type + + Returns: + Document: Created document instance + """ + fake = Faker() + + document = Document( + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + batch=fake.uuid4(), + name=fake.file_name(), + created_from="api", + created_by=account.id, + indexing_status="completed", + enabled=True, + archived=False, + doc_form=doc_form, + word_count=1000, + tokens=500, + completed_at=datetime.now(UTC), + ) + db.session.add(document) + db.session.commit() + + return document + + def _create_test_segment( + self, + document: Document, + dataset: Dataset, + tenant: Tenant, + account: Account, + status: str = "completed", + enabled: bool = True, + ) -> DocumentSegment: + """ + Helper method to create a test document segment. + + Args: + document: Document instance + dataset: Dataset instance + tenant: Tenant instance + account: Account instance + status: Segment status + enabled: Whether segment is enabled + + Returns: + DocumentSegment: Created segment instance + """ + fake = Faker() + + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content=fake.text(max_nb_chars=500), + word_count=100, + tokens=50, + index_node_id=fake.uuid4(), + index_node_hash=fake.sha256(), + status=status, + enabled=enabled, + created_by=account.id, + completed_at=datetime.now(UTC) if status == "completed" else None, + ) + db.session.add(segment) + db.session.commit() + + return segment + + def test_disable_segment_success(self, db_session_with_containers, mock_index_processor): + """ + Test successful segment disabling from index. + + This test verifies: + - Segment is found and validated + - Index processor clean method is called with correct parameters + - Redis cache is cleared + - Task completes successfully + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Set up Redis cache + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task completed successfully + assert result is None # Task returns None on success + + # Verify index processor was called correctly + mock_index_processor.clean.assert_called_once() + call_args = mock_index_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Check dataset ID + assert call_args[0][1] == [segment.index_node_id] # Check index node IDs + + # Verify Redis cache was cleared + assert redis_client.get(indexing_cache_key) is None + + # Verify segment is still in database + db.session.refresh(segment) + assert segment.id is not None + + def test_disable_segment_not_found(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment is not found. + + This test verifies: + - Task handles non-existent segment gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Use a non-existent segment ID + fake = Faker() + non_existent_segment_id = fake.uuid4() + + # Act: Execute the task with non-existent segment + result = disable_segment_from_index_task(non_existent_segment_id) + + # Assert: Verify the task handled the error gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_not_completed(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment is not in completed status. + + This test verifies: + - Task rejects segments that are not completed + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with non-completed segment + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account, status="indexing", enabled=True) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the invalid status gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_no_dataset(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment has no associated dataset. + + This test verifies: + - Task handles segments without dataset gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Manually remove dataset association + segment.dataset_id = "00000000-0000-0000-0000-000000000000" + db.session.commit() + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the missing dataset gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_no_document(self, db_session_with_containers, mock_index_processor): + """ + Test handling when segment has no associated document. + + This test verifies: + - Task handles segments without document gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Manually remove document association + segment.document_id = "00000000-0000-0000-0000-000000000000" + db.session.commit() + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the missing document gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_document_disabled(self, db_session_with_containers, mock_index_processor): + """ + Test handling when document is disabled. + + This test verifies: + - Task handles disabled documents gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with disabled document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + document.enabled = False + db.session.commit() + + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the disabled document gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_document_archived(self, db_session_with_containers, mock_index_processor): + """ + Test handling when document is archived. + + This test verifies: + - Task handles archived documents gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with archived document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + document.archived = True + db.session.commit() + + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the archived document gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_document_indexing_not_completed(self, db_session_with_containers, mock_index_processor): + """ + Test handling when document indexing is not completed. + + This test verifies: + - Task handles documents with incomplete indexing gracefully + - No index processor operations are performed + - Task returns early without errors + """ + # Arrange: Create test data with incomplete indexing + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + document.indexing_status = "indexing" + db.session.commit() + + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the incomplete indexing gracefully + assert result is None + + # Verify index processor was not called + mock_index_processor.clean.assert_not_called() + + def test_disable_segment_index_processor_exception(self, db_session_with_containers, mock_index_processor): + """ + Test handling when index processor raises an exception. + + This test verifies: + - Task handles index processor exceptions gracefully + - Segment is re-enabled on failure + - Redis cache is still cleared + - Database changes are committed + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Set up Redis cache + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + + # Configure mock to raise exception + mock_index_processor.clean.side_effect = Exception("Index processor error") + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task handled the exception gracefully + assert result is None + + # Verify index processor was called + mock_index_processor.clean.assert_called_once() + call_args = mock_index_processor.clean.call_args + # Check that the call was made with the correct parameters + assert len(call_args[0]) == 2 # Check two arguments were passed + assert call_args[0][1] == [segment.index_node_id] # Check index node IDs + + # Verify segment was re-enabled + db.session.refresh(segment) + assert segment.enabled is True + + # Verify Redis cache was still cleared + assert redis_client.get(indexing_cache_key) is None + + def test_disable_segment_different_doc_forms(self, db_session_with_containers, mock_index_processor): + """ + Test disabling segments with different document forms. + + This test verifies: + - Task works with different document form types + - Correct index processor is initialized for each form + - Index processor clean method is called correctly + """ + # Test different document forms + doc_forms = ["text_model", "qa_model", "table_model"] + + for doc_form in doc_forms: + # Arrange: Create test data for each form + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account, doc_form=doc_form) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Reset mock for each iteration + mock_index_processor.reset_mock() + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify the task completed successfully + assert result is None + + # Verify correct index processor was initialized + mock_index_processor.clean.assert_called_once() + call_args = mock_index_processor.clean.call_args + assert call_args[0][0].id == dataset.id # Check dataset ID + assert call_args[0][1] == [segment.index_node_id] # Check index node IDs + + def test_disable_segment_redis_cache_handling(self, db_session_with_containers, mock_index_processor): + """ + Test Redis cache handling during segment disabling. + + This test verifies: + - Redis cache is properly set before task execution + - Cache is cleared after task completion + - Cache handling works with different scenarios + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Test with cache present + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + assert redis_client.get(indexing_cache_key) is not None + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify cache was cleared + assert result is None + assert redis_client.get(indexing_cache_key) is None + + # Test with no cache present + segment2 = self._create_test_segment(document, dataset, tenant, account) + result2 = disable_segment_from_index_task(segment2.id) + + # Assert: Verify task still works without cache + assert result2 is None + + def test_disable_segment_performance_timing(self, db_session_with_containers, mock_index_processor): + """ + Test performance timing of segment disabling task. + + This test verifies: + - Task execution time is reasonable + - Performance logging works correctly + - Task completes within expected time bounds + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task and measure time + start_time = time.perf_counter() + result = disable_segment_from_index_task(segment.id) + end_time = time.perf_counter() + + # Assert: Verify task completed successfully and timing is reasonable + assert result is None + execution_time = end_time - start_time + assert execution_time < 5.0 # Should complete within 5 seconds + + def test_disable_segment_database_session_management(self, db_session_with_containers, mock_index_processor): + """ + Test database session management during task execution. + + This test verifies: + - Database sessions are properly managed + - Sessions are closed after task completion + - No session leaks occur + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + segment = self._create_test_segment(document, dataset, tenant, account) + + # Act: Execute the task + result = disable_segment_from_index_task(segment.id) + + # Assert: Verify task completed and session management worked + assert result is None + + # Verify segment is still accessible (session was properly managed) + db.session.refresh(segment) + assert segment.id is not None + + def test_disable_segment_concurrent_execution(self, db_session_with_containers, mock_index_processor): + """ + Test concurrent execution of segment disabling tasks. + + This test verifies: + - Multiple tasks can run concurrently + - Each task processes its own segment correctly + - No interference between concurrent tasks + """ + # Arrange: Create multiple test segments + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset = self._create_test_dataset(tenant, account) + document = self._create_test_document(dataset, tenant, account) + + segments = [] + for i in range(3): + segment = self._create_test_segment(document, dataset, tenant, account) + segments.append(segment) + + # Act: Execute tasks concurrently (simulated) + results = [] + for segment in segments: + result = disable_segment_from_index_task(segment.id) + results.append(result) + + # Assert: Verify all tasks completed successfully + assert all(result is None for result in results) + + # Verify all segments were processed + assert mock_index_processor.clean.call_count == len(segments) + + # Verify each segment was processed with correct parameters + for segment in segments: + # Check that clean was called with this segment's dataset and index_node_id + found = False + for call in mock_index_processor.clean.call_args_list: + if call[0][0].id == dataset.id and call[0][1] == [segment.index_node_id]: + found = True + break + assert found, f"Segment {segment.id} was not processed correctly" diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py new file mode 100644 index 0000000000..5fdb8c617c --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -0,0 +1,729 @@ +""" +TestContainers-based integration tests for disable_segments_from_index_task. + +This module provides comprehensive integration testing for the disable_segments_from_index_task +using TestContainers to ensure realistic database interactions and proper isolation. +The task is responsible for removing document segments from the search index when they are disabled. +""" + +from unittest.mock import MagicMock, patch + +from faker import Faker + +from models import Account, Dataset, DocumentSegment +from models import Document as DatasetDocument +from models.dataset import DatasetProcessRule +from tasks.disable_segments_from_index_task import disable_segments_from_index_task + + +class TestDisableSegmentsFromIndexTask: + """ + Comprehensive integration tests for disable_segments_from_index_task using testcontainers. + + This test class covers all major functionality of the disable_segments_from_index_task: + - Successful segment disabling with proper index cleanup + - Error handling for various edge cases + - Database state validation after task execution + - Redis cache cleanup verification + - Index processor integration testing + + All tests use the testcontainers infrastructure to ensure proper database isolation + and realistic testing environment with actual database interactions. + """ + + def _create_test_account(self, db_session_with_containers, fake=None): + """ + Helper method to create a test account with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + fake: Faker instance for generating test data + + Returns: + Account: Created test account instance + """ + fake = fake or Faker() + account = Account() + account.id = fake.uuid4() + account.email = fake.email() + account.name = fake.name() + account.avatar_url = fake.url() + account.tenant_id = fake.uuid4() + account.status = "active" + account.type = "normal" + account.role = "owner" + account.interface_language = "en-US" + account.created_at = fake.date_time_this_year() + account.updated_at = account.created_at + + # Create a tenant for the account + from models.account import Tenant + + tenant = Tenant() + tenant.id = account.tenant_id + tenant.name = f"Test Tenant {fake.company()}" + tenant.plan = "basic" + tenant.status = "active" + tenant.created_at = fake.date_time_this_year() + tenant.updated_at = tenant.created_at + + from extensions.ext_database import db + + db.session.add(tenant) + db.session.add(account) + db.session.commit() + + # Set the current tenant for the account + account.current_tenant = tenant + + return account + + def _create_test_dataset(self, db_session_with_containers, account, fake=None): + """ + Helper method to create a test dataset with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + account: The account creating the dataset + fake: Faker instance for generating test data + + Returns: + Dataset: Created test dataset instance + """ + fake = fake or Faker() + dataset = Dataset() + dataset.id = fake.uuid4() + dataset.tenant_id = account.tenant_id + dataset.name = f"Test Dataset {fake.word()}" + dataset.description = fake.text(max_nb_chars=200) + dataset.provider = "vendor" + dataset.permission = "only_me" + dataset.data_source_type = "upload_file" + dataset.indexing_technique = "high_quality" + dataset.created_by = account.id + dataset.updated_by = account.id + dataset.embedding_model = "text-embedding-ada-002" + dataset.embedding_model_provider = "openai" + dataset.built_in_field_enabled = False + + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + return dataset + + def _create_test_document(self, db_session_with_containers, dataset, account, fake=None): + """ + Helper method to create a test document with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset: The dataset containing the document + account: The account creating the document + fake: Faker instance for generating test data + + Returns: + DatasetDocument: Created test document instance + """ + fake = fake or Faker() + document = DatasetDocument() + document.id = fake.uuid4() + document.tenant_id = dataset.tenant_id + document.dataset_id = dataset.id + document.position = 1 + document.data_source_type = "upload_file" + document.data_source_info = '{"upload_file_id": "test_file_id"}' + document.batch = fake.uuid4() + document.name = f"Test Document {fake.word()}.txt" + document.created_from = "upload_file" + document.created_by = account.id + document.created_api_request_id = fake.uuid4() + document.processing_started_at = fake.date_time_this_year() + document.file_id = fake.uuid4() + document.word_count = fake.random_int(min=100, max=1000) + document.parsing_completed_at = fake.date_time_this_year() + document.cleaning_completed_at = fake.date_time_this_year() + document.splitting_completed_at = fake.date_time_this_year() + document.tokens = fake.random_int(min=50, max=500) + document.indexing_started_at = fake.date_time_this_year() + document.indexing_completed_at = fake.date_time_this_year() + document.indexing_status = "completed" + document.enabled = True + document.archived = False + document.doc_form = "text_model" # Use text_model form for testing + document.doc_language = "en" + + from extensions.ext_database import db + + db.session.add(document) + db.session.commit() + + return document + + def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None): + """ + Helper method to create test document segments with realistic data. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + document: The document containing the segments + dataset: The dataset containing the document + account: The account creating the segments + count: Number of segments to create + fake: Faker instance for generating test data + + Returns: + List[DocumentSegment]: Created test segment instances + """ + fake = fake or Faker() + segments = [] + + for i in range(count): + segment = DocumentSegment() + segment.id = fake.uuid4() + segment.tenant_id = dataset.tenant_id + segment.dataset_id = dataset.id + segment.document_id = document.id + segment.position = i + 1 + segment.content = f"Test segment content {i + 1}: {fake.text(max_nb_chars=200)}" + segment.answer = f"Test answer {i + 1}" if i % 2 == 0 else None + segment.word_count = fake.random_int(min=10, max=100) + segment.tokens = fake.random_int(min=5, max=50) + segment.keywords = [fake.word() for _ in range(3)] + segment.index_node_id = f"node_{segment.id}" + segment.index_node_hash = fake.sha256() + segment.hit_count = 0 + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + segment.status = "completed" + segment.created_by = account.id + segment.updated_by = account.id + segment.indexing_at = fake.date_time_this_year() + segment.completed_at = fake.date_time_this_year() + segment.error = None + segment.stopped_at = None + + segments.append(segment) + + from extensions.ext_database import db + + for segment in segments: + db.session.add(segment) + db.session.commit() + + return segments + + def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None): + """ + Helper method to create a dataset process rule. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset: The dataset for the process rule + fake: Faker instance for generating test data + + Returns: + DatasetProcessRule: Created process rule instance + """ + fake = fake or Faker() + process_rule = DatasetProcessRule() + process_rule.id = fake.uuid4() + process_rule.tenant_id = dataset.tenant_id + process_rule.dataset_id = dataset.id + process_rule.mode = "automatic" + process_rule.rules = ( + "{" + '"mode": "automatic", ' + '"rules": {' + '"pre_processing_rules": [], "segmentation": ' + '{"separator": "\\n\\n", "max_tokens": 1000, "chunk_overlap": 50}}' + "}" + ) + process_rule.created_by = dataset.created_by + process_rule.updated_by = dataset.updated_by + + from extensions.ext_database import db + + db.session.add(process_rule) + db.session.commit() + + return process_rule + + def test_disable_segments_success(self, db_session_with_containers): + """ + Test successful disabling of segments from index. + + This test verifies that the task can correctly disable segments from the index + when all conditions are met, including proper index cleanup and database state updates. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 3, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor to avoid external dependencies + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify index processor was called correctly + mock_factory.assert_called_once_with(document.doc_form) + mock_processor.clean.assert_called_once() + + # Verify the call arguments (checking by attributes rather than object identity) + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # First argument should be the dataset + assert sorted(call_args[0][1]) == sorted( + [segment.index_node_id for segment in segments] + ) # Compare sorted lists to handle any order while preserving duplicates + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is False + + # Verify Redis cache cleanup was called for each segment + assert mock_redis.delete.call_count == len(segments) + for segment in segments: + expected_key = f"segment_{segment.id}_indexing" + mock_redis.delete.assert_any_call(expected_key) + + def test_disable_segments_dataset_not_found(self, db_session_with_containers): + """ + Test handling when dataset is not found. + + This test ensures that the task correctly handles cases where the specified + dataset doesn't exist, logging appropriate messages and returning early. + """ + # Arrange + fake = Faker() + non_existent_dataset_id = fake.uuid4() + non_existent_document_id = fake.uuid4() + segment_ids = [fake.uuid4()] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, non_existent_dataset_id, non_existent_document_id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when dataset is not found + mock_redis.delete.assert_not_called() + + def test_disable_segments_document_not_found(self, db_session_with_containers): + """ + Test handling when document is not found. + + This test ensures that the task correctly handles cases where the specified + document doesn't exist, logging appropriate messages and returning early. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + non_existent_document_id = fake.uuid4() + segment_ids = [fake.uuid4()] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, non_existent_document_id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when document is not found + mock_redis.delete.assert_not_called() + + def test_disable_segments_document_invalid_status(self, db_session_with_containers): + """ + Test handling when document has invalid status for disabling. + + This test ensures that the task correctly handles cases where the document + is not enabled, archived, or not completed, preventing invalid operations. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + + # Test case 1: Document not enabled + document.enabled = False + from extensions.ext_database import db + + db.session.commit() + + segment_ids = [segment.id for segment in segments] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when document status is invalid + mock_redis.delete.assert_not_called() + + # Test case 2: Document archived + document.enabled = True + document.archived = True + db.session.commit() + + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + mock_redis.delete.assert_not_called() + + # Test case 3: Document indexing not completed + document.enabled = True + document.archived = False + document.indexing_status = "indexing" + db.session.commit() + + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + mock_redis.delete.assert_not_called() + + def test_disable_segments_no_segments_found(self, db_session_with_containers): + """ + Test handling when no segments are found for the given IDs. + + This test ensures that the task correctly handles cases where the specified + segment IDs don't exist or don't match the dataset/document criteria. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + # Use non-existent segment IDs + non_existent_segment_ids = [fake.uuid4() for _ in range(3)] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(non_existent_segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when no segments are found + mock_redis.delete.assert_not_called() + + def test_disable_segments_index_processor_error(self, db_session_with_containers): + """ + Test handling when index processor encounters an error. + + This test verifies that the task correctly handles index processor errors + by rolling back segment states and ensuring proper cleanup. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor to raise an exception + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean.side_effect = Exception("Index processor error") + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify segments were rolled back to enabled state + from extensions.ext_database import db + + db.session.refresh(segments[0]) + db.session.refresh(segments[1]) + + # Check that segments are re-enabled after error + updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() + + for segment in updated_segments: + assert segment.enabled is True + assert segment.disabled_at is None + assert segment.disabled_by is None + + # Verify Redis cache cleanup was still called + assert mock_redis.delete.call_count == len(segments) + + def test_disable_segments_with_different_doc_forms(self, db_session_with_containers): + """ + Test disabling segments with different document forms. + + This test verifies that the task correctly handles different document forms + (paragraph, qa, parent_child) and initializes the appropriate index processor. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Test different document forms + doc_forms = ["text_model", "qa_model", "hierarchical_model"] + + for doc_form in doc_forms: + # Update document form + document.doc_form = doc_form + from extensions.ext_database import db + + db.session.commit() + + # Mock the index processor factory + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + mock_factory.assert_called_with(doc_form) + + def test_disable_segments_performance_timing(self, db_session_with_containers): + """ + Test that the task properly measures and logs performance timing. + + This test verifies that the task correctly measures execution time + and logs performance metrics for monitoring purposes. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 3, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Mock time.perf_counter to control timing + with patch("tasks.disable_segments_from_index_task.time.perf_counter") as mock_perf_counter: + mock_perf_counter.side_effect = [1000.0, 1000.5] # 0.5 seconds execution time + + # Mock logger to capture log messages + with patch("tasks.disable_segments_from_index_task.logger") as mock_logger: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify performance logging + mock_logger.info.assert_called() + log_calls = [call[0][0] for call in mock_logger.info.call_args_list] + performance_log = next((call for call in log_calls if "latency" in call), None) + assert performance_log is not None + assert "0.5" in performance_log # Should log the execution time + + def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers): + """ + Test that Redis cache is properly cleaned up for all segments. + + This test verifies that the task correctly removes indexing cache entries + from Redis for all processed segments, preventing stale cache issues. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 5, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client to track delete calls + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify Redis delete was called for each segment + assert mock_redis.delete.call_count == len(segments) + + # Verify correct cache keys were used + expected_keys = [f"segment_{segment.id}_indexing" for segment in segments] + actual_calls = [call[0][0] for call in mock_redis.delete.call_args_list] + + for expected_key in expected_keys: + assert expected_key in actual_calls + + def test_disable_segments_database_session_cleanup(self, db_session_with_containers): + """ + Test that database session is properly closed after task execution. + + This test verifies that the task correctly manages database sessions + and ensures proper cleanup to prevent connection leaks. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + segment_ids = [segment.id for segment in segments] + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Mock db.session.close to verify it's called + with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close: + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Verify session was closed + mock_close.assert_called() + + def test_disable_segments_empty_segment_ids(self, db_session_with_containers): + """ + Test handling when empty segment IDs list is provided. + + This test ensures that the task correctly handles edge cases where + an empty list of segment IDs is provided. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + empty_segment_ids = [] + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + # Act + result = disable_segments_from_index_task(empty_segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + # Redis should not be called when no segments are provided + mock_redis.delete.assert_not_called() + + def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers): + """ + Test handling when some segment IDs are valid and others are invalid. + + This test verifies that the task correctly processes only the valid + segment IDs and ignores invalid ones. + """ + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + dataset = self._create_test_dataset(db_session_with_containers, account, fake) + document = self._create_test_document(db_session_with_containers, dataset, account, fake) + segments = self._create_test_segments(db_session_with_containers, document, dataset, account, 2, fake) + self._create_dataset_process_rule(db_session_with_containers, dataset, fake) + + # Mix valid and invalid segment IDs + valid_segment_ids = [segment.id for segment in segments] + invalid_segment_ids = [fake.uuid4() for _ in range(2)] + mixed_segment_ids = valid_segment_ids + invalid_segment_ids + + # Mock the index processor + with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + # Mock Redis client + with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: + mock_redis.delete.return_value = True + + # Act + result = disable_segments_from_index_task(mixed_segment_ids, dataset.id, document.id) + + # Assert + assert result is None # Task should complete without returning a value + + # Verify index processor was called with only valid segment node IDs + expected_node_ids = [segment.index_node_id for segment in segments] + mock_processor.clean.assert_called_once() + + # Verify the call arguments + call_args = mock_processor.clean.call_args + assert call_args[0][0].id == dataset.id # First argument should be the dataset + assert sorted(call_args[0][1]) == sorted( + expected_node_ids + ) # Compare sorted lists to handle any order while preserving duplicates + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is False + + # Verify Redis cleanup was called only for valid segments + assert mock_redis.delete.call_count == len(segments) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py new file mode 100644 index 0000000000..f75dcf06e1 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -0,0 +1,554 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from extensions.ext_database import db +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document +from tasks.document_indexing_task import document_indexing_task + + +class TestDocumentIndexingTask: + """Integration tests for document_indexing_task using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner, + patch("tasks.document_indexing_task.FeatureService") as mock_feature_service, + ): + # Setup mock indexing runner + mock_runner_instance = MagicMock() + mock_indexing_runner.return_value = mock_runner_instance + + # Setup mock feature service + mock_features = MagicMock() + mock_features.billing.enabled = False + mock_feature_service.get_features.return_value = mock_features + + yield { + "indexing_runner": mock_indexing_runner, + "indexing_runner_instance": mock_runner_instance, + "feature_service": mock_feature_service, + "features": mock_features, + } + + def _create_test_dataset_and_documents( + self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + ): + """ + Helper method to create a test dataset and documents for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + document_count: Number of documents to create + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(document_count): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def _create_test_dataset_with_billing_features( + self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ): + """ + Helper method to create a test dataset with billing features configured. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + billing_enabled: Whether billing is enabled + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(3): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Configure billing features + mock_external_service_dependencies["features"].billing.enabled = billing_enabled + if billing_enabled: + mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox" + mock_external_service_dependencies["features"].vector_space.limit = 100 + mock_external_service_dependencies["features"].vector_space.size = 50 + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def test_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful document indexing with multiple documents. + + This test verifies: + - Proper dataset retrieval from database + - Correct document processing and status updates + - IndexingRunner integration + - Database state updates + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=3 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the task + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify the expected outcomes + # Verify indexing runner was called correctly + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were updated to parsing status + for document in documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None + + # Verify the run method was called with correct documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 3 + + def test_document_indexing_task_dataset_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent dataset. + + This test verifies: + - Proper error handling for missing datasets + - Early return without processing + - Database session cleanup + - No unnecessary indexing runner calls + """ + # Arrange: Use non-existent dataset ID + fake = Faker() + non_existent_dataset_id = fake.uuid4() + document_ids = [fake.uuid4() for _ in range(3)] + + # Act: Execute the task with non-existent dataset + document_indexing_task(non_existent_dataset_id, document_ids) + + # Assert: Verify no processing occurred + mock_external_service_dependencies["indexing_runner"].assert_not_called() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + + def test_document_indexing_task_document_not_found_in_dataset( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling when some documents don't exist in the dataset. + + This test verifies: + - Only existing documents are processed + - Non-existent documents are ignored + - Indexing runner receives only valid documents + - Database state updates correctly + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + + # Mix existing and non-existent document IDs + fake = Faker() + existing_document_ids = [doc.id for doc in documents] + non_existent_document_ids = [fake.uuid4() for _ in range(2)] + all_document_ids = existing_document_ids + non_existent_document_ids + + # Act: Execute the task with mixed document IDs + document_indexing_task(dataset.id, all_document_ids) + + # Assert: Verify only existing documents were processed + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify only existing documents were updated + for document in documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None + + # Verify the run method was called with only existing documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 2 # Only existing documents + + def test_document_indexing_task_indexing_runner_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of IndexingRunner exceptions. + + This test verifies: + - Exceptions from IndexingRunner are properly caught + - Task completes without raising exceptions + - Database session is properly closed + - Error logging occurs + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock IndexingRunner to raise an exception + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception( + "Indexing runner failed" + ) + + # Act: Execute the task + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify exception was handled gracefully + # The task should complete without raising exceptions + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were still updated to parsing status before the exception + for document in documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None + + def test_document_indexing_task_mixed_document_states( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test processing documents with mixed initial states. + + This test verifies: + - Documents with different initial states are handled correctly + - Only valid documents are processed + - Database state updates are consistent + - IndexingRunner receives correct documents + """ + # Arrange: Create test data + dataset, base_documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + + # Create additional documents with different states + fake = Faker() + extra_documents = [] + + # Document with different indexing status + doc1 = Document( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=2, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=dataset.created_by, + indexing_status="completed", # Already completed + enabled=True, + ) + db.session.add(doc1) + extra_documents.append(doc1) + + # Document with disabled status + doc2 = Document( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=3, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=dataset.created_by, + indexing_status="waiting", + enabled=False, # Disabled + ) + db.session.add(doc2) + extra_documents.append(doc2) + + db.session.commit() + + all_documents = base_documents + extra_documents + document_ids = [doc.id for doc in all_documents] + + # Act: Execute the task with mixed document states + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify processing + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify all documents were updated to parsing status + for document in all_documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None + + # Verify the run method was called with all documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 4 + + def test_document_indexing_task_billing_sandbox_plan_batch_limit( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test billing validation for sandbox plan batch upload limit. + + This test verifies: + - Sandbox plan batch upload limit enforcement + - Error handling for batch upload limit exceeded + - Document status updates to error state + - Proper error message recording + """ + # Arrange: Create test data with billing enabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ) + + # Configure sandbox plan with batch limit + mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox" + + # Create more documents than sandbox plan allows (limit is 1) + fake = Faker() + extra_documents = [] + for i in range(2): # Total will be 5 documents (3 existing + 2 new) + document = Document( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=i + 3, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=dataset.created_by, + indexing_status="waiting", + enabled=True, + ) + db.session.add(document) + extra_documents.append(document) + + db.session.commit() + all_documents = documents + extra_documents + document_ids = [doc.id for doc in all_documents] + + # Act: Execute the task with too many documents for sandbox plan + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify error handling + for document in all_documents: + db.session.refresh(document) + assert document.indexing_status == "error" + assert document.error is not None + assert "batch upload" in document.error + assert document.stopped_at is not None + + # Verify no indexing runner was called + mock_external_service_dependencies["indexing_runner"].assert_not_called() + + def test_document_indexing_task_billing_disabled_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful processing when billing is disabled. + + This test verifies: + - Processing continues normally when billing is disabled + - No billing validation occurs + - Documents are processed successfully + - IndexingRunner is called correctly + """ + # Arrange: Create test data with billing disabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=False + ) + + document_ids = [doc.id for doc in documents] + + # Act: Execute the task with billing disabled + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify successful processing + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were updated to parsing status + for document in documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None + + def test_document_indexing_task_document_is_paused_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of DocumentIsPausedError from IndexingRunner. + + This test verifies: + - DocumentIsPausedError is properly caught and handled + - Task completes without raising exceptions + - Appropriate logging occurs + - Database session is properly closed + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock IndexingRunner to raise DocumentIsPausedError + from core.indexing_runner import DocumentIsPausedError + + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = DocumentIsPausedError( + "Document indexing is paused" + ) + + # Act: Execute the task + document_indexing_task(dataset.id, document_ids) + + # Assert: Verify exception was handled gracefully + # The task should complete without raising exceptions + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were still updated to parsing status before the exception + for document in documents: + db.session.refresh(document) + assert document.indexing_status == "parsing" + assert document.processing_started_at is not None diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index aefb4bf8b0..b6697ac5d4 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -9,7 +9,6 @@ from flask_restx import Api import services.errors.account from controllers.console.auth.error import AuthenticationFailedError from controllers.console.auth.login import LoginApi -from controllers.console.error import AccountNotFound class TestAuthenticationSecurity: @@ -27,31 +26,33 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.FeatureService.get_system_features") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.authenticate") - @patch("controllers.console.auth.login.AccountService.send_reset_password_email") + @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") def test_login_invalid_email_with_registration_allowed( - self, mock_get_invitation, mock_send_email, mock_authenticate, mock_is_rate_limit, mock_features, mock_db + self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): - """Test that invalid email sends reset password email when registration is allowed.""" + """Test that invalid email raises AuthenticationFailedError when account not found.""" # Arrange mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None - mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found") + mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = True - mock_send_email.return_value = "token123" # Act with self.app.test_request_context( "/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"} ): login_api = LoginApi() - result = login_api.post() - # Assert - assert result == {"result": "fail", "data": "token123", "code": "account_not_found"} - mock_send_email.assert_called_once_with(email="nonexistent@example.com", language="en-US") + # Assert + with pytest.raises(AuthenticationFailedError) as exc_info: + login_api.post() + + assert exc_info.value.error_code == "authentication_failed" + assert exc_info.value.description == "Invalid email or password." + mock_add_rate_limit.assert_called_once_with("nonexistent@example.com") @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @@ -87,16 +88,17 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.FeatureService.get_system_features") @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.AccountService.authenticate") + @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") def test_login_invalid_email_with_registration_disabled( - self, mock_get_invitation, mock_authenticate, mock_is_rate_limit, mock_features, mock_db + self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): - """Test that invalid email raises AccountNotFound when registration is disabled.""" + """Test that invalid email raises AuthenticationFailedError when account not found.""" # Arrange mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None - mock_authenticate.side_effect = services.errors.account.AccountNotFoundError("Account not found") + mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = False @@ -107,10 +109,12 @@ class TestAuthenticationSecurity: login_api = LoginApi() # Assert - with pytest.raises(AccountNotFound) as exc_info: + with pytest.raises(AuthenticationFailedError) as exc_info: login_api.post() - assert exc_info.value.error_code == "account_not_found" + assert exc_info.value.error_code == "authentication_failed" + assert exc_info.value.description == "Invalid email or password." + mock_add_rate_limit.assert_called_once_with("nonexistent@example.com") @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.FeatureService.get_system_features") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 037c9f2745..a7bdf5de33 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -12,7 +12,7 @@ from controllers.console.auth.oauth import ( ) from libs.oauth import OAuthUserInfo from models.account import AccountStatus -from services.errors.account import AccountNotFoundError +from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @@ -451,7 +451,7 @@ class TestAccountGeneration: with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): if not allow_register and not existing_account: - with pytest.raises(AccountNotFoundError): + with pytest.raises(AccountRegisterError): _generate_account("github", user_info) else: result = _generate_account("github", user_info) diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index f1d741602a..895ebdd751 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -29,7 +29,7 @@ class TestHandleMCPRequest: """Setup test fixtures""" self.app = Mock(spec=App) self.app.name = "test_app" - self.app.mode = AppMode.CHAT.value + self.app.mode = AppMode.CHAT self.mcp_server = Mock(spec=AppMCPServer) self.mcp_server.description = "Test server" @@ -196,7 +196,7 @@ class TestIndividualHandlers: def test_handle_list_tools(self): """Test list tools handler""" app_name = "test_app" - app_mode = AppMode.CHAT.value + app_mode = AppMode.CHAT description = "Test server" parameters_dict: dict[str, str] = {} user_input_form: list[VariableEntity] = [] @@ -212,7 +212,7 @@ class TestIndividualHandlers: def test_handle_call_tool(self, mock_app_generate): """Test call tool handler""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT # Create mock request mock_request = Mock() @@ -252,7 +252,7 @@ class TestUtilityFunctions: def test_build_parameter_schema_chat_mode(self): """Test building parameter schema for chat mode""" - app_mode = AppMode.CHAT.value + app_mode = AppMode.CHAT parameters_dict: dict[str, str] = {"name": "Enter your name"} user_input_form = [ @@ -275,7 +275,7 @@ class TestUtilityFunctions: def test_build_parameter_schema_workflow_mode(self): """Test building parameter schema for workflow mode""" - app_mode = AppMode.WORKFLOW.value + app_mode = AppMode.WORKFLOW parameters_dict: dict[str, str] = {"input_text": "Enter text"} user_input_form = [ @@ -298,7 +298,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_chat_mode(self): """Test preparing tool arguments for chat mode""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT arguments = {"query": "test question", "name": "John"} @@ -312,7 +312,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_workflow_mode(self): """Test preparing tool arguments for workflow mode""" app = Mock(spec=App) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW arguments = {"input_text": "test input"} @@ -324,7 +324,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_completion_mode(self): """Test preparing tool arguments for completion mode""" app = Mock(spec=App) - app.mode = AppMode.COMPLETION.value + app.mode = AppMode.COMPLETION arguments = {"name": "John"} @@ -336,7 +336,7 @@ class TestUtilityFunctions: def test_extract_answer_from_mapping_response_chat(self): """Test extracting answer from mapping response for chat mode""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT response = {"answer": "test answer", "other": "data"} @@ -347,7 +347,7 @@ class TestUtilityFunctions: def test_extract_answer_from_mapping_response_workflow(self): """Test extracting answer from mapping response for workflow mode""" app = Mock(spec=App) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW response = {"data": {"outputs": {"result": "test result"}}} diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index 607728efd8..6689e13b96 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -20,7 +20,6 @@ def test_firecrawl_web_extractor_crawl_mode(mocker): } mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl)) job_id = firecrawl_app.crawl_url(url, params) - print(f"job_id: {job_id}") assert job_id is not None assert isinstance(job_id, str) diff --git a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py index 57ddacd13d..0bf4a3cf91 100644 --- a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py +++ b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py @@ -15,7 +15,7 @@ class FakeResponse: self.status_code = status_code self.headers = headers or {} self.content = content - self.text = text if text else content.decode("utf-8", errors="ignore") + self.text = text or content.decode("utf-8", errors="ignore") # --------------------------- diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 4c8d983d20..c9cfabca6e 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -129,7 +129,6 @@ class TestSegmentDumpAndLoad: """Test basic segment serialization compatibility""" model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")]) json = model.model_dump_json() - print("Json: ", json) loaded = _Segments.model_validate_json(json) assert loaded == model @@ -137,7 +136,6 @@ class TestSegmentDumpAndLoad: """Test number segment serialization compatibility""" model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)]) json = model.model_dump_json() - print("Json: ", json) loaded = _Segments.model_validate_json(json) assert loaded == model @@ -145,7 +143,6 @@ class TestSegmentDumpAndLoad: """Test variable serialization compatibility""" model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")]) json = model.model_dump_json() - print("Json: ", json) restored = _Variables.model_validate_json(json) assert restored == model diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 2d8d433c46..b8f901770c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -342,4 +342,3 @@ def test_http_request_node_form_with_multiple_files(monkeypatch: pytest.MonkeyPa assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs is not None assert result.outputs["body"] == '{"status":"success"}' - print(result.outputs["body"]) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index ea8a88692f..2765048734 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -1,7 +1,6 @@ import base64 import uuid from collections.abc import Sequence -from typing import Optional from unittest import mock import pytest @@ -47,7 +46,7 @@ class MockTokenBufferMemory: self.history_messages = history_messages or [] def get_history_prompt_messages( - self, max_token_limit: int = 2000, message_limit: Optional[int] = None + self, max_token_limit: int = 2000, message_limit: int | None = None ) -> Sequence[PromptMessage]: if message_limit is not None: return self.history_messages[-message_limit * 2 :] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 987eaf7534..49a88e57b3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -143,15 +143,11 @@ def test_remove_first_from_array(): node.init_node_data(node_config["data"]) # Skip the mock assertion since we're in a test environment - # Print the variable before running - print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}") # Run the node result = list(node.run()) - # Print the variable after running and the result - print(f"After: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}") - print(f"Result: {result}") + # Completed run got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py new file mode 100644 index 0000000000..324f58abf6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -0,0 +1,456 @@ +import pytest + +from core.file.enums import FileType +from core.file.models import File, FileTransferMethod +from core.variables.variables import StringVariable +from core.workflow.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.system_variable import SystemVariable +from core.workflow.workflow_entry import WorkflowEntry + + +class TestWorkflowEntry: + """Test WorkflowEntry class methods.""" + + def test_mapping_user_inputs_to_variable_pool_with_system_variables(self): + """Test mapping system variables from user inputs to variable pool.""" + # Initialize variable pool with system variables + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + ), + user_inputs={}, + ) + + # Define variable mapping - sys variables mapped to other nodes + variable_mapping = { + "node1.input1": ["node1", "input1"], # Regular mapping + "node2.query": ["node2", "query"], # Regular mapping + "sys.user_id": ["output_node", "user"], # System variable mapping + } + + # User inputs including sys variables + user_inputs = { + "node1.input1": "new_user_id", + "node2.query": "test query", + "sys.user_id": "system_user", + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variables were added to pool + # Note: variable_pool.get returns Variable objects, not raw values + node1_var = variable_pool.get(["node1", "input1"]) + assert node1_var is not None + assert node1_var.value == "new_user_id" + + node2_var = variable_pool.get(["node2", "query"]) + assert node2_var is not None + assert node2_var.value == "test query" + + # System variable gets mapped to output node + output_var = variable_pool.get(["output_node", "user"]) + assert output_var is not None + assert output_var.value == "system_user" + + def test_mapping_user_inputs_to_variable_pool_with_env_variables(self): + """Test mapping environment variables from user inputs to variable pool.""" + # Initialize variable pool with environment variables + env_var = StringVariable(name="API_KEY", value="existing_key") + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + environment_variables=[env_var], + user_inputs={}, + ) + + # Add env variable to pool (simulating initialization) + variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var) + + # Define variable mapping - env variables should not be overridden + variable_mapping = { + "node1.api_key": [ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], + "node2.new_env": [ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"], + } + + # User inputs + user_inputs = { + "node1.api_key": "user_provided_key", # This should not override existing env var + "node2.new_env": "new_env_value", + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify env variable was not overridden + env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"]) + assert env_value is not None + assert env_value.value == "existing_key" # Should remain unchanged + + # New env variables from user input should not be added + assert variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "NEW_ENV"]) is None + + def test_mapping_user_inputs_to_variable_pool_with_conversation_variables(self): + """Test mapping conversation variables from user inputs to variable pool.""" + # Initialize variable pool with conversation variables + conv_var = StringVariable(name="last_message", value="Hello") + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + conversation_variables=[conv_var], + user_inputs={}, + ) + + # Add conversation variable to pool + variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "last_message"], conv_var) + + # Define variable mapping + variable_mapping = { + "node1.message": ["node1", "message"], # Map to regular node + "conversation.context": ["chat_node", "context"], # Conversation var to regular node + } + + # User inputs + user_inputs = { + "node1.message": "Updated message", + "conversation.context": "New context", + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variables were added to their target nodes + node1_var = variable_pool.get(["node1", "message"]) + assert node1_var is not None + assert node1_var.value == "Updated message" + + chat_var = variable_pool.get(["chat_node", "context"]) + assert chat_var is not None + assert chat_var.value == "New context" + + def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self): + """Test mapping regular node variables from user inputs to variable pool.""" + # Initialize empty variable pool + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping for regular nodes + variable_mapping = { + "input_node.text": ["input_node", "text"], + "llm_node.prompt": ["llm_node", "prompt"], + "code_node.input": ["code_node", "input"], + } + + # User inputs + user_inputs = { + "input_node.text": "User input text", + "llm_node.prompt": "Generate a summary", + "code_node.input": {"key": "value"}, + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify regular variables were added + text_var = variable_pool.get(["input_node", "text"]) + assert text_var is not None + assert text_var.value == "User input text" + + prompt_var = variable_pool.get(["llm_node", "prompt"]) + assert prompt_var is not None + assert prompt_var.value == "Generate a summary" + + input_var = variable_pool.get(["code_node", "input"]) + assert input_var is not None + assert input_var.value == {"key": "value"} + + def test_mapping_user_inputs_with_file_handling(self): + """Test mapping file inputs from user inputs to variable pool.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping + variable_mapping = { + "file_node.file": ["file_node", "file"], + "file_node.files": ["file_node", "files"], + } + + # User inputs with file data - using remote_url which doesn't require upload_file_id + user_inputs = { + "file_node.file": { + "type": "document", + "transfer_method": "remote_url", + "url": "http://example.com/test.pdf", + }, + "file_node.files": [ + { + "type": "image", + "transfer_method": "remote_url", + "url": "http://example.com/image1.jpg", + }, + { + "type": "image", + "transfer_method": "remote_url", + "url": "http://example.com/image2.jpg", + }, + ], + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify file was converted and added + file_var = variable_pool.get(["file_node", "file"]) + assert file_var is not None + assert file_var.value.type == FileType.DOCUMENT + assert file_var.value.transfer_method == FileTransferMethod.REMOTE_URL + + # Verify file list was converted and added + files_var = variable_pool.get(["file_node", "files"]) + assert files_var is not None + assert isinstance(files_var.value, list) + assert len(files_var.value) == 2 + assert all(isinstance(f, File) for f in files_var.value) + assert files_var.value[0].type == FileType.IMAGE + assert files_var.value[1].type == FileType.IMAGE + assert files_var.value[0].type == FileType.IMAGE + assert files_var.value[1].type == FileType.IMAGE + + def test_mapping_user_inputs_missing_variable_error(self): + """Test that mapping raises error when required variable is missing.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping + variable_mapping = { + "node1.required_input": ["node1", "required_input"], + } + + # User inputs without required variable + user_inputs = { + "node1.other_input": "some value", + } + + # Should raise ValueError for missing variable + with pytest.raises(ValueError, match="Variable key node1.required_input not found in user inputs"): + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + def test_mapping_user_inputs_with_alternative_key_format(self): + """Test mapping with alternative key format (without node prefix).""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping + variable_mapping = { + "node1.input": ["node1", "input"], + } + + # User inputs with alternative key format + user_inputs = { + "input": "value without node prefix", # Alternative format without node prefix + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variable was added using alternative key + input_var = variable_pool.get(["node1", "input"]) + assert input_var is not None + assert input_var.value == "value without node prefix" + + def test_mapping_user_inputs_with_complex_selectors(self): + """Test mapping with complex node variable keys.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping - selectors can only have 2 elements + variable_mapping = { + "node1.data.field1": ["node1", "data_field1"], # Complex key mapped to simple selector + "node2.config.settings.timeout": ["node2", "timeout"], # Complex key mapped to simple selector + } + + # User inputs + user_inputs = { + "node1.data.field1": "nested value", + "node2.config.settings.timeout": 30, + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify variables were added with simple selectors + data_var = variable_pool.get(["node1", "data_field1"]) + assert data_var is not None + assert data_var.value == "nested value" + + timeout_var = variable_pool.get(["node2", "timeout"]) + assert timeout_var is not None + assert timeout_var.value == 30 + + def test_mapping_user_inputs_invalid_node_variable(self): + """Test that mapping handles invalid node variable format.""" + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + ) + + # Define variable mapping with single element node variable (at least one dot is required) + variable_mapping = { + "singleelement": ["node1", "input"], # No dot separator + } + + user_inputs = {"singleelement": "some value"} # Must use exact key + + # Should NOT raise error - function accepts it and uses direct key + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify it was added + var = variable_pool.get(["node1", "input"]) + assert var is not None + assert var.value == "some value" + + def test_mapping_all_variable_types_together(self): + """Test mapping all four types of variables in one operation.""" + # Initialize variable pool with some existing variables + env_var = StringVariable(name="API_KEY", value="existing_key") + conv_var = StringVariable(name="session_id", value="session123") + + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="test_user", + app_id="test_app", + query="initial query", + ), + environment_variables=[env_var], + conversation_variables=[conv_var], + user_inputs={}, + ) + + # Add existing variables to pool + variable_pool.add([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"], env_var) + variable_pool.add([CONVERSATION_VARIABLE_NODE_ID, "session_id"], conv_var) + + # Define comprehensive variable mapping + variable_mapping = { + # System variables mapped to regular nodes + "sys.user_id": ["start", "user"], + "sys.app_id": ["start", "app"], + # Environment variables (won't be overridden) + "env.API_KEY": ["config", "api_key"], + # Conversation variables mapped to regular nodes + "conversation.session_id": ["chat", "session"], + # Regular variables + "input.text": ["input", "text"], + "process.data": ["process", "data"], + } + + # User inputs + user_inputs = { + "sys.user_id": "new_user", + "sys.app_id": "new_app", + "env.API_KEY": "attempted_override", # Should not override env var + "conversation.session_id": "new_session", + "input.text": "user input text", + "process.data": {"value": 123, "status": "active"}, + } + + # Execute mapping + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id="test_tenant", + ) + + # Verify system variables were added to their target nodes + start_user = variable_pool.get(["start", "user"]) + assert start_user is not None + assert start_user.value == "new_user" + + start_app = variable_pool.get(["start", "app"]) + assert start_app is not None + assert start_app.value == "new_app" + + # Verify env variable was not overridden (still has original value) + env_value = variable_pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY"]) + assert env_value is not None + assert env_value.value == "existing_key" + + # Environment variables get mapped to other nodes even when they exist in env pool + # But the original env value remains unchanged + config_api_key = variable_pool.get(["config", "api_key"]) + assert config_api_key is not None + assert config_api_key.value == "attempted_override" + + # Verify conversation variable was mapped to target node + chat_session = variable_pool.get(["chat", "session"]) + assert chat_session is not None + assert chat_session.value == "new_session" + + # Verify regular variables were added + input_text = variable_pool.get(["input", "text"]) + assert input_text is not None + assert input_text.value == "user input text" + + process_data = variable_pool.get(["process", "data"]) + assert process_data is not None + assert process_data.value == {"value": 123, "status": "active"} diff --git a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py index 7d295cecf2..958072223e 100644 --- a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py +++ b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py @@ -11,12 +11,12 @@ class TestSupabaseStorage: def test_init_success_with_all_config(self): """Test successful initialization when all required config is provided.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -31,7 +31,7 @@ class TestSupabaseStorage: def test_init_raises_error_when_url_missing(self): """Test initialization raises ValueError when SUPABASE_URL is None.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = None mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" @@ -41,7 +41,7 @@ class TestSupabaseStorage: def test_init_raises_error_when_api_key_missing(self): """Test initialization raises ValueError when SUPABASE_API_KEY is None.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = None mock_config.SUPABASE_BUCKET_NAME = "test-bucket" @@ -51,7 +51,7 @@ class TestSupabaseStorage: def test_init_raises_error_when_bucket_name_missing(self): """Test initialization raises ValueError when SUPABASE_BUCKET_NAME is None.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = None @@ -61,12 +61,12 @@ class TestSupabaseStorage: def test_create_bucket_when_not_exists(self): """Test create_bucket creates bucket when it doesn't exist.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -77,12 +77,12 @@ class TestSupabaseStorage: def test_create_bucket_when_exists(self): """Test create_bucket does not create bucket when it already exists.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -94,12 +94,12 @@ class TestSupabaseStorage: @pytest.fixture def storage_with_mock_client(self): """Fixture providing SupabaseStorage with mocked client.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -251,12 +251,12 @@ class TestSupabaseStorage: def test_bucket_exists_returns_true_when_bucket_found(self): """Test bucket_exists returns True when bucket is found in list.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -271,12 +271,12 @@ class TestSupabaseStorage: def test_bucket_exists_returns_false_when_bucket_not_found(self): """Test bucket_exists returns False when bucket is not found in list.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client @@ -294,12 +294,12 @@ class TestSupabaseStorage: def test_bucket_exists_returns_false_when_no_buckets(self): """Test bucket_exists returns False when no buckets exist.""" - with patch("extensions.storage.supabase_storage.dify_config") as mock_config: + with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config: mock_config.SUPABASE_URL = "https://test.supabase.co" mock_config.SUPABASE_API_KEY = "test-api-key" mock_config.SUPABASE_BUCKET_NAME = "test-bucket" - with patch("extensions.storage.supabase_storage.Client") as mock_client_class: + with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index 2a193ef2d7..1e98e99aab 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -4,7 +4,7 @@ from typing import Any from uuid import uuid4 import pytest -from hypothesis import given +from hypothesis import given, settings from hypothesis import strategies as st from core.file import File, FileTransferMethod, FileType @@ -486,13 +486,14 @@ def _generate_file(draw) -> File: def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]: return st.one_of( st.none(), - st.integers(), - st.floats(), - st.text(), + st.integers(min_value=-(10**6), max_value=10**6), + st.floats(allow_nan=True, allow_infinity=False), + st.text(max_size=50), _generate_file(), ) +@settings(max_examples=50) @given(_scalar_value()) def test_build_segment_and_extract_values_for_scalar_types(value): seg = variable_factory.build_segment(value) @@ -503,7 +504,8 @@ def test_build_segment_and_extract_values_for_scalar_types(value): assert seg.value == value -@given(st.lists(_scalar_value())) +@settings(max_examples=50) +@given(values=st.lists(_scalar_value(), max_size=20)) def test_build_segment_and_extract_values_for_array_types(values): seg = variable_factory.build_segment(values) assert seg.value == values diff --git a/api/tests/unit_tests/libs/test_flask_utils.py b/api/tests/unit_tests/libs/test_flask_utils.py index fb46ba50f3..e30433bfce 100644 --- a/api/tests/unit_tests/libs/test_flask_utils.py +++ b/api/tests/unit_tests/libs/test_flask_utils.py @@ -1,6 +1,5 @@ import contextvars import threading -from typing import Optional import pytest from flask import Flask @@ -29,7 +28,7 @@ def login_app(app: Flask) -> Flask: login_manager.init_app(app) @login_manager.user_loader - def load_user(user_id: str) -> Optional[User]: + def load_user(user_id: str) -> User | None: if user_id == "test_user": return User("test_user") return None diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index d7404ee90a..737202f8de 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -10,7 +10,6 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, - AccountNotFoundError, AccountPasswordError, AccountRegisterError, CurrentPasswordIncorrectError, @@ -195,7 +194,7 @@ class TestAccountService: # Execute test and verify exception self._assert_exception_raised( - AccountNotFoundError, AccountService.authenticate, "notfound@example.com", "password" + AccountPasswordError, AccountService.authenticate, "notfound@example.com", "password" ) def test_authenticate_account_banned(self, mock_db_dependencies): diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py index 1881ceac26..69766188f3 100644 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py @@ -1,5 +1,4 @@ import datetime -from typing import Optional # Mock redis_client before importing dataset_service from unittest.mock import Mock, call, patch @@ -37,7 +36,7 @@ class DocumentBatchUpdateTestDataFactory: enabled: bool = True, archived: bool = False, indexing_status: str = "completed", - completed_at: Optional[datetime.datetime] = None, + completed_at: datetime.datetime | None = None, **kwargs, ) -> Mock: """Create a mock document with specified attributes.""" diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py index fb23863043..df5596f5c8 100644 --- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, Optional +from typing import Any # Mock redis_client before importing dataset_service from unittest.mock import Mock, create_autospec, patch @@ -24,9 +24,9 @@ class DatasetUpdateTestDataFactory: description: str = "old_description", indexing_technique: str = "high_quality", retrieval_model: str = "old_model", - embedding_model_provider: Optional[str] = None, - embedding_model: Optional[str] = None, - collection_binding_id: Optional[str] = None, + embedding_model_provider: str | None = None, + embedding_model: str | None = None, + collection_binding_id: str | None = None, **kwargs, ) -> Mock: """Create a mock dataset with specified attributes.""" diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index ad65175e89..0ff1edc950 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest.mock import Mock, create_autospec, patch import pytest @@ -146,19 +147,17 @@ class TestMetadataBugCompleteValidation: # Console API create console_create_file = "api/controllers/console/datasets/metadata.py" if os.path.exists(console_create_file): - with open(console_create_file) as f: - content = f.read() - # Should contain nullable=False, not nullable=True - assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0] + content = Path(console_create_file).read_text() + # Should contain nullable=False, not nullable=True + assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0] # Service API create service_create_file = "api/controllers/service_api/dataset/metadata.py" if os.path.exists(service_create_file): - with open(service_create_file) as f: - content = f.read() - # Should contain nullable=False, not nullable=True - create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0] - assert "nullable=True" not in create_api_section + content = Path(service_create_file).read_text() + # Should contain nullable=False, not nullable=True + create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0] + assert "nullable=True" not in create_api_section class TestMetadataValidationSummary: diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 0a09167349..2ca781bae5 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -66,7 +66,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): app_model = MagicMock() app_model.id = "app_id" app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.CHAT.value + app_model.mode = AppMode.CHAT api_based_extension_id = "api_based_extension_id" mock_api_based_extension = APIBasedExtension( @@ -127,7 +127,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): app_model = MagicMock() app_model.id = "app_id" app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.WORKFLOW.value + app_model.mode = AppMode.WORKFLOW api_based_extension_id = "api_based_extension_id" mock_api_based_extension = APIBasedExtension( diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 93284eed4b..9046f785d2 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -279,8 +279,6 @@ def test_structured_output_parser(): ] for case in testcases: - print(f"Running test case: {case['name']}") - # Setup model entity model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"]) diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index 95b93651d5..9e2b0659c0 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -3,7 +3,7 @@ from textwrap import dedent import pytest from yaml import YAMLError -from core.tools.utils.yaml_utils import load_yaml_file +from core.tools.utils.yaml_utils import _load_yaml_file EXAMPLE_YAML_FILE = "example_yaml.yaml" INVALID_YAML_FILE = "invalid_yaml.yaml" @@ -56,15 +56,15 @@ def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: def test_load_yaml_non_existing_file(): - assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} - assert load_yaml_file(file_path="") == {} + with pytest.raises(FileNotFoundError): + _load_yaml_file(file_path=NON_EXISTING_YAML_FILE) with pytest.raises(FileNotFoundError): - load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False) + _load_yaml_file(file_path="") def test_load_valid_yaml_file(prepare_example_yaml_file): - yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) + yaml_data = _load_yaml_file(file_path=prepare_example_yaml_file) assert len(yaml_data) > 0 assert yaml_data["age"] == 30 assert yaml_data["gender"] == "male" @@ -77,7 +77,4 @@ def test_load_valid_yaml_file(prepare_example_yaml_file): def test_load_invalid_yaml_file(prepare_invalid_yaml_file): # yaml syntax error with pytest.raises(YAMLError): - load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False) - - # ignore error - assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {} + _load_yaml_file(file_path=prepare_invalid_yaml_file) diff --git a/api/uv.lock b/api/uv.lock index 54c4083369..56ce7108e3 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -538,6 +538,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/57/f4/a69c20ee4f660081a7dedb1ac57f29be9378e04edfcb90c526b923d4bebc/beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a", size = 142979, upload-time = "2023-04-07T15:02:50.77Z" }, ] +[[package]] +name = "bidict" +version = "0.23.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/6e/026678aa5a830e07cd9498a05d3e7e650a4f56a42f267a53d22bcda1bdc9/bidict-0.23.1.tar.gz", hash = "sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71", size = 29093, upload-time = "2024-02-18T19:09:05.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/37/e8730c3587a65eb5645d4aba2d27aae48e8003614d6aaf15dda67f702f1f/bidict-0.23.1-py3-none-any.whl", hash = "sha256:5dae8d4d79b552a71cbabc7deb25dfe8ce710b17ff41711e13010ead2abfc3e5", size = 32764, upload-time = "2024-02-18T19:09:04.156Z" }, +] + [[package]] name = "billiard" version = "4.2.1" @@ -1009,7 +1018,7 @@ wheels = [ [[package]] name = "clickzetta-connector-python" -version = "0.8.102" +version = "0.8.104" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "future" }, @@ -1023,7 +1032,7 @@ dependencies = [ { name = "urllib3" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/e5/23dcc950e873127df0135cf45144062a3207f5d2067259c73854e8ce7228/clickzetta_connector_python-0.8.102-py3-none-any.whl", hash = "sha256:c45486ae77fd82df7113ec67ec50e772372588d79c23757f8ee6291a057994a7", size = 77861, upload-time = "2025-07-17T03:11:59.543Z" }, + { url = "https://files.pythonhosted.org/packages/8f/94/c7eee2224bdab39d16dfe5bb7687f5525c7ed345b7fe8812e18a2d9a6335/clickzetta_connector_python-0.8.104-py3-none-any.whl", hash = "sha256:ae3e466d990677f96c769ec1c29318237df80c80fe9c1e21ba1eaf42bdef0207", size = 79382, upload-time = "2025-09-10T08:46:39.731Z" }, ] [[package]] @@ -1061,6 +1070,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/06/3d6badcf13db419e25b07041d9c7b4a2c331d3f4e7134445ec5df57714cd/coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934", size = 46018, upload-time = "2021-06-11T10:22:42.561Z" }, ] +[[package]] +name = "configargparse" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/85/4d/6c9ef746dfcc2a32e26f3860bb4a011c008c392b83eabdfb598d1a8bbe5d/configargparse-1.7.1.tar.gz", hash = "sha256:79c2ddae836a1e5914b71d58e4b9adbd9f7779d4e6351a637b7d2d9b6c46d3d9", size = 43958, upload-time = "2025-05-23T14:26:17.369Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/28/d28211d29bcc3620b1fece85a65ce5bb22f18670a03cd28ea4b75ede270c/configargparse-1.7.1-py3-none-any.whl", hash = "sha256:8b586a31f9d873abd1ca527ffbe58863c99f36d896e2829779803125e83be4b6", size = 25607, upload-time = "2025-05-23T14:26:15.923Z" }, +] + [[package]] name = "cos-python-sdk-v5" version = "1.9.30" @@ -1357,6 +1375,7 @@ dev = [ { name = "dotenv-linter" }, { name = "faker" }, { name = "hypothesis" }, + { name = "locust" }, { name = "lxml-stubs" }, { name = "mypy" }, { name = "pandas-stubs" }, @@ -1367,6 +1386,7 @@ dev = [ { name = "pytest-mock" }, { name = "ruff" }, { name = "scipy-stubs" }, + { name = "sseclient-py" }, { name = "testcontainers" }, { name = "ty" }, { name = "types-aiofiles" }, @@ -1533,7 +1553,7 @@ requires-dist = [ { name = "sseclient-py", specifier = "~=1.8.0" }, { name = "starlette", specifier = "==0.47.2" }, { name = "tiktoken", specifier = "~=0.9.0" }, - { name = "transformers", specifier = "~=4.53.0" }, + { name = "transformers", specifier = "~=4.56.1" }, { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" }, { name = "weave", specifier = "~=0.51.0" }, { name = "webvtt-py", specifier = "~=0.5.1" }, @@ -1549,6 +1569,7 @@ dev = [ { name = "dotenv-linter", specifier = "~=0.5.0" }, { name = "faker", specifier = "~=32.1.0" }, { name = "hypothesis", specifier = ">=6.131.15" }, + { name = "locust", specifier = ">=2.40.4" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.17.1" }, { name = "pandas-stubs", specifier = "~=2.2.3" }, @@ -1559,6 +1580,7 @@ dev = [ { name = "pytest-mock", specifier = "~=3.14.0" }, { name = "ruff", specifier = "~=0.12.3" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, + { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "testcontainers", specifier = "~=4.10.0" }, { name = "ty", specifier = "~=0.0.1a19" }, { name = "types-aiofiles", specifier = "~=24.1.0" }, @@ -2036,6 +2058,58 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/b2/5d20664ef6a077bec9f27f7a7ee761edc64946d0b1e293726a3d074a9a18/gevent-24.11.1-cp312-cp312-win_amd64.whl", hash = "sha256:68bee86b6e1c041a187347ef84cf03a792f0b6c7238378bf6ba4118af11feaae", size = 1541631, upload-time = "2024-11-11T14:55:34.977Z" }, ] +[[package]] +name = "geventhttpclient" +version = "2.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "brotli" }, + { name = "certifi" }, + { name = "gevent" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/89/19/1ca8de73dcc0596d3df01be299e940d7fc3bccbeb6f62bb8dd2d427a3a50/geventhttpclient-2.3.4.tar.gz", hash = "sha256:1749f75810435a001fc6d4d7526c92cf02b39b30ab6217a886102f941c874222", size = 83545, upload-time = "2025-06-11T13:18:14.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/c7/c4c31bd92b08c4e34073c722152b05c48c026bc6978cf04f52be7e9050d5/geventhttpclient-2.3.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fb8f6a18f1b5e37724111abbd3edf25f8f00e43dc261b11b10686e17688d2405", size = 71919, upload-time = "2025-06-11T13:16:49.796Z" }, + { url = "https://files.pythonhosted.org/packages/9d/8a/4565e6e768181ecb06677861d949b3679ed29123b6f14333e38767a17b5a/geventhttpclient-2.3.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dbb28455bb5d82ca3024f9eb7d65c8ff6707394b584519def497b5eb9e5b1222", size = 52577, upload-time = "2025-06-11T13:16:50.657Z" }, + { url = "https://files.pythonhosted.org/packages/02/a1/fb623cf478799c08f95774bc41edb8ae4c2f1317ae986b52f233d0f3fa05/geventhttpclient-2.3.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96578fc4a5707b5535d1c25a89e72583e02aafe64d14f3b4d78f9c512c6d613c", size = 51981, upload-time = "2025-06-11T13:16:52.586Z" }, + { url = "https://files.pythonhosted.org/packages/18/b2/a4ddd3d24c8aa064b19b9f180eb5e1517248518289d38af70500569ebedf/geventhttpclient-2.3.4-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19721357db976149ccf54ac279eab8139da8cdf7a11343fd02212891b6f39677", size = 114287, upload-time = "2025-08-24T12:16:47.101Z" }, + { url = "https://files.pythonhosted.org/packages/a1/cc/caac4d4bd2c72d53836dbf50018aed3747c0d0c6f1d08175a785083d9d36/geventhttpclient-2.3.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecf830cdcd1d4d28463c8e0c48f7f5fb06f3c952fff875da279385554d1d4d65", size = 115208, upload-time = "2025-08-24T12:16:48.108Z" }, + { url = "https://files.pythonhosted.org/packages/04/a2/8278bd4d16b9df88bd538824595b7b84efd6f03c7b56b2087d09be838e02/geventhttpclient-2.3.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:47dbf8a163a07f83b38b0f8a35b85e5d193d3af4522ab8a5bbecffff1a4cd462", size = 121101, upload-time = "2025-08-24T12:16:49.417Z" }, + { url = "https://files.pythonhosted.org/packages/e3/0e/a9ebb216140bd0854007ff953094b2af983cdf6d4aec49796572fcbf2606/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e39ad577b33a5be33b47bff7c2dda9b19ced4773d169d6555777cd8445c13c0", size = 118494, upload-time = "2025-06-11T13:16:54.172Z" }, + { url = "https://files.pythonhosted.org/packages/4f/95/6d45dead27e4f5db7a6d277354b0e2877c58efb3cd1687d90a02d5c7b9cd/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:110d863baf7f0a369b6c22be547c5582e87eea70ddda41894715c870b2e82eb0", size = 123860, upload-time = "2025-06-11T13:16:55.824Z" }, + { url = "https://files.pythonhosted.org/packages/70/a1/4baa8dca3d2df94e6ccca889947bb5929aca5b64b59136bbf1779b5777ba/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:226d9fca98469bd770e3efd88326854296d1aa68016f285bd1a2fb6cd21e17ee", size = 114969, upload-time = "2025-06-11T13:16:58.02Z" }, + { url = "https://files.pythonhosted.org/packages/ab/48/123fa67f6fca14c557332a168011565abd9cbdccc5c8b7ed76d9a736aeb2/geventhttpclient-2.3.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71dbc6d4004017ef88c70229809df4ad2317aad4876870c0b6bcd4d6695b7a8d", size = 113311, upload-time = "2025-06-11T13:16:59.423Z" }, + { url = "https://files.pythonhosted.org/packages/93/e4/8a467991127ca6c53dd79a8aecb26a48207e7e7976c578fb6eb31378792c/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ed35391ad697d6cda43c94087f59310f028c3e9fb229e435281a92509469c627", size = 111154, upload-time = "2025-06-11T13:17:01.139Z" }, + { url = "https://files.pythonhosted.org/packages/11/e7/cca0663d90bc8e68592a62d7b28148eb9fd976f739bb107e4c93f9ae6d81/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:97cd2ab03d303fd57dea4f6d9c2ab23b7193846f1b3bbb4c80b315ebb5fc8527", size = 112532, upload-time = "2025-06-11T13:17:03.729Z" }, + { url = "https://files.pythonhosted.org/packages/02/98/625cee18a3be5f7ca74c612d4032b0c013b911eb73c7e72e06fa56a44ba2/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ec4d1aa08569b7eb075942caeacabefee469a0e283c96c7aac0226d5e7598fe8", size = 117806, upload-time = "2025-06-11T13:17:05.138Z" }, + { url = "https://files.pythonhosted.org/packages/f1/5e/e561a5f8c9d98b7258685355aacb9cca8a3c714190cf92438a6e91da09d5/geventhttpclient-2.3.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:93926aacdb0f4289b558f213bc32c03578f3432a18b09e4b6d73a716839d7a74", size = 111392, upload-time = "2025-06-11T13:17:06.053Z" }, + { url = "https://files.pythonhosted.org/packages/d0/37/42d09ad90fd1da960ff68facaa3b79418ccf66297f202ba5361038fc3182/geventhttpclient-2.3.4-cp311-cp311-win32.whl", hash = "sha256:ea87c25e933991366049a42c88e91ad20c2b72e11c7bd38ef68f80486ab63cb2", size = 48332, upload-time = "2025-06-11T13:17:06.965Z" }, + { url = "https://files.pythonhosted.org/packages/4b/0b/55e2a9ed4b1aed7c97e857dc9649a7e804609a105e1ef3cb01da857fbce7/geventhttpclient-2.3.4-cp311-cp311-win_amd64.whl", hash = "sha256:e02e0e9ef2e45475cf33816c8fb2e24595650bcf259e7b15b515a7b49cae1ccf", size = 48969, upload-time = "2025-06-11T13:17:08.239Z" }, + { url = "https://files.pythonhosted.org/packages/4f/72/dcbc6dbf838549b7b0c2c18c1365d2580eb7456939e4b608c3ab213fce78/geventhttpclient-2.3.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9ac30c38d86d888b42bb2ab2738ab9881199609e9fa9a153eb0c66fc9188c6cb", size = 71984, upload-time = "2025-06-11T13:17:09.126Z" }, + { url = "https://files.pythonhosted.org/packages/4c/f9/74aa8c556364ad39b238919c954a0da01a6154ad5e85a1d1ab5f9f5ac186/geventhttpclient-2.3.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b802000a4fad80fa57e895009671d6e8af56777e3adf0d8aee0807e96188fd9", size = 52631, upload-time = "2025-06-11T13:17:10.061Z" }, + { url = "https://files.pythonhosted.org/packages/11/1a/bc4b70cba8b46be8b2c6ca5b8067c4f086f8c90915eb68086ab40ff6243d/geventhttpclient-2.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:461e4d9f4caee481788ec95ac64e0a4a087c1964ddbfae9b6f2dc51715ba706c", size = 51991, upload-time = "2025-06-11T13:17:11.049Z" }, + { url = "https://files.pythonhosted.org/packages/03/3f/5ce6e003b3b24f7caf3207285831afd1a4f857ce98ac45e1fb7a6815bd58/geventhttpclient-2.3.4-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b7e41687c74e8fbe6a665458bbaea0c5a75342a95e2583738364a73bcbf1671b", size = 114982, upload-time = "2025-08-24T12:16:50.76Z" }, + { url = "https://files.pythonhosted.org/packages/60/16/6f9dad141b7c6dd7ee831fbcd72dd02535c57bc1ec3c3282f07e72c31344/geventhttpclient-2.3.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3ea5da20f4023cf40207ce15f5f4028377ffffdba3adfb60b4c8f34925fce79", size = 115654, upload-time = "2025-08-24T12:16:52.072Z" }, + { url = "https://files.pythonhosted.org/packages/ba/52/9b516a2ff423d8bd64c319e1950a165ceebb552781c5a88c1e94e93e8713/geventhttpclient-2.3.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:91f19a8a6899c27867dbdace9500f337d3e891a610708e86078915f1d779bf53", size = 121672, upload-time = "2025-08-24T12:16:53.361Z" }, + { url = "https://files.pythonhosted.org/packages/b0/f5/8d0f1e998f6d933c251b51ef92d11f7eb5211e3cd579018973a2b455f7c5/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41f2dcc0805551ea9d49f9392c3b9296505a89b9387417b148655d0d8251b36e", size = 119012, upload-time = "2025-06-11T13:17:11.956Z" }, + { url = "https://files.pythonhosted.org/packages/ea/0e/59e4ab506b3c19fc72e88ca344d150a9028a00c400b1099637100bec26fc/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:62f3a29bf242ecca6360d497304900683fd8f42cbf1de8d0546c871819251dad", size = 124565, upload-time = "2025-06-11T13:17:12.896Z" }, + { url = "https://files.pythonhosted.org/packages/39/5d/dcbd34dfcda0c016b4970bd583cb260cc5ebfc35b33d0ec9ccdb2293587a/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8714a3f2c093aeda3ffdb14c03571d349cb3ed1b8b461d9f321890659f4a5dbf", size = 115573, upload-time = "2025-06-11T13:17:13.937Z" }, + { url = "https://files.pythonhosted.org/packages/03/51/89af99e4805e9ce7f95562dfbd23c0b0391830831e43d58f940ec74489ac/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b11f38b74bab75282db66226197024a731250dcbe25542fd4e85ac5313547332", size = 114260, upload-time = "2025-06-11T13:17:14.913Z" }, + { url = "https://files.pythonhosted.org/packages/b3/ec/3a3000bda432953abcc6f51d008166fa7abc1eeddd1f0246933d83854f73/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:fccc2023a89dfbce2e1b1409b967011e45d41808df81b7fa0259397db79ba647", size = 111592, upload-time = "2025-06-11T13:17:15.879Z" }, + { url = "https://files.pythonhosted.org/packages/d8/a3/88fd71fe6bbe1315a2d161cbe2cc7810c357d99bced113bea1668ede8bcf/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9d54b8e9a44890159ae36ba4ae44efd8bb79ff519055137a340d357538a68aa3", size = 113216, upload-time = "2025-06-11T13:17:16.883Z" }, + { url = "https://files.pythonhosted.org/packages/52/eb/20435585a6911b26e65f901a827ef13551c053133926f8c28a7cca0fb08e/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:407cb68a3c3a2c4f5d503930298f2b26ae68137d520e8846d8e230a9981d9334", size = 118450, upload-time = "2025-06-11T13:17:17.968Z" }, + { url = "https://files.pythonhosted.org/packages/2f/79/82782283d613570373990b676a0966c1062a38ca8f41a0f20843c5808e01/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:54fbbcca2dcf06f12a337dd8f98417a09a49aa9d9706aa530fc93acb59b7d83c", size = 112226, upload-time = "2025-06-11T13:17:18.942Z" }, + { url = "https://files.pythonhosted.org/packages/9c/c4/417d12fc2a31ad93172b03309c7f8c3a8bbd0cf25b95eb7835de26b24453/geventhttpclient-2.3.4-cp312-cp312-win32.whl", hash = "sha256:83143b41bde2eb010c7056f142cb764cfbf77f16bf78bda2323a160767455cf5", size = 48365, upload-time = "2025-06-11T13:17:20.096Z" }, + { url = "https://files.pythonhosted.org/packages/cf/f4/7e5ee2f460bbbd09cb5d90ff63a1cf80d60f1c60c29dac20326324242377/geventhttpclient-2.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:46eda9a9137b0ca7886369b40995d2a43a5dff033d0a839a54241015d1845d41", size = 48961, upload-time = "2025-06-11T13:17:21.111Z" }, + { url = "https://files.pythonhosted.org/packages/0b/a7/de506f91a1ec67d3c4a53f2aa7475e7ffb869a17b71b94ba370a027a69ac/geventhttpclient-2.3.4-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:707a66cd1e3bf06e2c4f8f21d3b4e6290c9e092456f489c560345a8663cdd93e", size = 50828, upload-time = "2025-06-11T13:17:57.589Z" }, + { url = "https://files.pythonhosted.org/packages/2b/43/86479c278e96cd3e190932b0003d5b8e415660d9e519d59094728ae249da/geventhttpclient-2.3.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:0129ce7ef50e67d66ea5de44d89a3998ab778a4db98093d943d6855323646fa5", size = 50086, upload-time = "2025-06-11T13:17:58.567Z" }, + { url = "https://files.pythonhosted.org/packages/e8/f7/d3e04f95de14db3ca4fe126eb0e3ec24356125c5ca1f471a9b28b1d7714d/geventhttpclient-2.3.4-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fac2635f68b3b6752c2a576833d9d18f0af50bdd4bd7dd2d2ca753e3b8add84c", size = 54523, upload-time = "2025-06-11T13:17:59.536Z" }, + { url = "https://files.pythonhosted.org/packages/45/a7/d80c9ec1663f70f4bd976978bf86b3d0d123a220c4ae636c66d02d3accdb/geventhttpclient-2.3.4-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:71206ab89abdd0bd5fee21e04a3995ec1f7d8ae1478ee5868f9e16e85a831653", size = 58866, upload-time = "2025-06-11T13:18:03.719Z" }, + { url = "https://files.pythonhosted.org/packages/55/92/d874ff7e52803cef3850bf8875816a9f32e0a154b079a74e6663534bef30/geventhttpclient-2.3.4-pp311-pypy311_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8bde667d0ce46065fe57f8ff24b2e94f620a5747378c97314dcfc8fbab35b73", size = 54766, upload-time = "2025-06-11T13:18:04.724Z" }, + { url = "https://files.pythonhosted.org/packages/a8/73/2e03125170485193fcc99ef23b52749543d6c6711706d58713fe315869c4/geventhttpclient-2.3.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:5f71c75fc138331cbbe668a08951d36b641d2c26fb3677d7e497afb8419538db", size = 49011, upload-time = "2025-06-11T13:18:05.702Z" }, +] + [[package]] name = "gitdb" version = "4.0.12" @@ -2627,7 +2701,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.33.2" +version = "0.34.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -2639,9 +2713,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fa/42/8a95c5632080ae312c0498744b2b852195e10b05a20b1be11c5141092f4c/huggingface_hub-0.33.2.tar.gz", hash = "sha256:84221defaec8fa09c090390cd68c78b88e3c4c2b7befba68d3dc5aacbc3c2c5f", size = 426637, upload-time = "2025-07-02T06:26:05.156Z" } +sdist = { url = "https://files.pythonhosted.org/packages/45/c9/bdbe19339f76d12985bc03572f330a01a93c04dffecaaea3061bdd7fb892/huggingface_hub-0.34.4.tar.gz", hash = "sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c", size = 459768, upload-time = "2025-08-08T09:14:52.365Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/f4/5f3f22e762ad1965f01122b42dae5bf0e009286e2dba601ce1d0dba72424/huggingface_hub-0.33.2-py3-none-any.whl", hash = "sha256:3749498bfa91e8cde2ddc2c1db92c79981f40e66434c20133b39e5928ac9bcc5", size = 515373, upload-time = "2025-07-02T06:26:03.072Z" }, + { url = "https://files.pythonhosted.org/packages/39/7b/bb06b061991107cd8783f300adff3e7b7f284e330fd82f507f2a1417b11d/huggingface_hub-0.34.4-py3-none-any.whl", hash = "sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a", size = 561452, upload-time = "2025-08-08T09:14:50.159Z" }, ] [[package]] @@ -2959,6 +3033,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e2/3b/a9a17366af80127bd09decbe2a54d8974b6d8b274b39bf47fbaedeec6307/llvmlite-0.44.0-cp312-cp312-win_amd64.whl", hash = "sha256:eae7e2d4ca8f88f89d315b48c6b741dcb925d6a1042da694aa16ab3dd4cbd3a1", size = 30332380, upload-time = "2025-01-20T11:14:02.442Z" }, ] +[[package]] +name = "locust" +version = "2.40.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "configargparse" }, + { name = "flask" }, + { name = "flask-cors" }, + { name = "flask-login" }, + { name = "gevent" }, + { name = "geventhttpclient" }, + { name = "locust-cloud" }, + { name = "msgpack" }, + { name = "psutil" }, + { name = "pytest" }, + { name = "python-engineio" }, + { name = "python-socketio", extra = ["client"] }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "pyzmq" }, + { name = "requests" }, + { name = "setuptools" }, + { name = "typing-extensions", marker = "python_full_version < '3.12'" }, + { name = "werkzeug" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c8/40/31ff56ab6f46c7c77e61bbbd23f87fdf6a4aaf674dc961a3c573320caedc/locust-2.40.4.tar.gz", hash = "sha256:3a3a470459edc4ba1349229bf1aca4c0cb651c4e2e3f85d3bc28fe8118f5a18f", size = 1412529, upload-time = "2025-09-11T09:26:13.713Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7e/db1d969caf45ce711e81cd4f3e7c4554c3925a02383a1dcadb442eae3802/locust-2.40.4-py3-none-any.whl", hash = "sha256:50e647a73c5a4e7a775c6e4311979472fce8b00ed783837a2ce9bb36786f7d1a", size = 1430961, upload-time = "2025-09-11T09:26:11.623Z" }, +] + +[[package]] +name = "locust-cloud" +version = "1.26.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "configargparse" }, + { name = "gevent" }, + { name = "platformdirs" }, + { name = "python-engineio" }, + { name = "python-socketio", extra = ["client"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/84/ad/10b299b134068a4250a9156e6832a717406abe1dfea2482a07ae7bdca8f3/locust_cloud-1.26.3.tar.gz", hash = "sha256:587acfd4d2dee715fb5f0c3c2d922770babf0b7cff7b2927afbb693a9cd193cc", size = 456042, upload-time = "2025-07-15T19:51:53.791Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/6a/276fc50a9d170e7cbb6715735480cb037abb526639bca85491576e6eee4a/locust_cloud-1.26.3-py3-none-any.whl", hash = "sha256:8cb4b8bb9adcd5b99327bc8ed1d98cf67a29d9d29512651e6e94869de6f1faa8", size = 410023, upload-time = "2025-07-15T19:51:52.056Z" }, +] + [[package]] name = "lxml" version = "6.0.0" @@ -3230,6 +3349,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/75/bd9b7bb966668920f06b200e84454c8f3566b102183bc55c5473d96cb2b9/msal_extensions-1.3.1-py3-none-any.whl", hash = "sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca", size = 20583, upload-time = "2025-03-14T23:51:03.016Z" }, ] +[[package]] +name = "msgpack" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/45/b1/ea4f68038a18c77c9467400d166d74c4ffa536f34761f7983a104357e614/msgpack-1.1.1.tar.gz", hash = "sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd", size = 173555, upload-time = "2025-06-13T06:52:51.324Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/83/97f24bf9848af23fe2ba04380388216defc49a8af6da0c28cc636d722502/msgpack-1.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:71ef05c1726884e44f8b1d1773604ab5d4d17729d8491403a705e649116c9558", size = 82728, upload-time = "2025-06-13T06:51:50.68Z" }, + { url = "https://files.pythonhosted.org/packages/aa/7f/2eaa388267a78401f6e182662b08a588ef4f3de6f0eab1ec09736a7aaa2b/msgpack-1.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:36043272c6aede309d29d56851f8841ba907a1a3d04435e43e8a19928e243c1d", size = 79279, upload-time = "2025-06-13T06:51:51.72Z" }, + { url = "https://files.pythonhosted.org/packages/f8/46/31eb60f4452c96161e4dfd26dbca562b4ec68c72e4ad07d9566d7ea35e8a/msgpack-1.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a32747b1b39c3ac27d0670122b57e6e57f28eefb725e0b625618d1b59bf9d1e0", size = 423859, upload-time = "2025-06-13T06:51:52.749Z" }, + { url = "https://files.pythonhosted.org/packages/45/16/a20fa8c32825cc7ae8457fab45670c7a8996d7746ce80ce41cc51e3b2bd7/msgpack-1.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a8b10fdb84a43e50d38057b06901ec9da52baac6983d3f709d8507f3889d43f", size = 429975, upload-time = "2025-06-13T06:51:53.97Z" }, + { url = "https://files.pythonhosted.org/packages/86/ea/6c958e07692367feeb1a1594d35e22b62f7f476f3c568b002a5ea09d443d/msgpack-1.1.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba0c325c3f485dc54ec298d8b024e134acf07c10d494ffa24373bea729acf704", size = 413528, upload-time = "2025-06-13T06:51:55.507Z" }, + { url = "https://files.pythonhosted.org/packages/75/05/ac84063c5dae79722bda9f68b878dc31fc3059adb8633c79f1e82c2cd946/msgpack-1.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:88daaf7d146e48ec71212ce21109b66e06a98e5e44dca47d853cbfe171d6c8d2", size = 413338, upload-time = "2025-06-13T06:51:57.023Z" }, + { url = "https://files.pythonhosted.org/packages/69/e8/fe86b082c781d3e1c09ca0f4dacd457ede60a13119b6ce939efe2ea77b76/msgpack-1.1.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d8b55ea20dc59b181d3f47103f113e6f28a5e1c89fd5b67b9140edb442ab67f2", size = 422658, upload-time = "2025-06-13T06:51:58.419Z" }, + { url = "https://files.pythonhosted.org/packages/3b/2b/bafc9924df52d8f3bb7c00d24e57be477f4d0f967c0a31ef5e2225e035c7/msgpack-1.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4a28e8072ae9779f20427af07f53bbb8b4aa81151054e882aee333b158da8752", size = 427124, upload-time = "2025-06-13T06:51:59.969Z" }, + { url = "https://files.pythonhosted.org/packages/a2/3b/1f717e17e53e0ed0b68fa59e9188f3f610c79d7151f0e52ff3cd8eb6b2dc/msgpack-1.1.1-cp311-cp311-win32.whl", hash = "sha256:7da8831f9a0fdb526621ba09a281fadc58ea12701bc709e7b8cbc362feabc295", size = 65016, upload-time = "2025-06-13T06:52:01.294Z" }, + { url = "https://files.pythonhosted.org/packages/48/45/9d1780768d3b249accecc5a38c725eb1e203d44a191f7b7ff1941f7df60c/msgpack-1.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fd1b58e1431008a57247d6e7cc4faa41c3607e8e7d4aaf81f7c29ea013cb458", size = 72267, upload-time = "2025-06-13T06:52:02.568Z" }, + { url = "https://files.pythonhosted.org/packages/e3/26/389b9c593eda2b8551b2e7126ad3a06af6f9b44274eb3a4f054d48ff7e47/msgpack-1.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ae497b11f4c21558d95de9f64fff7053544f4d1a17731c866143ed6bb4591238", size = 82359, upload-time = "2025-06-13T06:52:03.909Z" }, + { url = "https://files.pythonhosted.org/packages/ab/65/7d1de38c8a22cf8b1551469159d4b6cf49be2126adc2482de50976084d78/msgpack-1.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:33be9ab121df9b6b461ff91baac6f2731f83d9b27ed948c5b9d1978ae28bf157", size = 79172, upload-time = "2025-06-13T06:52:05.246Z" }, + { url = "https://files.pythonhosted.org/packages/0f/bd/cacf208b64d9577a62c74b677e1ada005caa9b69a05a599889d6fc2ab20a/msgpack-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f64ae8fe7ffba251fecb8408540c34ee9df1c26674c50c4544d72dbf792e5ce", size = 425013, upload-time = "2025-06-13T06:52:06.341Z" }, + { url = "https://files.pythonhosted.org/packages/4d/ec/fd869e2567cc9c01278a736cfd1697941ba0d4b81a43e0aa2e8d71dab208/msgpack-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a494554874691720ba5891c9b0b39474ba43ffb1aaf32a5dac874effb1619e1a", size = 426905, upload-time = "2025-06-13T06:52:07.501Z" }, + { url = "https://files.pythonhosted.org/packages/55/2a/35860f33229075bce803a5593d046d8b489d7ba2fc85701e714fc1aaf898/msgpack-1.1.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb643284ab0ed26f6957d969fe0dd8bb17beb567beb8998140b5e38a90974f6c", size = 407336, upload-time = "2025-06-13T06:52:09.047Z" }, + { url = "https://files.pythonhosted.org/packages/8c/16/69ed8f3ada150bf92745fb4921bd621fd2cdf5a42e25eb50bcc57a5328f0/msgpack-1.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d275a9e3c81b1093c060c3837e580c37f47c51eca031f7b5fb76f7b8470f5f9b", size = 409485, upload-time = "2025-06-13T06:52:10.382Z" }, + { url = "https://files.pythonhosted.org/packages/c6/b6/0c398039e4c6d0b2e37c61d7e0e9d13439f91f780686deb8ee64ecf1ae71/msgpack-1.1.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4fd6b577e4541676e0cc9ddc1709d25014d3ad9a66caa19962c4f5de30fc09ef", size = 412182, upload-time = "2025-06-13T06:52:11.644Z" }, + { url = "https://files.pythonhosted.org/packages/b8/d0/0cf4a6ecb9bc960d624c93effaeaae75cbf00b3bc4a54f35c8507273cda1/msgpack-1.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bb29aaa613c0a1c40d1af111abf025f1732cab333f96f285d6a93b934738a68a", size = 419883, upload-time = "2025-06-13T06:52:12.806Z" }, + { url = "https://files.pythonhosted.org/packages/62/83/9697c211720fa71a2dfb632cad6196a8af3abea56eece220fde4674dc44b/msgpack-1.1.1-cp312-cp312-win32.whl", hash = "sha256:870b9a626280c86cff9c576ec0d9cbcc54a1e5ebda9cd26dab12baf41fee218c", size = 65406, upload-time = "2025-06-13T06:52:14.271Z" }, + { url = "https://files.pythonhosted.org/packages/c0/23/0abb886e80eab08f5e8c485d6f13924028602829f63b8f5fa25a06636628/msgpack-1.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:5692095123007180dca3e788bb4c399cc26626da51629a31d40207cb262e67f4", size = 72558, upload-time = "2025-06-13T06:52:15.252Z" }, +] + [[package]] name = "msrest" version = "0.7.1" @@ -4826,6 +4973,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863, upload-time = "2024-01-23T06:32:58.246Z" }, ] +[[package]] +name = "python-engineio" +version = "4.12.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "simple-websocket" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/0b/67295279b66835f9fa7a491650efcd78b20321c127036eef62c11a31e028/python_engineio-4.12.2.tar.gz", hash = "sha256:e7e712ffe1be1f6a05ee5f951e72d434854a32fcfc7f6e4d9d3cae24ec70defa", size = 91677, upload-time = "2025-06-04T19:22:18.789Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" }, +] + [[package]] name = "python-http-client" version = "3.3.7" @@ -4882,6 +5041,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/4f/00be2196329ebbff56ce564aa94efb0fbc828d00de250b1980de1a34ab49/python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba", size = 472788, upload-time = "2024-08-07T17:33:28.192Z" }, ] +[[package]] +name = "python-socketio" +version = "5.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bidict" }, + { name = "python-engineio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/21/1a/396d50ccf06ee539fa758ce5623b59a9cb27637fc4b2dc07ed08bf495e77/python_socketio-5.13.0.tar.gz", hash = "sha256:ac4e19a0302ae812e23b712ec8b6427ca0521f7c582d6abb096e36e24a263029", size = 121125, upload-time = "2025-04-12T15:46:59.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/32/b4fb8585d1be0f68bde7e110dffbcf354915f77ad8c778563f0ad9655c02/python_socketio-5.13.0-py3-none-any.whl", hash = "sha256:51f68d6499f2df8524668c24bcec13ba1414117cfb3a90115c559b601ab10caf", size = 77800, upload-time = "2025-04-12T15:46:58.412Z" }, +] + +[package.optional-dependencies] +client = [ + { name = "requests" }, + { name = "websocket-client" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -4939,6 +5117,42 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338, upload-time = "2024-08-06T20:32:41.93Z" }, ] +[[package]] +name = "pyzmq" +version = "27.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "implementation_name == 'pypy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/04/0b/3c9baedbdf613ecaa7aa07027780b8867f57b6293b6ee50de316c9f3222b/pyzmq-27.1.0.tar.gz", hash = "sha256:ac0765e3d44455adb6ddbf4417dcce460fc40a05978c08efdf2948072f6db540", size = 281750, upload-time = "2025-09-08T23:10:18.157Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/5d/305323ba86b284e6fcb0d842d6adaa2999035f70f8c38a9b6d21ad28c3d4/pyzmq-27.1.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:226b091818d461a3bef763805e75685e478ac17e9008f49fce2d3e52b3d58b86", size = 1333328, upload-time = "2025-09-08T23:07:45.946Z" }, + { url = "https://files.pythonhosted.org/packages/bd/a0/fc7e78a23748ad5443ac3275943457e8452da67fda347e05260261108cbc/pyzmq-27.1.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:0790a0161c281ca9723f804871b4027f2e8b5a528d357c8952d08cd1a9c15581", size = 908803, upload-time = "2025-09-08T23:07:47.551Z" }, + { url = "https://files.pythonhosted.org/packages/7e/22/37d15eb05f3bdfa4abea6f6d96eb3bb58585fbd3e4e0ded4e743bc650c97/pyzmq-27.1.0-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c895a6f35476b0c3a54e3eb6ccf41bf3018de937016e6e18748317f25d4e925f", size = 668836, upload-time = "2025-09-08T23:07:49.436Z" }, + { url = "https://files.pythonhosted.org/packages/b1/c4/2a6fe5111a01005fc7af3878259ce17684fabb8852815eda6225620f3c59/pyzmq-27.1.0-cp311-cp311-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5bbf8d3630bf96550b3be8e1fc0fea5cbdc8d5466c1192887bd94869da17a63e", size = 857038, upload-time = "2025-09-08T23:07:51.234Z" }, + { url = "https://files.pythonhosted.org/packages/cb/eb/bfdcb41d0db9cd233d6fb22dc131583774135505ada800ebf14dfb0a7c40/pyzmq-27.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:15c8bd0fe0dabf808e2d7a681398c4e5ded70a551ab47482067a572c054c8e2e", size = 1657531, upload-time = "2025-09-08T23:07:52.795Z" }, + { url = "https://files.pythonhosted.org/packages/ab/21/e3180ca269ed4a0de5c34417dfe71a8ae80421198be83ee619a8a485b0c7/pyzmq-27.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:bafcb3dd171b4ae9f19ee6380dfc71ce0390fefaf26b504c0e5f628d7c8c54f2", size = 2034786, upload-time = "2025-09-08T23:07:55.047Z" }, + { url = "https://files.pythonhosted.org/packages/3b/b1/5e21d0b517434b7f33588ff76c177c5a167858cc38ef740608898cd329f2/pyzmq-27.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e829529fcaa09937189178115c49c504e69289abd39967cd8a4c215761373394", size = 1894220, upload-time = "2025-09-08T23:07:57.172Z" }, + { url = "https://files.pythonhosted.org/packages/03/f2/44913a6ff6941905efc24a1acf3d3cb6146b636c546c7406c38c49c403d4/pyzmq-27.1.0-cp311-cp311-win32.whl", hash = "sha256:6df079c47d5902af6db298ec92151db82ecb557af663098b92f2508c398bb54f", size = 567155, upload-time = "2025-09-08T23:07:59.05Z" }, + { url = "https://files.pythonhosted.org/packages/23/6d/d8d92a0eb270a925c9b4dd039c0b4dc10abc2fcbc48331788824ef113935/pyzmq-27.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:190cbf120fbc0fc4957b56866830def56628934a9d112aec0e2507aa6a032b97", size = 633428, upload-time = "2025-09-08T23:08:00.663Z" }, + { url = "https://files.pythonhosted.org/packages/ae/14/01afebc96c5abbbd713ecfc7469cfb1bc801c819a74ed5c9fad9a48801cb/pyzmq-27.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:eca6b47df11a132d1745eb3b5b5e557a7dae2c303277aa0e69c6ba91b8736e07", size = 559497, upload-time = "2025-09-08T23:08:02.15Z" }, + { url = "https://files.pythonhosted.org/packages/92/e7/038aab64a946d535901103da16b953c8c9cc9c961dadcbf3609ed6428d23/pyzmq-27.1.0-cp312-abi3-macosx_10_15_universal2.whl", hash = "sha256:452631b640340c928fa343801b0d07eb0c3789a5ffa843f6e1a9cee0ba4eb4fc", size = 1306279, upload-time = "2025-09-08T23:08:03.807Z" }, + { url = "https://files.pythonhosted.org/packages/e8/5e/c3c49fdd0f535ef45eefcc16934648e9e59dace4a37ee88fc53f6cd8e641/pyzmq-27.1.0-cp312-abi3-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:1c179799b118e554b66da67d88ed66cd37a169f1f23b5d9f0a231b4e8d44a113", size = 895645, upload-time = "2025-09-08T23:08:05.301Z" }, + { url = "https://files.pythonhosted.org/packages/f8/e5/b0b2504cb4e903a74dcf1ebae157f9e20ebb6ea76095f6cfffea28c42ecd/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3837439b7f99e60312f0c926a6ad437b067356dc2bc2ec96eb395fd0fe804233", size = 652574, upload-time = "2025-09-08T23:08:06.828Z" }, + { url = "https://files.pythonhosted.org/packages/f8/9b/c108cdb55560eaf253f0cbdb61b29971e9fb34d9c3499b0e96e4e60ed8a5/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43ad9a73e3da1fab5b0e7e13402f0b2fb934ae1c876c51d0afff0e7c052eca31", size = 840995, upload-time = "2025-09-08T23:08:08.396Z" }, + { url = "https://files.pythonhosted.org/packages/c2/bb/b79798ca177b9eb0825b4c9998c6af8cd2a7f15a6a1a4272c1d1a21d382f/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0de3028d69d4cdc475bfe47a6128eb38d8bc0e8f4d69646adfbcd840facbac28", size = 1642070, upload-time = "2025-09-08T23:08:09.989Z" }, + { url = "https://files.pythonhosted.org/packages/9c/80/2df2e7977c4ede24c79ae39dcef3899bfc5f34d1ca7a5b24f182c9b7a9ca/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_i686.whl", hash = "sha256:cf44a7763aea9298c0aa7dbf859f87ed7012de8bda0f3977b6fb1d96745df856", size = 2021121, upload-time = "2025-09-08T23:08:11.907Z" }, + { url = "https://files.pythonhosted.org/packages/46/bd/2d45ad24f5f5ae7e8d01525eb76786fa7557136555cac7d929880519e33a/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f30f395a9e6fbca195400ce833c731e7b64c3919aa481af4d88c3759e0cb7496", size = 1878550, upload-time = "2025-09-08T23:08:13.513Z" }, + { url = "https://files.pythonhosted.org/packages/e6/2f/104c0a3c778d7c2ab8190e9db4f62f0b6957b53c9d87db77c284b69f33ea/pyzmq-27.1.0-cp312-abi3-win32.whl", hash = "sha256:250e5436a4ba13885494412b3da5d518cd0d3a278a1ae640e113c073a5f88edd", size = 559184, upload-time = "2025-09-08T23:08:15.163Z" }, + { url = "https://files.pythonhosted.org/packages/fc/7f/a21b20d577e4100c6a41795842028235998a643b1ad406a6d4163ea8f53e/pyzmq-27.1.0-cp312-abi3-win_amd64.whl", hash = "sha256:9ce490cf1d2ca2ad84733aa1d69ce6855372cb5ce9223802450c9b2a7cba0ccf", size = 619480, upload-time = "2025-09-08T23:08:17.192Z" }, + { url = "https://files.pythonhosted.org/packages/78/c2/c012beae5f76b72f007a9e91ee9401cb88c51d0f83c6257a03e785c81cc2/pyzmq-27.1.0-cp312-abi3-win_arm64.whl", hash = "sha256:75a2f36223f0d535a0c919e23615fc85a1e23b71f40c7eb43d7b1dedb4d8f15f", size = 552993, upload-time = "2025-09-08T23:08:18.926Z" }, + { url = "https://files.pythonhosted.org/packages/4c/c6/c4dcdecdbaa70969ee1fdced6d7b8f60cfabe64d25361f27ac4665a70620/pyzmq-27.1.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:18770c8d3563715387139060d37859c02ce40718d1faf299abddcdcc6a649066", size = 836265, upload-time = "2025-09-08T23:09:49.376Z" }, + { url = "https://files.pythonhosted.org/packages/3e/79/f38c92eeaeb03a2ccc2ba9866f0439593bb08c5e3b714ac1d553e5c96e25/pyzmq-27.1.0-pp311-pypy311_pp73-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:ac25465d42f92e990f8d8b0546b01c391ad431c3bf447683fdc40565941d0604", size = 800208, upload-time = "2025-09-08T23:09:51.073Z" }, + { url = "https://files.pythonhosted.org/packages/49/0e/3f0d0d335c6b3abb9b7b723776d0b21fa7f3a6c819a0db6097059aada160/pyzmq-27.1.0-pp311-pypy311_pp73-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53b40f8ae006f2734ee7608d59ed661419f087521edbfc2149c3932e9c14808c", size = 567747, upload-time = "2025-09-08T23:09:52.698Z" }, + { url = "https://files.pythonhosted.org/packages/a1/cf/f2b3784d536250ffd4be70e049f3b60981235d70c6e8ce7e3ef21e1adb25/pyzmq-27.1.0-pp311-pypy311_pp73-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f605d884e7c8be8fe1aa94e0a783bf3f591b84c24e4bc4f3e7564c82ac25e271", size = 747371, upload-time = "2025-09-08T23:09:54.563Z" }, + { url = "https://files.pythonhosted.org/packages/01/1b/5dbe84eefc86f48473947e2f41711aded97eecef1231f4558f1f02713c12/pyzmq-27.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:c9f7f6e13dff2e44a6afeaf2cf54cee5929ad64afaf4d40b50f93c58fc687355", size = 544862, upload-time = "2025-09-08T23:09:56.509Z" }, +] + [[package]] name = "qdrant-client" version = "1.9.0" @@ -5387,6 +5601,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, ] +[[package]] +name = "simple-websocket" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wsproto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b0/d4/bfa032f961103eba93de583b161f0e6a5b63cebb8f2c7d0c6e6efe1e3d2e/simple_websocket-1.1.0.tar.gz", hash = "sha256:7939234e7aa067c534abdab3a9ed933ec9ce4691b0713c78acb195560aa52ae4", size = 17300, upload-time = "2024-10-10T22:39:31.412Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/59/0782e51887ac6b07ffd1570e0364cf901ebc36345fea669969d2084baebb/simple_websocket-1.1.0-py3-none-any.whl", hash = "sha256:4af6069630a38ed6c561010f0e11a5bc0d4ca569b36306eb257cd9a192497c8c", size = 13842, upload-time = "2024-10-10T22:39:29.645Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -5706,27 +5932,27 @@ wheels = [ [[package]] name = "tokenizers" -version = "0.21.2" +version = "0.22.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ab/2d/b0fce2b8201635f60e8c95990080f58461cc9ca3d5026de2e900f38a7f21/tokenizers-0.21.2.tar.gz", hash = "sha256:fdc7cffde3e2113ba0e6cc7318c40e3438a4d74bbc62bf04bcc63bdfb082ac77", size = 351545, upload-time = "2025-06-24T10:24:52.449Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/b4/c1ce3699e81977da2ace8b16d2badfd42b060e7d33d75c4ccdbf9dc920fa/tokenizers-0.22.0.tar.gz", hash = "sha256:2e33b98525be8453f355927f3cab312c36cd3e44f4d7e9e97da2fa94d0a49dcb", size = 362771, upload-time = "2025-08-29T10:25:33.914Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1d/cc/2936e2d45ceb130a21d929743f1e9897514691bec123203e10837972296f/tokenizers-0.21.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:342b5dfb75009f2255ab8dec0041287260fed5ce00c323eb6bab639066fef8ec", size = 2875206, upload-time = "2025-06-24T10:24:42.755Z" }, - { url = "https://files.pythonhosted.org/packages/6c/e6/33f41f2cc7861faeba8988e7a77601407bf1d9d28fc79c5903f8f77df587/tokenizers-0.21.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:126df3205d6f3a93fea80c7a8a266a78c1bd8dd2fe043386bafdd7736a23e45f", size = 2732655, upload-time = "2025-06-24T10:24:41.56Z" }, - { url = "https://files.pythonhosted.org/packages/33/2b/1791eb329c07122a75b01035b1a3aa22ad139f3ce0ece1b059b506d9d9de/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a32cd81be21168bd0d6a0f0962d60177c447a1aa1b1e48fa6ec9fc728ee0b12", size = 3019202, upload-time = "2025-06-24T10:24:31.791Z" }, - { url = "https://files.pythonhosted.org/packages/05/15/fd2d8104faa9f86ac68748e6f7ece0b5eb7983c7efc3a2c197cb98c99030/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8bd8999538c405133c2ab999b83b17c08b7fc1b48c1ada2469964605a709ef91", size = 2934539, upload-time = "2025-06-24T10:24:34.567Z" }, - { url = "https://files.pythonhosted.org/packages/a5/2e/53e8fd053e1f3ffbe579ca5f9546f35ac67cf0039ed357ad7ec57f5f5af0/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5e9944e61239b083a41cf8fc42802f855e1dca0f499196df37a8ce219abac6eb", size = 3248665, upload-time = "2025-06-24T10:24:39.024Z" }, - { url = "https://files.pythonhosted.org/packages/00/15/79713359f4037aa8f4d1f06ffca35312ac83629da062670e8830917e2153/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:514cd43045c5d546f01142ff9c79a96ea69e4b5cda09e3027708cb2e6d5762ab", size = 3451305, upload-time = "2025-06-24T10:24:36.133Z" }, - { url = "https://files.pythonhosted.org/packages/38/5f/959f3a8756fc9396aeb704292777b84f02a5c6f25c3fc3ba7530db5feb2c/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b1b9405822527ec1e0f7d8d2fdb287a5730c3a6518189c968254a8441b21faae", size = 3214757, upload-time = "2025-06-24T10:24:37.784Z" }, - { url = "https://files.pythonhosted.org/packages/c5/74/f41a432a0733f61f3d21b288de6dfa78f7acff309c6f0f323b2833e9189f/tokenizers-0.21.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fed9a4d51c395103ad24f8e7eb976811c57fbec2af9f133df471afcd922e5020", size = 3121887, upload-time = "2025-06-24T10:24:40.293Z" }, - { url = "https://files.pythonhosted.org/packages/3c/6a/bc220a11a17e5d07b0dfb3b5c628621d4dcc084bccd27cfaead659963016/tokenizers-0.21.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2c41862df3d873665ec78b6be36fcc30a26e3d4902e9dd8608ed61d49a48bc19", size = 9091965, upload-time = "2025-06-24T10:24:44.431Z" }, - { url = "https://files.pythonhosted.org/packages/6c/bd/ac386d79c4ef20dc6f39c4706640c24823dca7ebb6f703bfe6b5f0292d88/tokenizers-0.21.2-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:ed21dc7e624e4220e21758b2e62893be7101453525e3d23264081c9ef9a6d00d", size = 9053372, upload-time = "2025-06-24T10:24:46.455Z" }, - { url = "https://files.pythonhosted.org/packages/63/7b/5440bf203b2a5358f074408f7f9c42884849cd9972879e10ee6b7a8c3b3d/tokenizers-0.21.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:0e73770507e65a0e0e2a1affd6b03c36e3bc4377bd10c9ccf51a82c77c0fe365", size = 9298632, upload-time = "2025-06-24T10:24:48.446Z" }, - { url = "https://files.pythonhosted.org/packages/a4/d2/faa1acac3f96a7427866e94ed4289949b2524f0c1878512516567d80563c/tokenizers-0.21.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:106746e8aa9014a12109e58d540ad5465b4c183768ea96c03cbc24c44d329958", size = 9470074, upload-time = "2025-06-24T10:24:50.378Z" }, - { url = "https://files.pythonhosted.org/packages/d8/a5/896e1ef0707212745ae9f37e84c7d50269411aef2e9ccd0de63623feecdf/tokenizers-0.21.2-cp39-abi3-win32.whl", hash = "sha256:cabda5a6d15d620b6dfe711e1af52205266d05b379ea85a8a301b3593c60e962", size = 2330115, upload-time = "2025-06-24T10:24:55.069Z" }, - { url = "https://files.pythonhosted.org/packages/13/c3/cc2755ee10be859c4338c962a35b9a663788c0c0b50c0bdd8078fb6870cf/tokenizers-0.21.2-cp39-abi3-win_amd64.whl", hash = "sha256:58747bb898acdb1007f37a7bbe614346e98dc28708ffb66a3fd50ce169ac6c98", size = 2509918, upload-time = "2025-06-24T10:24:53.71Z" }, + { url = "https://files.pythonhosted.org/packages/6d/b1/18c13648edabbe66baa85fe266a478a7931ddc0cd1ba618802eb7b8d9865/tokenizers-0.22.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:eaa9620122a3fb99b943f864af95ed14c8dfc0f47afa3b404ac8c16b3f2bb484", size = 3081954, upload-time = "2025-08-29T10:25:24.993Z" }, + { url = "https://files.pythonhosted.org/packages/c2/02/c3c454b641bd7c4f79e4464accfae9e7dfc913a777d2e561e168ae060362/tokenizers-0.22.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:71784b9ab5bf0ff3075bceeb198149d2c5e068549c0d18fe32d06ba0deb63f79", size = 2945644, upload-time = "2025-08-29T10:25:23.405Z" }, + { url = "https://files.pythonhosted.org/packages/55/02/d10185ba2fd8c2d111e124c9d92de398aee0264b35ce433f79fb8472f5d0/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec5b71f668a8076802b0241a42387d48289f25435b86b769ae1837cad4172a17", size = 3254764, upload-time = "2025-08-29T10:25:12.445Z" }, + { url = "https://files.pythonhosted.org/packages/13/89/17514bd7ef4bf5bfff58e2b131cec0f8d5cea2b1c8ffe1050a2c8de88dbb/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ea8562fa7498850d02a16178105b58803ea825b50dc9094d60549a7ed63654bb", size = 3161654, upload-time = "2025-08-29T10:25:15.493Z" }, + { url = "https://files.pythonhosted.org/packages/5a/d8/bac9f3a7ef6dcceec206e3857c3b61bb16c6b702ed7ae49585f5bd85c0ef/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4136e1558a9ef2e2f1de1555dcd573e1cbc4a320c1a06c4107a3d46dc8ac6e4b", size = 3511484, upload-time = "2025-08-29T10:25:20.477Z" }, + { url = "https://files.pythonhosted.org/packages/aa/27/9c9800eb6763683010a4851db4d1802d8cab9cec114c17056eccb4d4a6e0/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cdf5954de3962a5fd9781dc12048d24a1a6f1f5df038c6e95db328cd22964206", size = 3712829, upload-time = "2025-08-29T10:25:17.154Z" }, + { url = "https://files.pythonhosted.org/packages/10/e3/b1726dbc1f03f757260fa21752e1921445b5bc350389a8314dd3338836db/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8337ca75d0731fc4860e6204cc24bb36a67d9736142aa06ed320943b50b1e7ed", size = 3408934, upload-time = "2025-08-29T10:25:18.76Z" }, + { url = "https://files.pythonhosted.org/packages/d4/61/aeab3402c26874b74bb67a7f2c4b569dde29b51032c5384db592e7b216f4/tokenizers-0.22.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a89264e26f63c449d8cded9061adea7b5de53ba2346fc7e87311f7e4117c1cc8", size = 3345585, upload-time = "2025-08-29T10:25:22.08Z" }, + { url = "https://files.pythonhosted.org/packages/bc/d3/498b4a8a8764cce0900af1add0f176ff24f475d4413d55b760b8cdf00893/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:790bad50a1b59d4c21592f9c3cf5e5cf9c3c7ce7e1a23a739f13e01fb1be377a", size = 9322986, upload-time = "2025-08-29T10:25:26.607Z" }, + { url = "https://files.pythonhosted.org/packages/a2/62/92378eb1c2c565837ca3cb5f9569860d132ab9d195d7950c1ea2681dffd0/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:76cf6757c73a10ef10bf06fa937c0ec7393d90432f543f49adc8cab3fb6f26cb", size = 9276630, upload-time = "2025-08-29T10:25:28.349Z" }, + { url = "https://files.pythonhosted.org/packages/eb/f0/342d80457aa1cda7654327460f69db0d69405af1e4c453f4dc6ca7c4a76e/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:1626cb186e143720c62c6c6b5371e62bbc10af60481388c0da89bc903f37ea0c", size = 9547175, upload-time = "2025-08-29T10:25:29.989Z" }, + { url = "https://files.pythonhosted.org/packages/14/84/8aa9b4adfc4fbd09381e20a5bc6aa27040c9c09caa89988c01544e008d18/tokenizers-0.22.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:da589a61cbfea18ae267723d6b029b84598dc8ca78db9951d8f5beff72d8507c", size = 9692735, upload-time = "2025-08-29T10:25:32.089Z" }, + { url = "https://files.pythonhosted.org/packages/bf/24/83ee2b1dc76bfe05c3142e7d0ccdfe69f0ad2f1ebf6c726cea7f0874c0d0/tokenizers-0.22.0-cp39-abi3-win32.whl", hash = "sha256:dbf9d6851bddae3e046fedfb166f47743c1c7bd11c640f0691dd35ef0bcad3be", size = 2471915, upload-time = "2025-08-29T10:25:36.411Z" }, + { url = "https://files.pythonhosted.org/packages/d1/9b/0e0bf82214ee20231845b127aa4a8015936ad5a46779f30865d10e404167/tokenizers-0.22.0-cp39-abi3-win_amd64.whl", hash = "sha256:c78174859eeaee96021f248a56c801e36bfb6bd5b067f2e95aa82445ca324f00", size = 2680494, upload-time = "2025-08-29T10:25:35.14Z" }, ] [[package]] @@ -5794,7 +6020,7 @@ wheels = [ [[package]] name = "transformers" -version = "4.53.3" +version = "4.56.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -5808,9 +6034,9 @@ dependencies = [ { name = "tokenizers" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f1/5c/49182918b58eaa0b4c954fd0e37c79fc299e5643e69d70089d0b0eb0cd9b/transformers-4.53.3.tar.gz", hash = "sha256:b2eda1a261de79b78b97f7888fe2005fc0c3fabf5dad33d52cc02983f9f675d8", size = 9197478, upload-time = "2025-07-22T07:30:51.51Z" } +sdist = { url = "https://files.pythonhosted.org/packages/89/21/dc88ef3da1e49af07ed69386a11047a31dcf1aaf4ded3bc4b173fbf94116/transformers-4.56.1.tar.gz", hash = "sha256:0d88b1089a563996fc5f2c34502f10516cad3ea1aa89f179f522b54c8311fe74", size = 9855473, upload-time = "2025-09-04T20:47:13.14Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/b1/d7520cc5cb69c825599042eb3a7c986fa9baa8a8d2dea9acd78e152c81e2/transformers-4.53.3-py3-none-any.whl", hash = "sha256:5aba81c92095806b6baf12df35d756cf23b66c356975fb2a7fa9e536138d7c75", size = 10826382, upload-time = "2025-07-22T07:30:48.458Z" }, + { url = "https://files.pythonhosted.org/packages/71/7c/283c3dd35e00e22a7803a0b2a65251347b745474a82399be058bde1c9f15/transformers-4.56.1-py3-none-any.whl", hash = "sha256:1697af6addfb6ddbce9618b763f4b52d5a756f6da4899ffd1b4febf58b779248", size = 11608197, upload-time = "2025-09-04T20:47:04.895Z" }, ] [[package]] @@ -6794,6 +7020,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594, upload-time = "2025-01-14T10:35:44.018Z" }, ] +[[package]] +name = "wsproto" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/4a/44d3c295350d776427904d73c189e10aeae66d7f555bb2feee16d1e4ba5a/wsproto-1.2.0.tar.gz", hash = "sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065", size = 53425, upload-time = "2022-08-23T19:58:21.447Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/58/e860788190eba3bcce367f74d29c4675466ce8dddfba85f7827588416f01/wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736", size = 24226, upload-time = "2022-08-23T19:58:19.96Z" }, +] + [[package]] name = "xinference-client" version = "1.2.2" diff --git a/dev/pytest/pytest_config_tests.py b/dev/pytest/pytest_config_tests.py index 63d0cbaf3a..1ec95deb09 100644 --- a/dev/pytest/pytest_config_tests.py +++ b/dev/pytest/pytest_config_tests.py @@ -1,6 +1,7 @@ +from pathlib import Path + import yaml # type: ignore from dotenv import dotenv_values -from pathlib import Path BASE_API_AND_DOCKER_CONFIG_SET_DIFF = { "APP_MAX_EXECUTION_TIME", @@ -98,23 +99,15 @@ with open(Path("docker") / Path("docker-compose.yaml")) as f: def test_yaml_config(): # python set == operator is used to compare two sets - DIFF_API_WITH_DOCKER = ( - API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF - ) + DIFF_API_WITH_DOCKER = API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF if DIFF_API_WITH_DOCKER: - print( - f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}" - ) + print(f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}") raise Exception("API and Docker config sets are different") DIFF_API_WITH_DOCKER_COMPOSE = ( - API_CONFIG_SET - - DOCKER_COMPOSE_CONFIG_SET - - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF + API_CONFIG_SET - DOCKER_COMPOSE_CONFIG_SET - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF ) if DIFF_API_WITH_DOCKER_COMPOSE: - print( - f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}" - ) + print(f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}") raise Exception("API and Docker Compose config sets are different") print("All tests passed!") diff --git a/docker/.env.example b/docker/.env.example index 9a0a5a9622..92347a6e76 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -843,6 +843,7 @@ INVITE_EXPIRY_HOURS=72 # Reset password token valid time (minutes), RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5 CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 3f19dc7f63..193157b54f 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -372,6 +372,7 @@ x-shared-env: &shared-api-worker-env INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} + EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES: ${EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES:-5} CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: ${CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES:-5} OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: ${OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES:-5} CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} diff --git a/scripts/stress-test/README.md b/scripts/stress-test/README.md new file mode 100644 index 0000000000..15f21cd532 --- /dev/null +++ b/scripts/stress-test/README.md @@ -0,0 +1,521 @@ +# Dify Stress Test Suite + +A high-performance stress test suite for Dify workflow execution using **Locust** - optimized for measuring Server-Sent Events (SSE) streaming performance. + +## Key Metrics Tracked + +The stress test focuses on four critical SSE performance indicators: + +1. **Active SSE Connections** - Real-time count of open SSE connections +1. **New Connection Rate** - Connections per second (conn/sec) +1. **Time to First Event (TTFE)** - Latency until first SSE event arrives +1. **Event Throughput** - Events per second (events/sec) + +## Features + +- **True SSE Support**: Properly handles Server-Sent Events streaming without premature connection closure +- **Real-time Metrics**: Live reporting every 5 seconds during tests +- **Comprehensive Tracking**: + - Active connection monitoring + - Connection establishment rate + - Event processing throughput + - TTFE distribution analysis +- **Multiple Interfaces**: + - Web UI for real-time monitoring () + - Headless mode with periodic console updates +- **Detailed Reports**: Final statistics with overall rates and averages +- **Easy Configuration**: Uses existing API key configuration from setup + +## What Gets Measured + +The stress test focuses on SSE streaming performance with these key metrics: + +### Primary Endpoint: `/v1/workflows/run` + +The stress test tests a single endpoint with comprehensive SSE metrics tracking: + +- **Request Type**: POST request to workflow execution API +- **Response Type**: Server-Sent Events (SSE) stream +- **Payload**: Random questions from a configurable pool +- **Concurrency**: Configurable from 1 to 1000+ simultaneous users + +### Key Performance Metrics + +#### 1. **Active Connections** + +- **What it measures**: Number of concurrent SSE connections open at any moment +- **Why it matters**: Shows system's ability to handle parallel streams +- **Good values**: Should remain stable under load without drops + +#### 2. **Connection Rate (conn/sec)** + +- **What it measures**: How fast new SSE connections are established +- **Why it matters**: Indicates system's ability to handle connection spikes +- **Good values**: + - Light load: 5-10 conn/sec + - Medium load: 20-50 conn/sec + - Heavy load: 100+ conn/sec + +#### 3. **Time to First Event (TTFE)** + +- **What it measures**: Latency from request sent to first SSE event received +- **Why it matters**: Critical for user experience - faster TTFE = better perceived performance +- **Good values**: + - Excellent: < 50ms + - Good: 50-100ms + - Acceptable: 100-500ms + - Poor: > 500ms + +#### 4. **Event Throughput (events/sec)** + +- **What it measures**: Rate of SSE events being delivered across all connections +- **Why it matters**: Shows actual data delivery performance +- **Expected values**: Depends on workflow complexity and number of connections + - Single connection: 10-20 events/sec + - 10 connections: 50-100 events/sec + - 100 connections: 200-500 events/sec + +#### 5. **Request/Response Times** + +- **P50 (Median)**: 50% of requests complete within this time +- **P95**: 95% of requests complete within this time +- **P99**: 99% of requests complete within this time +- **Min/Max**: Best and worst case response times + +## Prerequisites + +1. **Dependencies are automatically installed** when running setup: + + - Locust (load testing framework) + - sseclient-py (SSE client library) + +1. **Complete Dify setup**: + + ```bash + # Run the complete setup + python scripts/stress-test/setup_all.py + ``` + +1. **Ensure services are running**: + + **IMPORTANT**: For accurate stress testing, run the API server with Gunicorn in production mode: + + ```bash + # Run from the api directory + cd api + uv run gunicorn \ + --bind 0.0.0.0:5001 \ + --workers 4 \ + --worker-class gevent \ + --timeout 120 \ + --keep-alive 5 \ + --log-level info \ + --access-logfile - \ + --error-logfile - \ + app:app + ``` + + **Configuration options explained**: + + - `--workers 4`: Number of worker processes (adjust based on CPU cores) + - `--worker-class gevent`: Async worker for handling concurrent connections + - `--timeout 120`: Worker timeout for long-running requests + - `--keep-alive 5`: Keep connections alive for SSE streaming + + **NOT RECOMMENDED for stress testing**: + + ```bash + # Debug mode - DO NOT use for stress testing (slow performance) + ./dev/start-api # This runs Flask in debug mode with single-threaded execution + ``` + + **Also start the Mock OpenAI server**: + + ```bash + python scripts/stress-test/setup/mock_openai_server.py + ``` + +## Running the Stress Test + +```bash +# Run with default configuration (headless mode) +./scripts/stress-test/run_locust_stress_test.sh + +# Or run directly with uv +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py --host http://localhost:5001 + +# Run with Web UI (access at http://localhost:8089) +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py --host http://localhost:5001 --web-port 8089 +``` + +The script will: + +1. Validate that all required services are running +1. Check API token availability +1. Execute the Locust stress test with SSE support +1. Generate comprehensive reports in the `reports/` directory + +## Configuration + +The stress test configuration is in `locust.conf`: + +```ini +users = 10 # Number of concurrent users +spawn-rate = 2 # Users spawned per second +run-time = 1m # Test duration (30s, 5m, 1h) +headless = true # Run without web UI +``` + +### Custom Question Sets + +Modify the questions list in `sse_benchmark.py`: + +```python +self.questions = [ + "Your custom question 1", + "Your custom question 2", + # Add more questions... +] +``` + +## Understanding the Results + +### Report Structure + +After running the stress test, you'll find these files in the `reports/` directory: + +- `locust_summary_YYYYMMDD_HHMMSS.txt` - Complete console output with metrics +- `locust_report_YYYYMMDD_HHMMSS.html` - Interactive HTML report with charts +- `locust_YYYYMMDD_HHMMSS_stats.csv` - CSV with detailed statistics +- `locust_YYYYMMDD_HHMMSS_stats_history.csv` - Time-series data + +### Key Metrics + +**Requests Per Second (RPS)**: + +- **Excellent**: > 50 RPS +- **Good**: 20-50 RPS +- **Acceptable**: 10-20 RPS +- **Needs Improvement**: < 10 RPS + +**Response Time Percentiles**: + +- **P50 (Median)**: 50% of requests complete within this time +- **P95**: 95% of requests complete within this time +- **P99**: 99% of requests complete within this time + +**Success Rate**: + +- Should be > 99% for production readiness +- Lower rates indicate errors or timeouts + +### Example Output + +```text +============================================================ +DIFY SSE STRESS TEST +============================================================ + +[2025-09-12 15:45:44,468] Starting test run with 10 users at 2 users/sec + +============================================================ +SSE Metrics | Active: 8 | Total Conn: 142 | Events: 2841 +Rates: 2.4 conn/s | 47.3 events/s | TTFE: 43ms +============================================================ + +Type Name # reqs # fails | Avg Min Max Med | req/s failures/s +---------|------------------------------|--------|--------|--------|--------|--------|--------|--------|----------- +POST /v1/workflows/run 142 0(0.00%) | 41 18 192 38 | 2.37 0.00 +---------|------------------------------|--------|--------|--------|--------|--------|--------|--------|----------- + Aggregated 142 0(0.00%) | 41 18 192 38 | 2.37 0.00 + +============================================================ +FINAL RESULTS +============================================================ +Total Connections: 142 +Total Events: 2841 +Average TTFE: 43 ms +============================================================ +``` + +### How to Read the Results + +**Live SSE Metrics Box (Updates every 10 seconds):** + +```text +SSE Metrics | Active: 8 | Total Conn: 142 | Events: 2841 +Rates: 2.4 conn/s | 47.3 events/s | TTFE: 43ms +``` + +- **Active**: Current number of open SSE connections +- **Total Conn**: Cumulative connections established +- **Events**: Total SSE events received +- **conn/s**: Connection establishment rate +- **events/s**: Event delivery rate +- **TTFE**: Average time to first event + +**Standard Locust Table:** + +```text +Type Name # reqs # fails | Avg Min Max Med | req/s +POST /v1/workflows/run 142 0(0.00%) | 41 18 192 38 | 2.37 +``` + +- **Type**: Always POST for our SSE requests +- **Name**: The API endpoint being tested +- **# reqs**: Total requests made +- **# fails**: Failed requests (should be 0) +- **Avg/Min/Max/Med**: Response time percentiles (ms) +- **req/s**: Request throughput + +**Performance Targets:** + +✅ **Good Performance**: + +- Zero failures (0.00%) +- TTFE < 100ms +- Stable active connections +- Consistent event throughput + +⚠️ **Warning Signs**: + +- Failures > 1% +- TTFE > 500ms +- Dropping active connections +- Declining event rate over time + +## Test Scenarios + +### Light Load + +```yaml +concurrency: 10 +iterations: 100 +``` + +### Normal Load + +```yaml +concurrency: 100 +iterations: 1000 +``` + +### Heavy Load + +```yaml +concurrency: 500 +iterations: 5000 +``` + +### Stress Test + +```yaml +concurrency: 1000 +iterations: 10000 +``` + +## Performance Tuning + +### API Server Optimization + +**Gunicorn Tuning for Different Load Levels**: + +```bash +# Light load (10-50 concurrent users) +uv run gunicorn --bind 0.0.0.0:5001 --workers 2 --worker-class gevent app:app + +# Medium load (50-200 concurrent users) +uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent --worker-connections 1000 app:app + +# Heavy load (200-1000 concurrent users) +uv run gunicorn --bind 0.0.0.0:5001 --workers 8 --worker-class gevent --worker-connections 2000 --max-requests 1000 app:app +``` + +**Worker calculation formula**: + +- Workers = (2 × CPU cores) + 1 +- For SSE/WebSocket: Use gevent worker class +- For CPU-bound tasks: Use sync workers + +### Database Optimization + +**PostgreSQL Connection Pool Tuning**: + +For high-concurrency stress testing, increase the PostgreSQL max connections in `docker/middleware.env`: + +```bash +# Edit docker/middleware.env +POSTGRES_MAX_CONNECTIONS=200 # Default is 100 + +# Recommended values for different load levels: +# Light load (10-50 users): 100 (default) +# Medium load (50-200 users): 200 +# Heavy load (200-1000 users): 500 +``` + +After changing, restart the PostgreSQL container: + +```bash +docker compose -f docker/docker-compose.middleware.yaml down db +docker compose -f docker/docker-compose.middleware.yaml up -d db +``` + +**Note**: Each connection uses ~10MB of RAM. Ensure your database server has sufficient memory: + +- 100 connections: ~1GB RAM +- 200 connections: ~2GB RAM +- 500 connections: ~5GB RAM + +### System Optimizations + +1. **Increase file descriptor limits**: + + ```bash + ulimit -n 65536 + ``` + +1. **TCP tuning for high concurrency** (Linux): + + ```bash + # Increase TCP buffer sizes + sudo sysctl -w net.core.rmem_max=134217728 + sudo sysctl -w net.core.wmem_max=134217728 + + # Enable TCP fast open + sudo sysctl -w net.ipv4.tcp_fastopen=3 + ``` + +1. **macOS specific**: + + ```bash + # Increase maximum connections + sudo sysctl -w kern.ipc.somaxconn=2048 + ``` + +## Troubleshooting + +### Common Issues + +1. **"ModuleNotFoundError: No module named 'locust'"**: + + ```bash + # Dependencies are installed automatically, but if needed: + uv --project api add --dev locust sseclient-py + ``` + +1. **"API key configuration not found"**: + + ```bash + # Run setup + python scripts/stress-test/setup_all.py + ``` + +1. **Services not running**: + + ```bash + # Start Dify API with Gunicorn (production mode) + cd api + uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent app:app + + # Start Mock OpenAI server + python scripts/stress-test/setup/mock_openai_server.py + ``` + +1. **High error rate**: + + - Reduce concurrency level + - Check system resources (CPU, memory) + - Review API server logs for errors + - Increase timeout values if needed + +1. **Permission denied running script**: + + ```bash + chmod +x run_benchmark.sh + ``` + +## Advanced Usage + +### Running Multiple Iterations + +```bash +# Run stress test 3 times with 60-second intervals +for i in {1..3}; do + echo "Run $i of 3" + ./run_locust_stress_test.sh + sleep 60 +done +``` + +### Custom Locust Options + +Run Locust directly with custom options: + +```bash +# With specific user count and spawn rate +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py \ + --host http://localhost:5001 --users 50 --spawn-rate 5 + +# Generate CSV reports +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py \ + --host http://localhost:5001 --csv reports/results + +# Run for specific duration +uv run --project api python -m locust -f scripts/stress-test/sse_benchmark.py \ + --host http://localhost:5001 --run-time 5m --headless +``` + +### Comparing Results + +```bash +# Compare multiple stress test runs +ls -la reports/stress_test_*.txt | tail -5 +``` + +## Interpreting Performance Issues + +### High Response Times + +Possible causes: + +- Database query performance +- External API latency +- Insufficient server resources +- Network congestion + +### Low Throughput (RPS < 10) + +Check for: + +- CPU bottlenecks +- Memory constraints +- Database connection pooling +- API rate limiting + +### High Error Rate + +Investigate: + +- Server error logs +- Resource exhaustion +- Timeout configurations +- Connection limits + +## Why Locust? + +Locust was chosen over Drill for this stress test because: + +1. **Proper SSE Support**: Correctly handles streaming responses without premature closure +1. **Custom Metrics**: Can track SSE-specific metrics like TTFE and stream duration +1. **Web UI**: Real-time monitoring and control via web interface +1. **Python Integration**: Seamlessly integrates with existing Python setup code +1. **Extensibility**: Easy to customize for specific testing scenarios + +## Contributing + +To improve the stress test suite: + +1. Edit `stress_test.yml` for configuration changes +1. Modify `run_locust_stress_test.sh` for workflow improvements +1. Update question sets for better coverage +1. Add new metrics or analysis features diff --git a/scripts/stress-test/cleanup.py b/scripts/stress-test/cleanup.py new file mode 100755 index 0000000000..05b97be7ca --- /dev/null +++ b/scripts/stress-test/cleanup.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 + +import shutil +import sys +from pathlib import Path + +from common import Logger + + +def cleanup() -> None: + """Clean up all configuration files and reports created during setup and stress testing.""" + + log = Logger("Cleanup") + log.header("Stress Test Cleanup") + + config_dir = Path(__file__).parent / "setup" / "config" + reports_dir = Path(__file__).parent / "reports" + + dirs_to_clean = [] + if config_dir.exists(): + dirs_to_clean.append(config_dir) + if reports_dir.exists(): + dirs_to_clean.append(reports_dir) + + if not dirs_to_clean: + log.success("No directories to clean. Everything is already clean.") + return + + log.info("Cleaning up stress test data...") + log.info("This will remove:") + for dir_path in dirs_to_clean: + log.list_item(str(dir_path)) + + # List files that will be deleted + log.separator() + if config_dir.exists(): + config_files = list(config_dir.glob("*.json")) + if config_files: + log.info("Config files to be removed:") + for file in config_files: + log.list_item(file.name) + + if reports_dir.exists(): + report_files = list(reports_dir.glob("*")) + if report_files: + log.info("Report files to be removed:") + for file in report_files: + log.list_item(file.name) + + # Ask for confirmation if running interactively + if sys.stdin.isatty(): + log.separator() + log.warning("This action cannot be undone!") + confirmation = input("Are you sure you want to remove all config and report files? (yes/no): ") + + if confirmation.lower() not in ["yes", "y"]: + log.error("Cleanup cancelled.") + return + + try: + # Remove directories and all their contents + for dir_path in dirs_to_clean: + shutil.rmtree(dir_path) + log.success(f"{dir_path.name} directory removed successfully!") + + log.separator() + log.info("To run the setup again, execute:") + log.list_item("python setup_all.py") + log.info("Or run scripts individually in this order:") + log.list_item("python setup/mock_openai_server.py (in a separate terminal)") + log.list_item("python setup/setup_admin.py") + log.list_item("python setup/login_admin.py") + log.list_item("python setup/install_openai_plugin.py") + log.list_item("python setup/configure_openai_plugin.py") + log.list_item("python setup/import_workflow_app.py") + log.list_item("python setup/create_api_key.py") + log.list_item("python setup/publish_workflow.py") + log.list_item("python setup/run_workflow.py") + + except PermissionError as e: + log.error(f"Permission denied: {e}") + log.info("Try running with appropriate permissions.") + except Exception as e: + log.error(f"An error occurred during cleanup: {e}") + + +if __name__ == "__main__": + cleanup() diff --git a/scripts/stress-test/common/__init__.py b/scripts/stress-test/common/__init__.py new file mode 100644 index 0000000000..a38d972ffb --- /dev/null +++ b/scripts/stress-test/common/__init__.py @@ -0,0 +1,6 @@ +"""Common utilities for Dify benchmark suite.""" + +from .config_helper import config_helper +from .logger_helper import Logger, ProgressLogger + +__all__ = ["Logger", "ProgressLogger", "config_helper"] diff --git a/scripts/stress-test/common/config_helper.py b/scripts/stress-test/common/config_helper.py new file mode 100644 index 0000000000..75fcbffa6f --- /dev/null +++ b/scripts/stress-test/common/config_helper.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 + +import json +from pathlib import Path +from typing import Any + + +class ConfigHelper: + """Helper class for reading and writing configuration files.""" + + def __init__(self, base_dir: Path | None = None): + """Initialize ConfigHelper with base directory. + + Args: + base_dir: Base directory for config files. If None, uses setup/config + """ + if base_dir is None: + # Default to config directory in setup folder + base_dir = Path(__file__).parent.parent / "setup" / "config" + self.base_dir = base_dir + self.state_file = "stress_test_state.json" + + def ensure_config_dir(self) -> None: + """Ensure the config directory exists.""" + self.base_dir.mkdir(exist_ok=True, parents=True) + + def get_config_path(self, filename: str) -> Path: + """Get the full path for a config file. + + Args: + filename: Name of the config file (e.g., 'admin_config.json') + + Returns: + Full path to the config file + """ + if not filename.endswith(".json"): + filename += ".json" + return self.base_dir / filename + + def read_config(self, filename: str) -> dict[str, Any] | None: + """Read a configuration file. + + DEPRECATED: Use read_state() or get_state_section() for new code. + This method provides backward compatibility. + + Args: + filename: Name of the config file to read + + Returns: + Dictionary containing config data, or None if file doesn't exist + """ + # Provide backward compatibility for old config names + if filename in ["admin_config", "token_config", "app_config", "api_key_config"]: + section_map = { + "admin_config": "admin", + "token_config": "auth", + "app_config": "app", + "api_key_config": "api_key", + } + return self.get_state_section(section_map[filename]) + + config_path = self.get_config_path(filename) + + if not config_path.exists(): + return None + + try: + with open(config_path) as f: + return json.load(f) + except (OSError, json.JSONDecodeError) as e: + print(f"❌ Error reading {filename}: {e}") + return None + + def write_config(self, filename: str, data: dict[str, Any]) -> bool: + """Write data to a configuration file. + + DEPRECATED: Use write_state() or update_state_section() for new code. + This method provides backward compatibility. + + Args: + filename: Name of the config file to write + data: Dictionary containing data to save + + Returns: + True if successful, False otherwise + """ + # Provide backward compatibility for old config names + if filename in ["admin_config", "token_config", "app_config", "api_key_config"]: + section_map = { + "admin_config": "admin", + "token_config": "auth", + "app_config": "app", + "api_key_config": "api_key", + } + return self.update_state_section(section_map[filename], data) + + self.ensure_config_dir() + config_path = self.get_config_path(filename) + + try: + with open(config_path, "w") as f: + json.dump(data, f, indent=2) + return True + except OSError as e: + print(f"❌ Error writing {filename}: {e}") + return False + + def config_exists(self, filename: str) -> bool: + """Check if a config file exists. + + Args: + filename: Name of the config file to check + + Returns: + True if file exists, False otherwise + """ + return self.get_config_path(filename).exists() + + def delete_config(self, filename: str) -> bool: + """Delete a configuration file. + + Args: + filename: Name of the config file to delete + + Returns: + True if successful, False otherwise + """ + config_path = self.get_config_path(filename) + + if not config_path.exists(): + return True # Already doesn't exist + + try: + config_path.unlink() + return True + except OSError as e: + print(f"❌ Error deleting {filename}: {e}") + return False + + def read_state(self) -> dict[str, Any] | None: + """Read the entire stress test state. + + Returns: + Dictionary containing all state data, or None if file doesn't exist + """ + state_path = self.get_config_path(self.state_file) + if not state_path.exists(): + return None + + try: + with open(state_path) as f: + return json.load(f) + except (OSError, json.JSONDecodeError) as e: + print(f"❌ Error reading {self.state_file}: {e}") + return None + + def write_state(self, data: dict[str, Any]) -> bool: + """Write the entire stress test state. + + Args: + data: Dictionary containing all state data to save + + Returns: + True if successful, False otherwise + """ + self.ensure_config_dir() + state_path = self.get_config_path(self.state_file) + + try: + with open(state_path, "w") as f: + json.dump(data, f, indent=2) + return True + except OSError as e: + print(f"❌ Error writing {self.state_file}: {e}") + return False + + def update_state_section(self, section: str, data: dict[str, Any]) -> bool: + """Update a specific section of the stress test state. + + Args: + section: Name of the section to update (e.g., 'admin', 'auth', 'app', 'api_key') + data: Dictionary containing section data to save + + Returns: + True if successful, False otherwise + """ + state = self.read_state() or {} + state[section] = data + return self.write_state(state) + + def get_state_section(self, section: str) -> dict[str, Any] | None: + """Get a specific section from the stress test state. + + Args: + section: Name of the section to get (e.g., 'admin', 'auth', 'app', 'api_key') + + Returns: + Dictionary containing section data, or None if not found + """ + state = self.read_state() + if state: + return state.get(section) + return None + + def get_token(self) -> str | None: + """Get the access token from auth section. + + Returns: + Access token string or None if not found + """ + auth = self.get_state_section("auth") + if auth: + return auth.get("access_token") + return None + + def get_app_id(self) -> str | None: + """Get the app ID from app section. + + Returns: + App ID string or None if not found + """ + app = self.get_state_section("app") + if app: + return app.get("app_id") + return None + + def get_api_key(self) -> str | None: + """Get the API key token from api_key section. + + Returns: + API key token string or None if not found + """ + api_key = self.get_state_section("api_key") + if api_key: + return api_key.get("token") + return None + + +# Create a default instance for convenience +config_helper = ConfigHelper() diff --git a/scripts/stress-test/common/logger_helper.py b/scripts/stress-test/common/logger_helper.py new file mode 100644 index 0000000000..c522685f1d --- /dev/null +++ b/scripts/stress-test/common/logger_helper.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 + +import sys +import time +from enum import Enum + + +class LogLevel(Enum): + """Log levels with associated colors and symbols.""" + + DEBUG = ("🔍", "\033[90m") # Gray + INFO = ("ℹ️ ", "\033[94m") # Blue + SUCCESS = ("✅", "\033[92m") # Green + WARNING = ("⚠️ ", "\033[93m") # Yellow + ERROR = ("❌", "\033[91m") # Red + STEP = ("🚀", "\033[96m") # Cyan + PROGRESS = ("📋", "\033[95m") # Magenta + + +class Logger: + """Logger class for formatted console output.""" + + def __init__(self, name: str | None = None, use_colors: bool = True): + """Initialize logger. + + Args: + name: Optional name for the logger (e.g., script name) + use_colors: Whether to use ANSI color codes + """ + self.name = name + self.use_colors = use_colors and sys.stdout.isatty() + self._reset_color = "\033[0m" if self.use_colors else "" + + def _format_message(self, level: LogLevel, message: str, indent: int = 0) -> str: + """Format a log message with level, color, and indentation. + + Args: + level: Log level + message: Message to log + indent: Number of spaces to indent + + Returns: + Formatted message string + """ + symbol, color = level.value + color = color if self.use_colors else "" + reset = self._reset_color + + prefix = " " * indent + + if self.name and level in [LogLevel.STEP, LogLevel.ERROR]: + return f"{prefix}{color}{symbol} [{self.name}] {message}{reset}" + else: + return f"{prefix}{color}{symbol} {message}{reset}" + + def debug(self, message: str, indent: int = 0) -> None: + """Log debug message.""" + print(self._format_message(LogLevel.DEBUG, message, indent)) + + def info(self, message: str, indent: int = 0) -> None: + """Log info message.""" + print(self._format_message(LogLevel.INFO, message, indent)) + + def success(self, message: str, indent: int = 0) -> None: + """Log success message.""" + print(self._format_message(LogLevel.SUCCESS, message, indent)) + + def warning(self, message: str, indent: int = 0) -> None: + """Log warning message.""" + print(self._format_message(LogLevel.WARNING, message, indent)) + + def error(self, message: str, indent: int = 0) -> None: + """Log error message.""" + print(self._format_message(LogLevel.ERROR, message, indent), file=sys.stderr) + + def step(self, message: str, indent: int = 0) -> None: + """Log a step in a process.""" + print(self._format_message(LogLevel.STEP, message, indent)) + + def progress(self, message: str, indent: int = 0) -> None: + """Log progress information.""" + print(self._format_message(LogLevel.PROGRESS, message, indent)) + + def separator(self, char: str = "-", length: int = 60) -> None: + """Print a separator line.""" + print(char * length) + + def header(self, title: str, width: int = 60) -> None: + """Print a formatted header.""" + if self.use_colors: + print(f"\n\033[1m{'=' * width}\033[0m") # Bold + print(f"\033[1m{title.center(width)}\033[0m") + print(f"\033[1m{'=' * width}\033[0m\n") + else: + print(f"\n{'=' * width}") + print(title.center(width)) + print(f"{'=' * width}\n") + + def box(self, title: str, width: int = 60) -> None: + """Print a title in a box.""" + border = "═" * (width - 2) + if self.use_colors: + print(f"\033[1m╔{border}╗\033[0m") + print(f"\033[1m║{title.center(width - 2)}║\033[0m") + print(f"\033[1m╚{border}╝\033[0m") + else: + print(f"╔{border}╗") + print(f"║{title.center(width - 2)}║") + print(f"╚{border}╝") + + def list_item(self, item: str, indent: int = 2) -> None: + """Print a list item.""" + prefix = " " * indent + print(f"{prefix}• {item}") + + def key_value(self, key: str, value: str, indent: int = 2) -> None: + """Print a key-value pair.""" + prefix = " " * indent + if self.use_colors: + print(f"{prefix}\033[1m{key}:\033[0m {value}") + else: + print(f"{prefix}{key}: {value}") + + def spinner_start(self, message: str) -> None: + """Start a spinner (simple implementation).""" + sys.stdout.write(f"\r{message}... ") + sys.stdout.flush() + + def spinner_stop(self, success: bool = True, message: str | None = None) -> None: + """Stop the spinner and show result.""" + if success: + symbol = "✅" if message else "Done" + sys.stdout.write(f"\r{symbol} {message or ''}\n") + else: + symbol = "❌" if message else "Failed" + sys.stdout.write(f"\r{symbol} {message or ''}\n") + sys.stdout.flush() + + +class ProgressLogger: + """Logger for tracking progress through multiple steps.""" + + def __init__(self, total_steps: int, logger: Logger | None = None): + """Initialize progress logger. + + Args: + total_steps: Total number of steps + logger: Logger instance to use (creates new if None) + """ + self.total_steps = total_steps + self.current_step = 0 + self.logger = logger or Logger() + self.start_time = time.time() + + def next_step(self, description: str) -> None: + """Move to next step and log it.""" + self.current_step += 1 + elapsed = time.time() - self.start_time + + if self.logger.use_colors: + progress_bar = self._create_progress_bar() + print(f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}") + self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)") + else: + print(f"\n[Step {self.current_step}/{self.total_steps}]") + self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)") + + def _create_progress_bar(self, width: int = 20) -> str: + """Create a simple progress bar.""" + filled = int(width * self.current_step / self.total_steps) + bar = "█" * filled + "░" * (width - filled) + percentage = int(100 * self.current_step / self.total_steps) + return f"[{bar}] {percentage}%" + + def complete(self) -> None: + """Mark progress as complete.""" + elapsed = time.time() - self.start_time + self.logger.success(f"All steps completed! Total time: {elapsed:.1f}s") + + +# Create default logger instance +logger = Logger() + + +# Convenience functions using default logger +def debug(message: str, indent: int = 0) -> None: + """Log debug message using default logger.""" + logger.debug(message, indent) + + +def info(message: str, indent: int = 0) -> None: + """Log info message using default logger.""" + logger.info(message, indent) + + +def success(message: str, indent: int = 0) -> None: + """Log success message using default logger.""" + logger.success(message, indent) + + +def warning(message: str, indent: int = 0) -> None: + """Log warning message using default logger.""" + logger.warning(message, indent) + + +def error(message: str, indent: int = 0) -> None: + """Log error message using default logger.""" + logger.error(message, indent) + + +def step(message: str, indent: int = 0) -> None: + """Log step using default logger.""" + logger.step(message, indent) + + +def progress(message: str, indent: int = 0) -> None: + """Log progress using default logger.""" + logger.progress(message, indent) diff --git a/scripts/stress-test/locust.conf b/scripts/stress-test/locust.conf new file mode 100644 index 0000000000..87bd8c2870 --- /dev/null +++ b/scripts/stress-test/locust.conf @@ -0,0 +1,37 @@ +# Locust configuration file for Dify SSE benchmark + +# Target host +host = http://localhost:5001 + +# Number of users to simulate +users = 10 + +# Spawn rate (users per second) +spawn-rate = 2 + +# Run time (use format like 30s, 5m, 1h) +run-time = 1m + +# Locustfile to use +locustfile = scripts/stress-test/sse_benchmark.py + +# Headless mode (no web UI) +headless = true + +# Print stats in the console +print-stats = true + +# Only print summary stats +only-summary = false + +# Reset statistics after ramp-up +reset-stats = false + +# Log level +loglevel = INFO + +# CSV output (uncomment to enable) +# csv = reports/locust_results + +# HTML report (uncomment to enable) +# html = reports/locust_report.html \ No newline at end of file diff --git a/scripts/stress-test/run_locust_stress_test.sh b/scripts/stress-test/run_locust_stress_test.sh new file mode 100755 index 0000000000..665cb68754 --- /dev/null +++ b/scripts/stress-test/run_locust_stress_test.sh @@ -0,0 +1,202 @@ +#!/bin/bash + +# Run Dify SSE Stress Test using Locust + +set -e + +# Get the directory where this script is located +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# Go to project root first, then to script dir +PROJECT_ROOT="$( cd "${SCRIPT_DIR}/../.." && pwd )" +cd "${PROJECT_ROOT}" +STRESS_TEST_DIR="scripts/stress-test" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +CYAN='\033[0;36m' +NC='\033[0m' # No Color + +# Configuration +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +REPORT_DIR="${STRESS_TEST_DIR}/reports" +CSV_PREFIX="${REPORT_DIR}/locust_${TIMESTAMP}" +HTML_REPORT="${REPORT_DIR}/locust_report_${TIMESTAMP}.html" +SUMMARY_REPORT="${REPORT_DIR}/locust_summary_${TIMESTAMP}.txt" + +# Create reports directory if it doesn't exist +mkdir -p "${REPORT_DIR}" + +echo -e "${BLUE}╔════════════════════════════════════════════════════════════════╗${NC}" +echo -e "${BLUE}║ DIFY SSE WORKFLOW STRESS TEST (LOCUST) ║${NC}" +echo -e "${BLUE}╚════════════════════════════════════════════════════════════════╝${NC}" +echo + +# Check if services are running +echo -e "${YELLOW}Checking services...${NC}" + +# Check Dify API +if curl -s -f http://localhost:5001/health > /dev/null 2>&1; then + echo -e "${GREEN}✓ Dify API is running${NC}" + + # Warn if running in debug mode (check for werkzeug in process) + if ps aux | grep -v grep | grep -q "werkzeug.*5001\|flask.*run.*5001"; then + echo -e "${YELLOW}⚠ WARNING: API appears to be running in debug mode (Flask development server)${NC}" + echo -e "${YELLOW} This will give inaccurate benchmark results!${NC}" + echo -e "${YELLOW} For accurate benchmarking, restart with Gunicorn:${NC}" + echo -e "${CYAN} cd api && uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent app:app${NC}" + echo + echo -n "Continue anyway? (not recommended) [y/N]: " + read -t 10 continue_debug || continue_debug="n" + if [ "$continue_debug" != "y" ] && [ "$continue_debug" != "Y" ]; then + echo -e "${RED}Benchmark cancelled. Please restart API with Gunicorn.${NC}" + exit 1 + fi + fi +else + echo -e "${RED}✗ Dify API is not running on port 5001${NC}" + echo -e "${YELLOW} Start it with Gunicorn for accurate benchmarking:${NC}" + echo -e "${CYAN} cd api && uv run gunicorn --bind 0.0.0.0:5001 --workers 4 --worker-class gevent app:app${NC}" + exit 1 +fi + +# Check Mock OpenAI server +if curl -s -f http://localhost:5004/v1/models > /dev/null 2>&1; then + echo -e "${GREEN}✓ Mock OpenAI server is running${NC}" +else + echo -e "${RED}✗ Mock OpenAI server is not running on port 5004${NC}" + echo -e "${YELLOW} Start it with: python scripts/stress-test/setup/mock_openai_server.py${NC}" + exit 1 +fi + +# Check API token exists +if [ ! -f "${STRESS_TEST_DIR}/setup/config/stress_test_state.json" ]; then + echo -e "${RED}✗ Stress test configuration not found${NC}" + echo -e "${YELLOW} Run setup first: python scripts/stress-test/setup_all.py${NC}" + exit 1 +fi + +API_TOKEN=$(python3 -c "import json; state = json.load(open('${STRESS_TEST_DIR}/setup/config/stress_test_state.json')); print(state.get('api_key', {}).get('token', ''))" 2>/dev/null) +if [ -z "$API_TOKEN" ]; then + echo -e "${RED}✗ Failed to read API token from stress test state${NC}" + exit 1 +fi +echo -e "${GREEN}✓ API token found: ${API_TOKEN:0:10}...${NC}" + +echo +echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" +echo -e "${CYAN} STRESS TEST PARAMETERS ${NC}" +echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" + +# Parse configuration +USERS=$(grep "^users" ${STRESS_TEST_DIR}/locust.conf | cut -d'=' -f2 | tr -d ' ') +SPAWN_RATE=$(grep "^spawn-rate" ${STRESS_TEST_DIR}/locust.conf | cut -d'=' -f2 | tr -d ' ') +RUN_TIME=$(grep "^run-time" ${STRESS_TEST_DIR}/locust.conf | cut -d'=' -f2 | tr -d ' ') + +echo -e " ${YELLOW}Users:${NC} $USERS concurrent users" +echo -e " ${YELLOW}Spawn Rate:${NC} $SPAWN_RATE users/second" +echo -e " ${YELLOW}Duration:${NC} $RUN_TIME" +echo -e " ${YELLOW}Mode:${NC} SSE Streaming" +echo + +# Ask user for run mode +echo -e "${YELLOW}Select run mode:${NC}" +echo " 1) Headless (CLI only) - Default" +echo " 2) Web UI (http://localhost:8089)" +echo -n "Choice [1]: " +read -t 10 choice || choice="1" +echo + +# Use SSE stress test script +LOCUST_SCRIPT="${STRESS_TEST_DIR}/sse_benchmark.py" + +# Prepare Locust command +if [ "$choice" = "2" ]; then + echo -e "${BLUE}Starting Locust with Web UI...${NC}" + echo -e "${YELLOW}Access the web interface at: ${CYAN}http://localhost:8089${NC}" + echo + + # Run with web UI + uv --project api run locust \ + -f ${LOCUST_SCRIPT} \ + --host http://localhost:5001 \ + --web-port 8089 +else + echo -e "${BLUE}Starting stress test in headless mode...${NC}" + echo + + # Run in headless mode with CSV output + uv --project api run locust \ + -f ${LOCUST_SCRIPT} \ + --host http://localhost:5001 \ + --users $USERS \ + --spawn-rate $SPAWN_RATE \ + --run-time $RUN_TIME \ + --headless \ + --print-stats \ + --csv=$CSV_PREFIX \ + --html=$HTML_REPORT \ + 2>&1 | tee $SUMMARY_REPORT + + echo + echo -e "${GREEN}═══════════════════════════════════════════════════════════════${NC}" + echo -e "${GREEN} STRESS TEST COMPLETE ${NC}" + echo -e "${GREEN}═══════════════════════════════════════════════════════════════${NC}" + echo + echo -e "${BLUE}Reports generated:${NC}" + echo -e " ${YELLOW}Summary:${NC} $SUMMARY_REPORT" + echo -e " ${YELLOW}HTML Report:${NC} $HTML_REPORT" + echo -e " ${YELLOW}CSV Stats:${NC} ${CSV_PREFIX}_stats.csv" + echo -e " ${YELLOW}CSV History:${NC} ${CSV_PREFIX}_stats_history.csv" + echo + echo -e "${CYAN}View HTML report:${NC}" + echo " open $HTML_REPORT # macOS" + echo " xdg-open $HTML_REPORT # Linux" + echo + + # Parse and display key metrics + echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" + echo -e "${CYAN} KEY METRICS ${NC}" + echo -e "${CYAN}═══════════════════════════════════════════════════════════════${NC}" + + if [ -f "${CSV_PREFIX}_stats.csv" ]; then + python3 - < None: + """Configure OpenAI plugin with mock server credentials.""" + + log = Logger("ConfigPlugin") + log.header("Configuring OpenAI Plugin") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + log.info("Please run login_admin.py first to get access token") + return + + log.step("Configuring OpenAI plugin with mock server...") + + # API endpoint for plugin configuration + base_url = "http://localhost:5001" + config_endpoint = f"{base_url}/console/api/workspaces/current/model-providers/langgenius/openai/openai/credentials" + + # Configuration payload with mock server + config_payload = { + "credentials": { + "openai_api_key": "apikey", + "openai_organization": None, + "openai_api_base": "http://host.docker.internal:5004", + } + } + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the configuration request + with httpx.Client() as client: + response = client.post( + config_endpoint, + json=config_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200: + log.success("OpenAI plugin configured successfully!") + log.key_value("API Base", config_payload["credentials"]["openai_api_base"]) + log.key_value("API Key", config_payload["credentials"]["openai_api_key"]) + + elif response.status_code == 201: + log.success("OpenAI plugin credentials created successfully!") + log.key_value("API Base", config_payload["credentials"]["openai_api_base"]) + log.key_value("API Key", config_payload["credentials"]["openai_api_key"]) + + elif response.status_code == 401: + log.error("Configuration failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + else: + log.error(f"Configuration failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + configure_openai_plugin() diff --git a/scripts/stress-test/setup/create_api_key.py b/scripts/stress-test/setup/create_api_key.py new file mode 100755 index 0000000000..cd04fe57eb --- /dev/null +++ b/scripts/stress-test/setup/create_api_key.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import json + +import httpx +from common import Logger, config_helper + + +def create_api_key() -> None: + """Create API key for the imported app.""" + + log = Logger("CreateAPIKey") + log.header("Creating API Key") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + return + + # Read app_id from config + app_id = config_helper.get_app_id() + if not app_id: + log.error("No app_id found in config") + log.info("Please run import_workflow_app.py first to import the app") + return + + log.step(f"Creating API key for app: {app_id}") + + # API endpoint for creating API key + base_url = "http://localhost:5001" + api_key_endpoint = f"{base_url}/console/api/apps/{app_id}/api-keys" + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Length": "0", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the API key creation request + with httpx.Client() as client: + response = client.post( + api_key_endpoint, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200 or response.status_code == 201: + response_data = response.json() + + api_key_id = response_data.get("id") + api_key_token = response_data.get("token") + + if api_key_token: + log.success("API key created successfully!") + log.key_value("Key ID", api_key_id) + log.key_value("Token", api_key_token) + log.key_value("Type", response_data.get("type")) + + # Save API key to config + api_key_config = { + "id": api_key_id, + "token": api_key_token, + "type": response_data.get("type"), + "app_id": app_id, + "created_at": response_data.get("created_at"), + } + + if config_helper.write_config("api_key_config", api_key_config): + log.info(f"API key saved to: {config_helper.get_config_path('benchmark_state')}") + else: + log.error("No API token received") + log.debug(f"Response: {json.dumps(response_data, indent=2)}") + + elif response.status_code == 401: + log.error("API key creation failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + else: + log.error(f"API key creation failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + create_api_key() diff --git a/scripts/stress-test/setup/dsl/workflow_llm.yml b/scripts/stress-test/setup/dsl/workflow_llm.yml new file mode 100644 index 0000000000..c0fd2c7d8b --- /dev/null +++ b/scripts/stress-test/setup/dsl/workflow_llm.yml @@ -0,0 +1,176 @@ +app: + description: '' + icon: 🤖 + icon_background: '#FFEAD5' + mode: workflow + name: workflow_llm + use_icon_as_answer_icon: false +dependencies: +- current_identifier: null + type: marketplace + value: + marketplace_plugin_unique_identifier: langgenius/openai:0.2.5@373362a028986aae53a7baf73a7f11991ba3c22c69eaf97d6cde048cfd4a9f98 +kind: app +version: 0.4.0 +workflow: + conversation_variables: [] + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_size_limit: 10 + video_file_size_limit: 100 + workflow_file_upload_limit: 10 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - data: + isInIteration: false + isInLoop: false + sourceType: start + targetType: llm + id: 1757611990947-source-1757611992921-target + source: '1757611990947' + sourceHandle: source + target: '1757611992921' + targetHandle: target + type: custom + zIndex: 0 + - data: + isInIteration: false + isInLoop: false + sourceType: llm + targetType: end + id: 1757611992921-source-1757611996447-target + source: '1757611992921' + sourceHandle: source + target: '1757611996447' + targetHandle: target + type: custom + zIndex: 0 + nodes: + - data: + desc: '' + selected: false + title: Start + type: start + variables: + - label: question + max_length: null + options: [] + required: true + type: text-input + variable: question + height: 90 + id: '1757611990947' + position: + x: 30 + y: 245 + positionAbsolute: + x: 30 + y: 245 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 244 + - data: + context: + enabled: false + variable_selector: [] + desc: '' + model: + completion_params: + temperature: 0.7 + mode: chat + name: gpt-4o + provider: langgenius/openai/openai + prompt_template: + - id: c165fcb6-f1f0-42f2-abab-e81982434deb + role: system + text: '' + - role: user + text: '{{#1757611990947.question#}}' + selected: false + title: LLM + type: llm + variables: [] + vision: + enabled: false + height: 90 + id: '1757611992921' + position: + x: 334 + y: 245 + positionAbsolute: + x: 334 + y: 245 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 244 + - data: + desc: '' + outputs: + - value_selector: + - '1757611992921' + - text + value_type: string + variable: answer + selected: false + title: End + type: end + height: 90 + id: '1757611996447' + position: + x: 638 + y: 245 + positionAbsolute: + x: 638 + y: 245 + selected: true + sourcePosition: right + targetPosition: left + type: custom + width: 244 + viewport: + x: 0 + y: 0 + zoom: 0.7 diff --git a/scripts/stress-test/setup/import_workflow_app.py b/scripts/stress-test/setup/import_workflow_app.py new file mode 100755 index 0000000000..86d0239e35 --- /dev/null +++ b/scripts/stress-test/setup/import_workflow_app.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import json + +import httpx +from common import Logger, config_helper + + +def import_workflow_app() -> None: + """Import workflow app from DSL file and save app_id.""" + + log = Logger("ImportApp") + log.header("Importing Workflow Application") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + log.info("Please run login_admin.py first to get access token") + return + + # Read workflow DSL file + dsl_path = Path(__file__).parent / "dsl" / "workflow_llm.yml" + + if not dsl_path.exists(): + log.error(f"DSL file not found: {dsl_path}") + return + + with open(dsl_path) as f: + yaml_content = f.read() + + log.step("Importing workflow app from DSL...") + log.key_value("DSL file", dsl_path.name) + + # API endpoint for app import + base_url = "http://localhost:5001" + import_endpoint = f"{base_url}/console/api/apps/imports" + + # Import payload + import_payload = {"mode": "yaml-content", "yaml_content": yaml_content} + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the import request + with httpx.Client() as client: + response = client.post( + import_endpoint, + json=import_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200: + response_data = response.json() + + # Check import status + if response_data.get("status") == "completed": + app_id = response_data.get("app_id") + + if app_id: + log.success("Workflow app imported successfully!") + log.key_value("App ID", app_id) + log.key_value("App Mode", response_data.get("app_mode")) + log.key_value("DSL Version", response_data.get("imported_dsl_version")) + + # Save app_id to config + app_config = { + "app_id": app_id, + "app_mode": response_data.get("app_mode"), + "app_name": "workflow_llm", + "dsl_version": response_data.get("imported_dsl_version"), + } + + if config_helper.write_config("app_config", app_config): + log.info(f"App config saved to: {config_helper.get_config_path('benchmark_state')}") + else: + log.error("Import completed but no app_id received") + log.debug(f"Response: {json.dumps(response_data, indent=2)}") + + elif response_data.get("status") == "failed": + log.error("Import failed") + log.error(f"Error: {response_data.get('error')}") + else: + log.warning(f"Import status: {response_data.get('status')}") + log.debug(f"Response: {json.dumps(response_data, indent=2)}") + + elif response.status_code == 401: + log.error("Import failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + else: + log.error(f"Import failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + import_workflow_app() diff --git a/scripts/stress-test/setup/install_openai_plugin.py b/scripts/stress-test/setup/install_openai_plugin.py new file mode 100755 index 0000000000..055e5661f8 --- /dev/null +++ b/scripts/stress-test/setup/install_openai_plugin.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import time + +import httpx +from common import Logger, config_helper + + +def install_openai_plugin() -> None: + """Install OpenAI plugin using saved access token.""" + + log = Logger("InstallPlugin") + log.header("Installing OpenAI Plugin") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + log.info("Please run login_admin.py first to get access token") + return + + log.step("Installing OpenAI plugin...") + + # API endpoint for plugin installation + base_url = "http://localhost:5001" + install_endpoint = f"{base_url}/console/api/workspaces/current/plugin/install/marketplace" + + # Plugin identifier + plugin_payload = { + "plugin_unique_identifiers": [ + "langgenius/openai:0.2.5@373362a028986aae53a7baf73a7f11991ba3c22c69eaf97d6cde048cfd4a9f98" + ] + } + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the installation request + with httpx.Client() as client: + response = client.post( + install_endpoint, + json=plugin_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200: + response_data = response.json() + task_id = response_data.get("task_id") + + if not task_id: + log.error("No task ID received from installation request") + return + + log.progress(f"Installation task created: {task_id}") + log.info("Polling for task completion...") + + # Poll for task completion + task_endpoint = f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}" + + max_attempts = 30 # 30 attempts with 2 second delay = 60 seconds max + attempt = 0 + + log.spinner_start("Installing plugin") + + while attempt < max_attempts: + attempt += 1 + time.sleep(2) # Wait 2 seconds between polls + + task_response = client.get( + task_endpoint, + headers=headers, + cookies=cookies, + ) + + if task_response.status_code != 200: + log.spinner_stop( + success=False, + message=f"Failed to get task status: {task_response.status_code}", + ) + return + + task_data = task_response.json() + task_info = task_data.get("task", {}) + status = task_info.get("status") + + if status == "success": + log.spinner_stop(success=True, message="Plugin installed!") + log.success("OpenAI plugin installed successfully!") + + # Display plugin info + plugins = task_info.get("plugins", []) + if plugins: + plugin_info = plugins[0] + log.key_value("Plugin ID", plugin_info.get("plugin_id")) + log.key_value("Message", plugin_info.get("message")) + break + + elif status == "failed": + log.spinner_stop(success=False, message="Installation failed") + log.error("Plugin installation failed") + plugins = task_info.get("plugins", []) + if plugins: + for plugin in plugins: + log.list_item(f"{plugin.get('plugin_id')}: {plugin.get('message')}") + break + + # Continue polling if status is "pending" or other + + else: + log.spinner_stop(success=False, message="Installation timed out") + log.error("Installation timed out after 60 seconds") + + elif response.status_code == 401: + log.error("Installation failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + elif response.status_code == 409: + log.warning("Plugin may already be installed") + log.debug(f"Response: {response.text}") + else: + log.error(f"Installation failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + install_openai_plugin() diff --git a/scripts/stress-test/setup/login_admin.py b/scripts/stress-test/setup/login_admin.py new file mode 100755 index 0000000000..572b8fb650 --- /dev/null +++ b/scripts/stress-test/setup/login_admin.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import json + +import httpx +from common import Logger, config_helper + + +def login_admin() -> None: + """Login with admin account and save access token.""" + + log = Logger("Login") + log.header("Admin Login") + + # Read admin credentials from config + admin_config = config_helper.read_config("admin_config") + + if not admin_config: + log.error("Admin config not found") + log.info("Please run setup_admin.py first to create the admin account") + return + + log.info(f"Logging in with email: {admin_config['email']}") + + # API login endpoint + base_url = "http://localhost:5001" + login_endpoint = f"{base_url}/console/api/login" + + # Prepare login payload + login_payload = { + "email": admin_config["email"], + "password": admin_config["password"], + "remember_me": True, + } + + try: + # Make the login request + with httpx.Client() as client: + response = client.post( + login_endpoint, + json=login_payload, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code == 200: + log.success("Login successful!") + + # Extract token from response + response_data = response.json() + + # Check if login was successful + if response_data.get("result") != "success": + log.error(f"Login failed: {response_data}") + return + + # Extract tokens from data field + token_data = response_data.get("data", {}) + access_token = token_data.get("access_token", "") + refresh_token = token_data.get("refresh_token", "") + + if not access_token: + log.error("No access token found in response") + log.debug(f"Full response: {json.dumps(response_data, indent=2)}") + return + + # Save token to config file + token_config = { + "email": admin_config["email"], + "access_token": access_token, + "refresh_token": refresh_token, + } + + # Save token config + if config_helper.write_config("token_config", token_config): + log.info(f"Token saved to: {config_helper.get_config_path('benchmark_state')}") + + # Show truncated token for verification + token_display = f"{access_token[:20]}..." if len(access_token) > 20 else "Token saved" + log.key_value("Access token", token_display) + + elif response.status_code == 401: + log.error("Login failed: Invalid credentials") + log.debug(f"Response: {response.text}") + else: + log.error(f"Login failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + login_admin() diff --git a/scripts/stress-test/setup/mock_openai_server.py b/scripts/stress-test/setup/mock_openai_server.py new file mode 100755 index 0000000000..7333c66e57 --- /dev/null +++ b/scripts/stress-test/setup/mock_openai_server.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 + +import json +import time +import uuid +from collections.abc import Iterator +from typing import Any + +from flask import Flask, Response, jsonify, request + +app = Flask(__name__) + +# Mock models list +MODELS = [ + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1677649963, + "owned_by": "openai", + }, + {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, + { + "id": "text-embedding-ada-002", + "object": "model", + "created": 1671217299, + "owned_by": "openai-internal", + }, +] + + +@app.route("/v1/models", methods=["GET"]) +def list_models() -> Any: + """List available models.""" + return jsonify({"object": "list", "data": MODELS}) + + +@app.route("/v1/chat/completions", methods=["POST"]) +def chat_completions() -> Any: + """Handle chat completions.""" + data = request.json or {} + model = data.get("model", "gpt-3.5-turbo") + messages = data.get("messages", []) + stream = data.get("stream", False) + + # Generate mock response + response_content = "This is a mock response from the OpenAI server." + if messages: + last_message = messages[-1].get("content", "") + response_content = f"Mock response to: {last_message[:100]}..." + + if stream: + # Streaming response + def generate() -> Iterator[str]: + # Send initial chunk + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(chunk)}\n\n" + + # Send content in chunks + words = response_content.split() + for word in words: + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": word + " "}, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(chunk)}\n\n" + time.sleep(0.05) # Simulate streaming delay + + # Send final chunk + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + + return Response(generate(), mimetype="text/event-stream") + else: + # Non-streaming response + return jsonify( + { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": response_content}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": len(str(messages)), + "completion_tokens": len(response_content.split()), + "total_tokens": len(str(messages)) + len(response_content.split()), + }, + } + ) + + +@app.route("/v1/completions", methods=["POST"]) +def completions() -> Any: + """Handle text completions.""" + data = request.json or {} + model = data.get("model", "gpt-3.5-turbo-instruct") + prompt = data.get("prompt", "") + + response_text = f"Mock completion for prompt: {prompt[:100]}..." + + return jsonify( + { + "id": f"cmpl-{uuid.uuid4().hex[:8]}", + "object": "text_completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "text": response_text, + "index": 0, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": len(prompt.split()), + "completion_tokens": len(response_text.split()), + "total_tokens": len(prompt.split()) + len(response_text.split()), + }, + } + ) + + +@app.route("/v1/embeddings", methods=["POST"]) +def embeddings() -> Any: + """Handle embeddings requests.""" + data = request.json or {} + model = data.get("model", "text-embedding-ada-002") + input_text = data.get("input", "") + + # Generate mock embedding (1536 dimensions for ada-002) + mock_embedding = [0.1] * 1536 + + return jsonify( + { + "object": "list", + "data": [{"object": "embedding", "embedding": mock_embedding, "index": 0}], + "model": model, + "usage": { + "prompt_tokens": len(input_text.split()), + "total_tokens": len(input_text.split()), + }, + } + ) + + +@app.route("/v1/models/", methods=["GET"]) +def get_model(model_id: str) -> tuple[Any, int] | Any: + """Get specific model details.""" + for model in MODELS: + if model["id"] == model_id: + return jsonify(model) + + return jsonify({"error": "Model not found"}), 404 + + +@app.route("/health", methods=["GET"]) +def health() -> Any: + """Health check endpoint.""" + return jsonify({"status": "healthy"}) + + +if __name__ == "__main__": + print("🚀 Starting Mock OpenAI Server on http://localhost:5004") + print("Available endpoints:") + print(" - GET /v1/models") + print(" - POST /v1/chat/completions") + print(" - POST /v1/completions") + print(" - POST /v1/embeddings") + print(" - GET /v1/models/") + print(" - GET /health") + app.run(host="0.0.0.0", port=5004, debug=True) diff --git a/scripts/stress-test/setup/publish_workflow.py b/scripts/stress-test/setup/publish_workflow.py new file mode 100755 index 0000000000..b772eccebd --- /dev/null +++ b/scripts/stress-test/setup/publish_workflow.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import json + +import httpx +from common import Logger, config_helper + + +def publish_workflow() -> None: + """Publish the imported workflow app.""" + + log = Logger("PublishWorkflow") + log.header("Publishing Workflow") + + # Read token from config + access_token = config_helper.get_token() + if not access_token: + log.error("No access token found in config") + return + + # Read app_id from config + app_id = config_helper.get_app_id() + if not app_id: + log.error("No app_id found in config") + return + + log.step(f"Publishing workflow for app: {app_id}") + + # API endpoint for publishing workflow + base_url = "http://localhost:5001" + publish_endpoint = f"{base_url}/console/api/apps/{app_id}/workflows/publish" + + # Publish payload + publish_payload = {"marked_name": "", "marked_comment": ""} + + headers = { + "Accept": "*/*", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "DNT": "1", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36", + "authorization": f"Bearer {access_token}", + "content-type": "application/json", + "sec-ch-ua": '"Not;A=Brand";v="99", "Google Chrome";v="139", "Chromium";v="139"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + } + + cookies = {"locale": "en-US"} + + try: + # Make the publish request + with httpx.Client() as client: + response = client.post( + publish_endpoint, + json=publish_payload, + headers=headers, + cookies=cookies, + ) + + if response.status_code == 200 or response.status_code == 201: + log.success("Workflow published successfully!") + log.key_value("App ID", app_id) + + # Try to parse response if it has JSON content + if response.text: + try: + response_data = response.json() + if response_data: + log.debug(f"Response: {json.dumps(response_data, indent=2)}") + except json.JSONDecodeError: + # Response might be empty or non-JSON + pass + + elif response.status_code == 401: + log.error("Workflow publish failed: Unauthorized") + log.info("Token may have expired. Please run login_admin.py again") + elif response.status_code == 404: + log.error("Workflow publish failed: App not found") + log.info("Make sure the app was imported successfully") + else: + log.error(f"Workflow publish failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + publish_workflow() diff --git a/scripts/stress-test/setup/run_workflow.py b/scripts/stress-test/setup/run_workflow.py new file mode 100755 index 0000000000..6da0ff17be --- /dev/null +++ b/scripts/stress-test/setup/run_workflow.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import json + +import httpx +from common import Logger, config_helper + + +def run_workflow(question: str = "fake question", streaming: bool = True) -> None: + """Run the workflow app with a question.""" + + log = Logger("RunWorkflow") + log.header("Running Workflow") + + # Read API key from config + api_token = config_helper.get_api_key() + if not api_token: + log.error("No API token found in config") + log.info("Please run create_api_key.py first to create an API key") + return + + log.key_value("Question", question) + log.key_value("Mode", "Streaming" if streaming else "Blocking") + log.separator() + + # API endpoint for running workflow + base_url = "http://localhost:5001" + run_endpoint = f"{base_url}/v1/workflows/run" + + # Run payload + run_payload = { + "inputs": {"question": question}, + "user": "default user", + "response_mode": "streaming" if streaming else "blocking", + } + + headers = { + "Authorization": f"Bearer {api_token}", + "Content-Type": "application/json", + } + + try: + # Make the run request + with httpx.Client(timeout=30.0) as client: + if streaming: + # Handle streaming response + with client.stream( + "POST", + run_endpoint, + json=run_payload, + headers=headers, + ) as response: + if response.status_code == 200: + log.success("Workflow started successfully!") + log.separator() + log.step("Streaming response:") + + for line in response.iter_lines(): + if line.startswith("data: "): + data_str = line[6:] # Remove "data: " prefix + if data_str == "[DONE]": + log.success("Workflow completed!") + break + try: + data = json.loads(data_str) + event = data.get("event") + + if event == "workflow_started": + log.progress(f"Workflow started: {data.get('data', {}).get('id')}") + elif event == "node_started": + node_data = data.get("data", {}) + log.progress( + f"Node started: {node_data.get('node_type')} - {node_data.get('title')}" + ) + elif event == "node_finished": + node_data = data.get("data", {}) + log.progress( + f"Node finished: {node_data.get('node_type')} - {node_data.get('title')}" + ) + + # Print output if it's the LLM node + outputs = node_data.get("outputs", {}) + if outputs.get("text"): + log.separator() + log.info("💬 LLM Response:") + log.info(outputs.get("text"), indent=2) + log.separator() + + elif event == "workflow_finished": + workflow_data = data.get("data", {}) + outputs = workflow_data.get("outputs", {}) + if outputs.get("answer"): + log.separator() + log.info("📤 Final Answer:") + log.info(outputs.get("answer"), indent=2) + log.separator() + log.key_value( + "Total tokens", + str(workflow_data.get("total_tokens", 0)), + ) + log.key_value( + "Total steps", + str(workflow_data.get("total_steps", 0)), + ) + + elif event == "error": + log.error(f"Error: {data.get('message')}") + + except json.JSONDecodeError: + # Some lines might not be JSON + pass + else: + log.error(f"Workflow run failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + else: + # Handle blocking response + response = client.post( + run_endpoint, + json=run_payload, + headers=headers, + ) + + if response.status_code == 200: + log.success("Workflow completed successfully!") + response_data = response.json() + + log.separator() + log.debug(f"Full response: {json.dumps(response_data, indent=2)}") + + # Extract the answer if available + outputs = response_data.get("data", {}).get("outputs", {}) + if outputs.get("answer"): + log.separator() + log.info("📤 Final Answer:") + log.info(outputs.get("answer"), indent=2) + else: + log.error(f"Workflow run failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except httpx.TimeoutException: + log.error("Request timed out") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + # Allow passing question as command line argument + if len(sys.argv) > 1: + question = " ".join(sys.argv[1:]) + else: + question = "What is the capital of France?" + + run_workflow(question=question, streaming=True) diff --git a/scripts/stress-test/setup/setup_admin.py b/scripts/stress-test/setup/setup_admin.py new file mode 100755 index 0000000000..a5e9161210 --- /dev/null +++ b/scripts/stress-test/setup/setup_admin.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import httpx +from common import Logger, config_helper + + +def setup_admin_account() -> None: + """Setup Dify API with an admin account.""" + + log = Logger("SetupAdmin") + log.header("Setting up Admin Account") + + # Admin account credentials + admin_config = { + "email": "test@dify.ai", + "username": "dify", + "password": "password123", + } + + # Save credentials to config file + if config_helper.write_config("admin_config", admin_config): + log.info(f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}") + + # API setup endpoint + base_url = "http://localhost:5001" + setup_endpoint = f"{base_url}/console/api/setup" + + # Prepare setup payload + setup_payload = { + "email": admin_config["email"], + "name": admin_config["username"], + "password": admin_config["password"], + } + + log.step("Configuring Dify with admin account...") + + try: + # Make the setup request + with httpx.Client() as client: + response = client.post( + setup_endpoint, + json=setup_payload, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code == 201: + log.success("Admin account created successfully!") + log.key_value("Email", admin_config["email"]) + log.key_value("Username", admin_config["username"]) + + elif response.status_code == 400: + log.warning("Setup may have already been completed or invalid data provided") + log.debug(f"Response: {response.text}") + else: + log.error(f"Setup failed with status code: {response.status_code}") + log.debug(f"Response: {response.text}") + + except httpx.ConnectError: + log.error("Could not connect to Dify API at http://localhost:5001") + log.info("Make sure the API server is running with: ./dev/start-api") + except Exception as e: + log.error(f"An error occurred: {e}") + + +if __name__ == "__main__": + setup_admin_account() diff --git a/scripts/stress-test/setup_all.py b/scripts/stress-test/setup_all.py new file mode 100755 index 0000000000..ece420f925 --- /dev/null +++ b/scripts/stress-test/setup_all.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 + +import socket +import subprocess +import sys +import time +from pathlib import Path + +from common import Logger, ProgressLogger + + +def run_script(script_name: str, description: str) -> bool: + """Run a Python script and return success status.""" + script_path = Path(__file__).parent / "setup" / script_name + + if not script_path.exists(): + print(f"❌ Script not found: {script_path}") + return False + + print(f"\n{'=' * 60}") + print(f"🚀 {description}") + print(f" Running: {script_name}") + print(f"{'=' * 60}") + + try: + result = subprocess.run( + [sys.executable, str(script_path)], + capture_output=True, + text=True, + check=False, + ) + + # Print output + if result.stdout: + print(result.stdout) + if result.stderr: + print(result.stderr, file=sys.stderr) + + if result.returncode != 0: + print(f"❌ Script failed with exit code: {result.returncode}") + return False + + print(f"✅ {script_name} completed successfully") + return True + + except Exception as e: + print(f"❌ Error running {script_name}: {e}") + return False + + +def check_port(host: str, port: int, service_name: str) -> bool: + """Check if a service is running on the specified port.""" + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2) + result = sock.connect_ex((host, port)) + sock.close() + + if result == 0: + Logger().success(f"{service_name} is running on port {port}") + return True + else: + Logger().error(f"{service_name} is not accessible on port {port}") + return False + except Exception as e: + Logger().error(f"Error checking {service_name}: {e}") + return False + + +def main() -> None: + """Run all setup scripts in order.""" + + log = Logger("Setup") + log.box("Dify Stress Test Setup - Full Installation") + + # Check if required services are running + log.step("Checking required services...") + log.separator() + + dify_running = check_port("localhost", 5001, "Dify API server") + if not dify_running: + log.info("To start Dify API server:") + log.list_item("Run: ./dev/start-api") + + mock_running = check_port("localhost", 5004, "Mock OpenAI server") + if not mock_running: + log.info("To start Mock OpenAI server:") + log.list_item("Run: python scripts/stress-test/setup/mock_openai_server.py") + + if not dify_running or not mock_running: + print("\n⚠️ Both services must be running before proceeding.") + retry = input("\nWould you like to check again? (yes/no): ") + if retry.lower() in ["yes", "y"]: + return main() # Recursively call main to check again + else: + print("❌ Setup cancelled. Please start the required services and try again.") + sys.exit(1) + + log.success("All required services are running!") + input("\nPress Enter to continue with setup...") + + # Define setup steps + setup_steps = [ + ("setup_admin.py", "Creating admin account"), + ("login_admin.py", "Logging in and getting access token"), + ("install_openai_plugin.py", "Installing OpenAI plugin"), + ("configure_openai_plugin.py", "Configuring OpenAI plugin with mock server"), + ("import_workflow_app.py", "Importing workflow application"), + ("create_api_key.py", "Creating API key for the app"), + ("publish_workflow.py", "Publishing the workflow"), + ] + + # Create progress logger + progress = ProgressLogger(len(setup_steps), log) + failed_step = None + + for script, description in setup_steps: + progress.next_step(description) + success = run_script(script, description) + + if not success: + failed_step = script + break + + # Small delay between steps + time.sleep(1) + + log.separator() + + if failed_step: + log.error(f"Setup failed at: {failed_step}") + log.separator() + log.info("Troubleshooting:") + log.list_item("Check if the Dify API server is running (./dev/start-api)") + log.list_item("Check if the mock OpenAI server is running (port 5004)") + log.list_item("Review the error messages above") + log.list_item("Run cleanup.py and try again") + sys.exit(1) + else: + progress.complete() + log.separator() + log.success("Setup completed successfully!") + log.info("Next steps:") + log.list_item("Test the workflow:") + log.info( + ' python scripts/stress-test/setup/run_workflow.py "Your question here"', + indent=4, + ) + log.list_item("To clean up and start over:") + log.info(" python scripts/stress-test/cleanup.py", indent=4) + + # Optionally run a test + log.separator() + test_input = input("Would you like to run a test workflow now? (yes/no): ") + + if test_input.lower() in ["yes", "y"]: + log.step("Running test workflow...") + run_script("run_workflow.py", "Testing workflow with default question") + + +if __name__ == "__main__": + main() diff --git a/scripts/stress-test/sse_benchmark.py b/scripts/stress-test/sse_benchmark.py new file mode 100644 index 0000000000..99fe2b20f4 --- /dev/null +++ b/scripts/stress-test/sse_benchmark.py @@ -0,0 +1,750 @@ +#!/usr/bin/env python3 +""" +SSE (Server-Sent Events) Stress Test for Dify Workflow API + +This script stress tests the streaming performance of Dify's workflow execution API, +measuring key metrics like connection rate, event throughput, and time to first event (TTFE). +""" + +import json +import logging +import os +import random +import statistics +import sys +import threading +import time +from collections import deque +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Literal, TypeAlias, TypedDict + +import requests.exceptions +from locust import HttpUser, between, constant, events, task + +# Add the stress-test directory to path to import common modules +sys.path.insert(0, str(Path(__file__).parent)) +from common.config_helper import ConfigHelper # type: ignore[import-not-found] + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +# Configuration from environment +WORKFLOW_PATH = os.getenv("WORKFLOW_PATH", "/v1/workflows/run") +CONNECT_TIMEOUT = float(os.getenv("CONNECT_TIMEOUT", "10")) +READ_TIMEOUT = float(os.getenv("READ_TIMEOUT", "60")) +TERMINAL_EVENTS = [e.strip() for e in os.getenv("TERMINAL_EVENTS", "workflow_finished,error").split(",") if e.strip()] +QUESTIONS_FILE = os.getenv("QUESTIONS_FILE", "") + + +# Type definitions +ErrorType: TypeAlias = Literal[ + "connection_error", + "timeout", + "invalid_json", + "http_4xx", + "http_5xx", + "early_termination", + "invalid_response", +] + + +class ErrorCounts(TypedDict): + """Error count tracking""" + + connection_error: int + timeout: int + invalid_json: int + http_4xx: int + http_5xx: int + early_termination: int + invalid_response: int + + +class SSEEvent(TypedDict): + """Server-Sent Event structure""" + + data: str + event: str + id: str | None + + +class WorkflowInputs(TypedDict): + """Workflow input structure""" + + question: str + + +class WorkflowRequestData(TypedDict): + """Workflow request payload""" + + inputs: WorkflowInputs + response_mode: Literal["streaming"] + user: str + + +class ParsedEventData(TypedDict, total=False): + """Parsed event data from SSE stream""" + + event: str + task_id: str + workflow_run_id: str + data: object # For dynamic content + created_at: int + + +class LocustStats(TypedDict): + """Locust statistics structure""" + + total_requests: int + total_failures: int + avg_response_time: float + min_response_time: float + max_response_time: float + + +class ReportData(TypedDict): + """JSON report structure""" + + timestamp: str + duration_seconds: float + metrics: dict[str, object] # Metrics as dict for JSON serialization + locust_stats: LocustStats | None + + +@dataclass +class StreamMetrics: + """Metrics for a single stream""" + + stream_duration: float + events_count: int + bytes_received: int + ttfe: float + inter_event_times: list[float] + + +@dataclass +class MetricsSnapshot: + """Snapshot of current metrics state""" + + active_connections: int + total_connections: int + total_events: int + connection_rate: float + event_rate: float + overall_conn_rate: float + overall_event_rate: float + ttfe_avg: float + ttfe_min: float + ttfe_max: float + ttfe_p50: float + ttfe_p95: float + ttfe_samples: int + ttfe_total_samples: int # Total TTFE samples collected (not limited by window) + error_counts: ErrorCounts + stream_duration_avg: float + stream_duration_p50: float + stream_duration_p95: float + events_per_stream_avg: float + inter_event_latency_avg: float + inter_event_latency_p50: float + inter_event_latency_p95: float + + +class MetricsTracker: + def __init__(self) -> None: + self.lock = threading.Lock() + self.active_connections = 0 + self.total_connections = 0 + self.total_events = 0 + self.start_time = time.time() + + # Enhanced metrics with memory limits + self.max_samples = 10000 # Prevent unbounded growth + self.ttfe_samples: deque[float] = deque(maxlen=self.max_samples) + self.ttfe_total_count = 0 # Track total TTFE samples collected + + # For rate calculations - no maxlen to avoid artificial limits + self.connection_times: deque[float] = deque() + self.event_times: deque[float] = deque() + self.last_stats_time = time.time() + self.last_total_connections = 0 + self.last_total_events = 0 + self.stream_metrics: deque[StreamMetrics] = deque(maxlen=self.max_samples) + self.error_counts: ErrorCounts = ErrorCounts( + connection_error=0, + timeout=0, + invalid_json=0, + http_4xx=0, + http_5xx=0, + early_termination=0, + invalid_response=0, + ) + + def connection_started(self) -> None: + with self.lock: + self.active_connections += 1 + self.total_connections += 1 + self.connection_times.append(time.time()) + + def connection_ended(self) -> None: + with self.lock: + self.active_connections -= 1 + + def event_received(self) -> None: + with self.lock: + self.total_events += 1 + self.event_times.append(time.time()) + + def record_ttfe(self, ttfe_ms: float) -> None: + with self.lock: + self.ttfe_samples.append(ttfe_ms) # deque handles maxlen + self.ttfe_total_count += 1 # Increment total counter + + def record_stream_metrics(self, metrics: StreamMetrics) -> None: + with self.lock: + self.stream_metrics.append(metrics) # deque handles maxlen + + def record_error(self, error_type: ErrorType) -> None: + with self.lock: + self.error_counts[error_type] += 1 + + def get_stats(self) -> MetricsSnapshot: + with self.lock: + current_time = time.time() + time_window = 10.0 # 10 second window for rate calculation + + # Clean up old timestamps outside the window + cutoff_time = current_time - time_window + while self.connection_times and self.connection_times[0] < cutoff_time: + self.connection_times.popleft() + while self.event_times and self.event_times[0] < cutoff_time: + self.event_times.popleft() + + # Calculate rates based on actual window or elapsed time + window_duration = min(time_window, current_time - self.start_time) + if window_duration > 0: + conn_rate = len(self.connection_times) / window_duration + event_rate = len(self.event_times) / window_duration + else: + conn_rate = 0 + event_rate = 0 + + # Calculate TTFE statistics + if self.ttfe_samples: + avg_ttfe = statistics.mean(self.ttfe_samples) + min_ttfe = min(self.ttfe_samples) + max_ttfe = max(self.ttfe_samples) + p50_ttfe = statistics.median(self.ttfe_samples) + if len(self.ttfe_samples) >= 2: + quantiles = statistics.quantiles(self.ttfe_samples, n=20, method="inclusive") + p95_ttfe = quantiles[18] # 19th of 19 quantiles = 95th percentile + else: + p95_ttfe = max_ttfe + else: + avg_ttfe = min_ttfe = max_ttfe = p50_ttfe = p95_ttfe = 0 + + # Calculate stream metrics + if self.stream_metrics: + durations = [m.stream_duration for m in self.stream_metrics] + events_per_stream = [m.events_count for m in self.stream_metrics] + stream_duration_avg = statistics.mean(durations) + stream_duration_p50 = statistics.median(durations) + stream_duration_p95 = ( + statistics.quantiles(durations, n=20, method="inclusive")[18] + if len(durations) >= 2 + else max(durations) + if durations + else 0 + ) + events_per_stream_avg = statistics.mean(events_per_stream) if events_per_stream else 0 + + # Calculate inter-event latency statistics + all_inter_event_times = [] + for m in self.stream_metrics: + all_inter_event_times.extend(m.inter_event_times) + + if all_inter_event_times: + inter_event_latency_avg = statistics.mean(all_inter_event_times) + inter_event_latency_p50 = statistics.median(all_inter_event_times) + inter_event_latency_p95 = ( + statistics.quantiles(all_inter_event_times, n=20, method="inclusive")[18] + if len(all_inter_event_times) >= 2 + else max(all_inter_event_times) + ) + else: + inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0 + else: + stream_duration_avg = stream_duration_p50 = stream_duration_p95 = events_per_stream_avg = 0 + inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0 + + # Also calculate overall average rates + total_elapsed = current_time - self.start_time + overall_conn_rate = self.total_connections / total_elapsed if total_elapsed > 0 else 0 + overall_event_rate = self.total_events / total_elapsed if total_elapsed > 0 else 0 + + return MetricsSnapshot( + active_connections=self.active_connections, + total_connections=self.total_connections, + total_events=self.total_events, + connection_rate=conn_rate, + event_rate=event_rate, + overall_conn_rate=overall_conn_rate, + overall_event_rate=overall_event_rate, + ttfe_avg=avg_ttfe, + ttfe_min=min_ttfe, + ttfe_max=max_ttfe, + ttfe_p50=p50_ttfe, + ttfe_p95=p95_ttfe, + ttfe_samples=len(self.ttfe_samples), + ttfe_total_samples=self.ttfe_total_count, # Return total count + error_counts=ErrorCounts(**self.error_counts), + stream_duration_avg=stream_duration_avg, + stream_duration_p50=stream_duration_p50, + stream_duration_p95=stream_duration_p95, + events_per_stream_avg=events_per_stream_avg, + inter_event_latency_avg=inter_event_latency_avg, + inter_event_latency_p50=inter_event_latency_p50, + inter_event_latency_p95=inter_event_latency_p95, + ) + + +# Global metrics instance +metrics = MetricsTracker() + + +class SSEParser: + """Parser for Server-Sent Events according to W3C spec""" + + def __init__(self) -> None: + self.data_buffer: list[str] = [] + self.event_type: str | None = None + self.event_id: str | None = None + + def parse_line(self, line: str) -> SSEEvent | None: + """Parse a single SSE line and return event if complete""" + # Empty line signals end of event + if not line: + if self.data_buffer: + event = SSEEvent( + data="\n".join(self.data_buffer), + event=self.event_type or "message", + id=self.event_id, + ) + self.data_buffer = [] + self.event_type = None + self.event_id = None + return event + return None + + # Comment line + if line.startswith(":"): + return None + + # Parse field + if ":" in line: + field, value = line.split(":", 1) + value = value.lstrip() + + if field == "data": + self.data_buffer.append(value) + elif field == "event": + self.event_type = value + elif field == "id": + self.event_id = value + + return None + + +# Note: SSEClient removed - we'll handle SSE parsing directly in the task for better Locust integration + + +class DifyWorkflowUser(HttpUser): + """Locust user for testing Dify workflow SSE endpoints""" + + # Use constant wait for streaming workloads + wait_time = constant(0) if os.getenv("WAIT_TIME", "0") == "0" else between(1, 3) + + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) # type: ignore[arg-type] + + # Load API configuration + config_helper = ConfigHelper() + self.api_token = config_helper.get_api_key() + + if not self.api_token: + raise ValueError("API key not found. Please run setup_all.py first.") + + # Load questions from file or use defaults + if QUESTIONS_FILE and os.path.exists(QUESTIONS_FILE): + with open(QUESTIONS_FILE) as f: + self.questions = [line.strip() for line in f if line.strip()] + else: + self.questions = [ + "What is artificial intelligence?", + "Explain quantum computing", + "What is machine learning?", + "How do neural networks work?", + "What is renewable energy?", + ] + + self.user_counter = 0 + + def on_start(self) -> None: + """Called when a user starts""" + self.user_counter = 0 + + @task + def test_workflow_stream(self) -> None: + """Test workflow SSE streaming endpoint""" + + question = random.choice(self.questions) + self.user_counter += 1 + + headers = { + "Authorization": f"Bearer {self.api_token}", + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + data = WorkflowRequestData( + inputs=WorkflowInputs(question=question), + response_mode="streaming", + user=f"user_{self.user_counter}", + ) + + start_time = time.time() + first_event_time = None + event_count = 0 + inter_event_times: list[float] = [] + last_event_time = None + ttfe = 0 + request_success = False + bytes_received = 0 + + metrics.connection_started() + + # Use catch_response context manager directly + with self.client.request( + method="POST", + url=WORKFLOW_PATH, + headers=headers, + json=data, + stream=True, + catch_response=True, + timeout=(CONNECT_TIMEOUT, READ_TIMEOUT), + name="/v1/workflows/run", # Name for Locust stats + ) as response: + try: + # Validate response + if response.status_code >= 400: + error_type: ErrorType = "http_4xx" if response.status_code < 500 else "http_5xx" + metrics.record_error(error_type) + response.failure(f"HTTP {response.status_code}") + return + + content_type = response.headers.get("Content-Type", "") + if "text/event-stream" not in content_type and "application/json" not in content_type: + logger.error(f"Expected text/event-stream, got: {content_type}") + metrics.record_error("invalid_response") + response.failure(f"Invalid content type: {content_type}") + return + + # Parse SSE events + parser = SSEParser() + + for line in response.iter_lines(decode_unicode=True): + # Check if runner is stopping + if getattr(self.environment.runner, "state", "") in ( + "stopping", + "stopped", + ): + logger.debug("Runner stopping, breaking streaming loop") + break + + if line is not None: + bytes_received += len(line.encode("utf-8")) + + # Parse SSE line + event = parser.parse_line(line if line is not None else "") + if event: + event_count += 1 + current_time = time.time() + metrics.event_received() + + # Track inter-event timing + if last_event_time: + inter_event_times.append((current_time - last_event_time) * 1000) + last_event_time = current_time + + if first_event_time is None: + first_event_time = current_time + ttfe = (first_event_time - start_time) * 1000 + metrics.record_ttfe(ttfe) + + try: + # Parse event data + event_data = event.get("data", "") + if event_data: + if event_data == "[DONE]": + logger.debug("Received [DONE] sentinel") + request_success = True + break + + try: + parsed_event: ParsedEventData = json.loads(event_data) + # Check for terminal events + if parsed_event.get("event") in TERMINAL_EVENTS: + logger.debug(f"Received terminal event: {parsed_event.get('event')}") + request_success = True + break + except json.JSONDecodeError as e: + logger.debug(f"JSON decode error: {e} for data: {event_data[:100]}") + metrics.record_error("invalid_json") + + except Exception as e: + logger.error(f"Error processing event: {e}") + + # Mark success only if terminal condition was met or events were received + if request_success: + response.success() + elif event_count > 0: + # Got events but no proper terminal condition + metrics.record_error("early_termination") + response.failure("Stream ended without terminal event") + else: + response.failure("No events received") + + except ( + requests.exceptions.ConnectTimeout, + requests.exceptions.ReadTimeout, + ) as e: + metrics.record_error("timeout") + response.failure(f"Timeout: {e}") + except ( + requests.exceptions.ConnectionError, + requests.exceptions.RequestException, + ) as e: + metrics.record_error("connection_error") + response.failure(f"Connection error: {e}") + except Exception as e: + response.failure(str(e)) + raise + finally: + metrics.connection_ended() + + # Record stream metrics + if event_count > 0: + stream_duration = (time.time() - start_time) * 1000 + stream_metrics = StreamMetrics( + stream_duration=stream_duration, + events_count=event_count, + bytes_received=bytes_received, + ttfe=ttfe, + inter_event_times=inter_event_times, + ) + metrics.record_stream_metrics(stream_metrics) + logger.debug( + f"Stream completed: {event_count} events, {stream_duration:.1f}ms, success={request_success}" + ) + else: + logger.warning("No events received in stream") + + +# Event handlers +@events.test_start.add_listener # type: ignore[misc] +def on_test_start(environment: object, **kwargs: object) -> None: + logger.info("=" * 80) + logger.info(" " * 25 + "DIFY SSE BENCHMARK - REAL-TIME METRICS") + logger.info("=" * 80) + logger.info(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + logger.info("=" * 80) + + # Periodic stats reporting + def report_stats() -> None: + if not hasattr(environment, "runner"): + return + runner = environment.runner + while hasattr(runner, "state") and runner.state not in ["stopped", "stopping"]: + time.sleep(5) # Report every 5 seconds + if hasattr(runner, "state") and runner.state == "running": + stats = metrics.get_stats() + + # Only log on master node in distributed mode + is_master = ( + not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True + ) + if is_master: + # Clear previous lines and show updated stats + logger.info("\n" + "=" * 80) + logger.info( + f"{'METRIC':<25} {'CURRENT':>15} {'RATE (10s)':>15} {'AVG (overall)':>15} {'TOTAL':>12}" + ) + logger.info("-" * 80) + + # Active SSE Connections + logger.info( + f"{'Active SSE Connections':<25} {stats.active_connections:>15,d} {'-':>15} {'-':>12} {'-':>12}" + ) + + # New Connection Rate + logger.info( + f"{'New Connections':<25} {'-':>15} {stats.connection_rate:>13.2f}/s {stats.overall_conn_rate:>13.2f}/s {stats.total_connections:>12,d}" + ) + + # Event Throughput + logger.info( + f"{'Event Throughput':<25} {'-':>15} {stats.event_rate:>13.2f}/s {stats.overall_event_rate:>13.2f}/s {stats.total_events:>12,d}" + ) + + logger.info("-" * 80) + logger.info( + f"{'TIME TO FIRST EVENT':<25} {'AVG':>15} {'P50':>10} {'P95':>10} {'MIN':>10} {'MAX':>10}" + ) + logger.info( + f"{'(TTFE in ms)':<25} {stats.ttfe_avg:>15.1f} {stats.ttfe_p50:>10.1f} {stats.ttfe_p95:>10.1f} {stats.ttfe_min:>10.1f} {stats.ttfe_max:>10.1f}" + ) + logger.info( + f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)" + ) + logger.info(f"{'Total Samples':<25} {stats.ttfe_total_samples:>15,d}") + + # Inter-event latency + if stats.inter_event_latency_avg > 0: + logger.info("-" * 80) + logger.info(f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}") + logger.info( + f"{'(ms between events)':<25} {stats.inter_event_latency_avg:>15.1f} {stats.inter_event_latency_p50:>10.1f} {stats.inter_event_latency_p95:>10.1f}" + ) + + # Error stats + if any(stats.error_counts.values()): + logger.info("-" * 80) + logger.info(f"{'ERROR TYPE':<25} {'COUNT':>15}") + for error_type, count in stats.error_counts.items(): + if isinstance(count, int) and count > 0: + logger.info(f"{error_type:<25} {count:>15,d}") + + logger.info("=" * 80) + + # Show Locust stats summary + if hasattr(environment, "stats") and hasattr(environment.stats, "total"): + total = environment.stats.total + if hasattr(total, "num_requests") and total.num_requests > 0: + logger.info( + f"{'LOCUST STATS':<25} {'Requests':>12} {'Fails':>8} {'Avg (ms)':>12} {'Min':>8} {'Max':>8}" + ) + logger.info("-" * 80) + logger.info( + f"{'Aggregated':<25} {total.num_requests:>12,d} " + f"{total.num_failures:>8,d} " + f"{total.avg_response_time:>12.1f} " + f"{total.min_response_time:>8.0f} " + f"{total.max_response_time:>8.0f}" + ) + logger.info("=" * 80) + + threading.Thread(target=report_stats, daemon=True).start() + + +@events.test_stop.add_listener # type: ignore[misc] +def on_test_stop(environment: object, **kwargs: object) -> None: + stats = metrics.get_stats() + test_duration = time.time() - metrics.start_time + + # Log final results + logger.info("\n" + "=" * 80) + logger.info(" " * 30 + "FINAL BENCHMARK RESULTS") + logger.info("=" * 80) + logger.info(f"Test Duration: {test_duration:.1f} seconds") + logger.info("-" * 80) + + logger.info("") + logger.info("CONNECTIONS") + logger.info(f" {'Total Connections:':<30} {stats.total_connections:>10,d}") + logger.info(f" {'Final Active:':<30} {stats.active_connections:>10,d}") + logger.info(f" {'Average Rate:':<30} {stats.overall_conn_rate:>10.2f} conn/s") + + logger.info("") + logger.info("EVENTS") + logger.info(f" {'Total Events Received:':<30} {stats.total_events:>10,d}") + logger.info(f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s") + logger.info(f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s") + + logger.info("") + logger.info("STREAM METRICS") + logger.info(f" {'Avg Stream Duration:':<30} {stats.stream_duration_avg:>10.1f} ms") + logger.info(f" {'P50 Stream Duration:':<30} {stats.stream_duration_p50:>10.1f} ms") + logger.info(f" {'P95 Stream Duration:':<30} {stats.stream_duration_p95:>10.1f} ms") + logger.info(f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}") + + logger.info("") + logger.info("INTER-EVENT LATENCY") + logger.info(f" {'Average:':<30} {stats.inter_event_latency_avg:>10.1f} ms") + logger.info(f" {'Median (P50):':<30} {stats.inter_event_latency_p50:>10.1f} ms") + logger.info(f" {'95th Percentile:':<30} {stats.inter_event_latency_p95:>10.1f} ms") + + logger.info("") + logger.info("TIME TO FIRST EVENT (ms)") + logger.info(f" {'Average:':<30} {stats.ttfe_avg:>10.1f} ms") + logger.info(f" {'Median (P50):':<30} {stats.ttfe_p50:>10.1f} ms") + logger.info(f" {'95th Percentile:':<30} {stats.ttfe_p95:>10.1f} ms") + logger.info(f" {'Minimum:':<30} {stats.ttfe_min:>10.1f} ms") + logger.info(f" {'Maximum:':<30} {stats.ttfe_max:>10.1f} ms") + logger.info( + f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})" + ) + logger.info(f" {'Total Samples:':<30} {stats.ttfe_total_samples:>10,d}") + + # Error summary + if any(stats.error_counts.values()): + logger.info("") + logger.info("ERRORS") + for error_type, count in stats.error_counts.items(): + if isinstance(count, int) and count > 0: + logger.info(f" {error_type:<30} {count:>10,d}") + + logger.info("=" * 80 + "\n") + + # Export machine-readable report (only on master node) + is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True + if is_master: + export_json_report(stats, test_duration, environment) + + +def export_json_report(stats: MetricsSnapshot, duration: float, environment: object) -> None: + """Export metrics to JSON file for CI/CD analysis""" + + reports_dir = Path(__file__).parent / "reports" + reports_dir.mkdir(exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + report_file = reports_dir / f"sse_metrics_{timestamp}.json" + + # Access environment.stats.total attributes safely + locust_stats: LocustStats | None = None + if hasattr(environment, "stats") and hasattr(environment.stats, "total"): + total = environment.stats.total + if hasattr(total, "num_requests") and total.num_requests > 0: + locust_stats = LocustStats( + total_requests=total.num_requests, + total_failures=total.num_failures, + avg_response_time=total.avg_response_time, + min_response_time=total.min_response_time, + max_response_time=total.max_response_time, + ) + + report_data = ReportData( + timestamp=datetime.now().isoformat(), + duration_seconds=duration, + metrics=asdict(stats), # type: ignore[arg-type] + locust_stats=locust_stats, + ) + + with open(report_file, "w") as f: + json.dump(report_data, f, indent=2) + + logger.info(f"Exported metrics to {report_file}") diff --git a/sdks/nodejs-client/index.js b/sdks/nodejs-client/index.js index 0ba7bba8bb..3025cc2ab6 100644 --- a/sdks/nodejs-client/index.js +++ b/sdks/nodejs-client/index.js @@ -95,10 +95,9 @@ export class DifyClient { headerParams = {} ) { const headers = { - ...{ + Authorization: `Bearer ${this.apiKey}`, "Content-Type": "application/json", - }, ...headerParams }; diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py index d00c207afa..e866472f45 100644 --- a/sdks/python-client/dify_client/__init__.py +++ b/sdks/python-client/dify_client/__init__.py @@ -1,7 +1,15 @@ from dify_client.client import ( ChatClient, CompletionClient, - WorkflowClient, - KnowledgeBaseClient, DifyClient, + KnowledgeBaseClient, + WorkflowClient, ) + +__all__ = [ + "ChatClient", + "CompletionClient", + "DifyClient", + "KnowledgeBaseClient", + "WorkflowClient", +] diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index abd0e7ae29..791cb98a1b 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -1,5 +1,5 @@ import json - +from typing import Literal import requests @@ -8,16 +8,16 @@ class DifyClient: self.api_key = api_key self.base_url = base_url - def _send_request(self, method, endpoint, json=None, params=None, stream=False): + def _send_request( + self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False + ): headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } url = f"{self.base_url}{endpoint}" - response = requests.request( - method, url, json=json, params=params, headers=headers, stream=stream - ) + response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream) return response @@ -25,37 +25,35 @@ class DifyClient: headers = {"Authorization": f"Bearer {self.api_key}"} url = f"{self.base_url}{endpoint}" - response = requests.request( - method, url, data=data, headers=headers, files=files - ) + response = requests.request(method, url, data=data, headers=headers, files=files) return response - def message_feedback(self, message_id, rating, user): + def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): data = {"rating": rating, "user": user} return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - def get_application_parameters(self, user): + def get_application_parameters(self, user: str): params = {"user": user} return self._send_request("GET", "/parameters", params=params) - def file_upload(self, user, files): + def file_upload(self, user: str, files: dict): data = {"user": user} - return self._send_request_with_files( - "POST", "/files/upload", data=data, files=files - ) + return self._send_request_with_files("POST", "/files/upload", data=data, files=files) def text_to_audio(self, text: str, user: str, streaming: bool = False): data = {"text": text, "user": user, "streaming": streaming} return self._send_request("POST", "/text-to-audio", json=data) - def get_meta(self, user): + def get_meta(self, user: str): params = {"user": user} return self._send_request("GET", "/meta", params=params) class CompletionClient(DifyClient): - def create_completion_message(self, inputs, response_mode, user, files=None): + def create_completion_message( + self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None + ): data = { "inputs": inputs, "response_mode": response_mode, @@ -76,7 +74,7 @@ class ChatClient(DifyClient): inputs: dict, query: str, user: str, - response_mode: str = "blocking", + response_mode: Literal["blocking", "streaming"] = "blocking", conversation_id: str | None = None, files: dict | None = None, ): @@ -99,9 +97,7 @@ class ChatClient(DifyClient): def get_suggested(self, message_id: str, user: str): params = {"user": user} - return self._send_request( - "GET", f"/messages/{message_id}/suggested", params=params - ) + return self._send_request("GET", f"/messages/{message_id}/suggested", params=params) def stop_message(self, task_id: str, user: str): data = {"user": user} @@ -112,10 +108,9 @@ class ChatClient(DifyClient): user: str, last_id: str | None = None, limit: int | None = None, - pinned: bool | None = None + pinned: bool | None = None, ): - params = {"user": user, "last_id": last_id, - "limit": limit, "pinned": pinned} + params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} return self._send_request("GET", "/conversations", params=params) def get_conversation_messages( @@ -123,7 +118,7 @@ class ChatClient(DifyClient): user: str, conversation_id: str | None = None, first_id: str | None = None, - limit: int | None = None + limit: int | None = None, ): params = {"user": user} @@ -136,13 +131,9 @@ class ChatClient(DifyClient): return self._send_request("GET", "/messages", params=params) - def rename_conversation( - self, conversation_id: str, name: str, auto_generate: bool, user: str - ): + def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str): data = {"name": name, "auto_generate": auto_generate, "user": user} - return self._send_request( - "POST", f"/conversations/{conversation_id}/name", data - ) + return self._send_request("POST", f"/conversations/{conversation_id}/name", data) def delete_conversation(self, conversation_id: str, user: str): data = {"user": user} @@ -155,9 +146,7 @@ class ChatClient(DifyClient): class WorkflowClient(DifyClient): - def run( - self, inputs: dict, response_mode: str = "streaming", user: str = "abc-123" - ): + def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"): data = {"inputs": inputs, "response_mode": response_mode, "user": user} return self._send_request("POST", "/workflows/run", data) @@ -172,7 +161,7 @@ class WorkflowClient(DifyClient): class KnowledgeBaseClient(DifyClient): def __init__( self, - api_key, + api_key: str, base_url: str = "https://api.dify.ai/v1", dataset_id: str | None = None, ): @@ -197,13 +186,9 @@ class KnowledgeBaseClient(DifyClient): return self._send_request("POST", "/datasets", {"name": name}, **kwargs) def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): - return self._send_request( - "GET", f"/datasets?page={page}&limit={page_size}", **kwargs - ) + return self._send_request("GET", f"/datasets?page={page}&limit={page_size}", **kwargs) - def create_document_by_text( - self, name, text, extra_params: dict | None = None, **kwargs - ): + def create_document_by_text(self, name, text, extra_params: dict | None = None, **kwargs): """ Create a document by text. @@ -241,7 +226,7 @@ class KnowledgeBaseClient(DifyClient): return self._send_request("POST", url, json=data, **kwargs) def update_document_by_text( - self, document_id, name, text, extra_params: dict | None = None, **kwargs + self, document_id: str, name: str, text: str, extra_params: dict | None = None, **kwargs ): """ Update a document by text. @@ -272,13 +257,11 @@ class KnowledgeBaseClient(DifyClient): data = {"name": name, "text": text} if extra_params is not None and isinstance(extra_params, dict): data.update(extra_params) - url = ( - f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" - ) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" return self._send_request("POST", url, json=data, **kwargs) def create_document_by_file( - self, file_path, original_document_id=None, extra_params: dict | None = None + self, file_path: str, original_document_id: str | None = None, extra_params: dict | None = None ): """ Create a document by file. @@ -315,13 +298,9 @@ class KnowledgeBaseClient(DifyClient): if original_document_id is not None: data["original_document_id"] = original_document_id url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" - return self._send_request_with_files( - "POST", url, {"data": json.dumps(data)}, files - ) + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - def update_document_by_file( - self, document_id, file_path, extra_params: dict | None = None - ): + def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None): """ Update a document by file. @@ -351,12 +330,8 @@ class KnowledgeBaseClient(DifyClient): data = {} if extra_params is not None and isinstance(extra_params, dict): data.update(extra_params) - url = ( - f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" - ) - return self._send_request_with_files( - "POST", url, {"data": json.dumps(data)}, files - ) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) def batch_indexing_status(self, batch_id: str, **kwargs): """ @@ -377,7 +352,7 @@ class KnowledgeBaseClient(DifyClient): url = f"/datasets/{self._get_dataset_id()}" return self._send_request("DELETE", url) - def delete_document(self, document_id): + def delete_document(self, document_id: str): """ Delete a document. @@ -409,7 +384,7 @@ class KnowledgeBaseClient(DifyClient): url = f"/datasets/{self._get_dataset_id()}/documents" return self._send_request("GET", url, params=params, **kwargs) - def add_segments(self, document_id, segments, **kwargs): + def add_segments(self, document_id: str, segments: list[dict], **kwargs): """ Add segments to a document. @@ -423,7 +398,7 @@ class KnowledgeBaseClient(DifyClient): def query_segments( self, - document_id, + document_id: str, keyword: str | None = None, status: str | None = None, **kwargs, @@ -445,7 +420,7 @@ class KnowledgeBaseClient(DifyClient): params.update(kwargs["params"]) return self._send_request("GET", url, params=params, **kwargs) - def delete_document_segment(self, document_id, segment_id): + def delete_document_segment(self, document_id: str, segment_id: str): """ Delete a segment from a document. @@ -456,7 +431,7 @@ class KnowledgeBaseClient(DifyClient): url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" return self._send_request("DELETE", url) - def update_document_segment(self, document_id, segment_id, segment_data, **kwargs): + def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs): """ Update a segment in a document. diff --git a/sdks/python-client/setup.py b/sdks/python-client/setup.py index 7340fffb4c..a05f6410fb 100644 --- a/sdks/python-client/setup.py +++ b/sdks/python-client/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -with open("README.md", "r", encoding="utf-8") as fh: +with open("README.md", encoding="utf-8") as fh: long_description = fh.read() setup( diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py index 52032417c0..fce1b11eba 100644 --- a/sdks/python-client/tests/test_client.py +++ b/sdks/python-client/tests/test_client.py @@ -18,9 +18,7 @@ FILE_PATH_BASE = os.path.dirname(__file__) class TestKnowledgeBaseClient(unittest.TestCase): def setUp(self): self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL) - self.README_FILE_PATH = os.path.abspath( - os.path.join(FILE_PATH_BASE, "../README.md") - ) + self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md")) self.dataset_id = None self.document_id = None self.segment_id = None @@ -28,9 +26,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _get_dataset_kb_client(self): self.assertIsNotNone(self.dataset_id) - return KnowledgeBaseClient( - API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id - ) + return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id) def test_001_create_dataset(self): response = self.knowledge_base_client.create_dataset(name="test_dataset") @@ -76,9 +72,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _test_004_update_document_by_text(self): client = self._get_dataset_kb_client() self.assertIsNotNone(self.document_id) - response = client.update_document_by_text( - self.document_id, "test_document_updated", "test_text_updated" - ) + response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated") data = response.json() self.assertIn("document", data) self.assertIn("batch", data) @@ -93,9 +87,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _test_006_update_document_by_file(self): client = self._get_dataset_kb_client() self.assertIsNotNone(self.document_id) - response = client.update_document_by_file( - self.document_id, self.README_FILE_PATH - ) + response = client.update_document_by_file(self.document_id, self.README_FILE_PATH) data = response.json() self.assertIn("document", data) self.assertIn("batch", data) @@ -125,9 +117,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _test_010_add_segments(self): client = self._get_dataset_kb_client() - response = client.add_segments( - self.document_id, [{"content": "test text segment 1"}] - ) + response = client.add_segments(self.document_id, [{"content": "test text segment 1"}]) data = response.json() self.assertIn("data", data) self.assertGreater(len(data["data"]), 0) @@ -174,18 +164,12 @@ class TestChatClient(unittest.TestCase): self.chat_client = ChatClient(API_KEY) def test_create_chat_message(self): - response = self.chat_client.create_chat_message( - {}, "Hello, World!", "test_user" - ) + response = self.chat_client.create_chat_message({}, "Hello, World!", "test_user") self.assertIn("answer", response.text) def test_create_chat_message_with_vision_model_by_remote_url(self): - files = [ - {"type": "image", "transfer_method": "remote_url", "url": "your_image_url"} - ] - response = self.chat_client.create_chat_message( - {}, "Describe the picture.", "test_user", files=files - ) + files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}] + response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) self.assertIn("answer", response.text) def test_create_chat_message_with_vision_model_by_local_file(self): @@ -196,15 +180,11 @@ class TestChatClient(unittest.TestCase): "upload_file_id": "your_file_id", } ] - response = self.chat_client.create_chat_message( - {}, "Describe the picture.", "test_user", files=files - ) + response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) self.assertIn("answer", response.text) def test_get_conversation_messages(self): - response = self.chat_client.get_conversation_messages( - "test_user", "your_conversation_id" - ) + response = self.chat_client.get_conversation_messages("test_user", "your_conversation_id") self.assertIn("answer", response.text) def test_get_conversations(self): @@ -223,9 +203,7 @@ class TestCompletionClient(unittest.TestCase): self.assertIn("answer", response.text) def test_create_completion_message_with_vision_model_by_remote_url(self): - files = [ - {"type": "image", "transfer_method": "remote_url", "url": "your_image_url"} - ] + files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}] response = self.completion_client.create_completion_message( {"query": "Describe the picture."}, "blocking", "test_user", files ) @@ -250,9 +228,7 @@ class TestDifyClient(unittest.TestCase): self.dify_client = DifyClient(API_KEY) def test_message_feedback(self): - response = self.dify_client.message_feedback( - "your_message_id", "like", "test_user" - ) + response = self.dify_client.message_feedback("your_message_id", "like", "test_user") self.assertIn("success", response.text) def test_get_application_parameters(self): diff --git a/web/.oxlintrc.json b/web/.oxlintrc.json index 1bfcca58f5..57eddd34fb 100644 --- a/web/.oxlintrc.json +++ b/web/.oxlintrc.json @@ -45,7 +45,7 @@ "no-unassigned-vars": "warn", "no-unsafe-finally": "warn", "no-unsafe-negation": "warn", - "no-unsafe-optional-chaining": "warn", + "no-unsafe-optional-chaining": "error", "no-unused-labels": "warn", "no-unused-private-class-members": "warn", "no-unused-vars": "warn", diff --git a/web/__tests__/goto-anything/match-action.test.ts b/web/__tests__/goto-anything/match-action.test.ts new file mode 100644 index 0000000000..3df9c0d533 --- /dev/null +++ b/web/__tests__/goto-anything/match-action.test.ts @@ -0,0 +1,235 @@ +import type { ActionItem } from '../../app/components/goto-anything/actions/types' + +// Mock the entire actions module to avoid import issues +jest.mock('../../app/components/goto-anything/actions', () => ({ + matchAction: jest.fn(), +})) + +jest.mock('../../app/components/goto-anything/actions/commands/registry') + +// Import after mocking to get mocked version +import { matchAction } from '../../app/components/goto-anything/actions' +import { slashCommandRegistry } from '../../app/components/goto-anything/actions/commands/registry' + +// Implement the actual matchAction logic for testing +const actualMatchAction = (query: string, actions: Record) => { + const result = Object.values(actions).find((action) => { + // Special handling for slash commands + if (action.key === '/') { + // Get all registered commands from the registry + const allCommands = slashCommandRegistry.getAllCommands() + + // Check if query matches any registered command + return allCommands.some((cmd) => { + const cmdPattern = `/${cmd.name}` + + // For direct mode commands, don't match (keep in command selector) + if (cmd.mode === 'direct') + return false + + // For submenu mode commands, match when complete command is entered + return query === cmdPattern || query.startsWith(`${cmdPattern} `) + }) + } + + const reg = new RegExp(`^(${action.key}|${action.shortcut})(?:\\s|$)`) + return reg.test(query) + }) + return result +} + +// Replace mock with actual implementation +;(matchAction as jest.Mock).mockImplementation(actualMatchAction) + +describe('matchAction Logic', () => { + const mockActions: Record = { + app: { + key: '@app', + shortcut: '@a', + title: 'Search Applications', + description: 'Search apps', + search: jest.fn(), + }, + knowledge: { + key: '@knowledge', + shortcut: '@kb', + title: 'Search Knowledge', + description: 'Search knowledge bases', + search: jest.fn(), + }, + slash: { + key: '/', + shortcut: '/', + title: 'Commands', + description: 'Execute commands', + search: jest.fn(), + }, + } + + beforeEach(() => { + jest.clearAllMocks() + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'docs', mode: 'direct' }, + { name: 'community', mode: 'direct' }, + { name: 'feedback', mode: 'direct' }, + { name: 'account', mode: 'direct' }, + { name: 'theme', mode: 'submenu' }, + { name: 'language', mode: 'submenu' }, + ]) + }) + + describe('@ Actions Matching', () => { + it('should match @app with key', () => { + const result = matchAction('@app', mockActions) + expect(result).toBe(mockActions.app) + }) + + it('should match @app with shortcut', () => { + const result = matchAction('@a', mockActions) + expect(result).toBe(mockActions.app) + }) + + it('should match @knowledge with key', () => { + const result = matchAction('@knowledge', mockActions) + expect(result).toBe(mockActions.knowledge) + }) + + it('should match @knowledge with shortcut @kb', () => { + const result = matchAction('@kb', mockActions) + expect(result).toBe(mockActions.knowledge) + }) + + it('should match with text after action', () => { + const result = matchAction('@app search term', mockActions) + expect(result).toBe(mockActions.app) + }) + + it('should not match partial @ actions', () => { + const result = matchAction('@ap', mockActions) + expect(result).toBeUndefined() + }) + }) + + describe('Slash Commands Matching', () => { + describe('Direct Mode Commands', () => { + it('should not match direct mode commands', () => { + const result = matchAction('/docs', mockActions) + expect(result).toBeUndefined() + }) + + it('should not match direct mode with arguments', () => { + const result = matchAction('/docs something', mockActions) + expect(result).toBeUndefined() + }) + + it('should not match any direct mode command', () => { + expect(matchAction('/community', mockActions)).toBeUndefined() + expect(matchAction('/feedback', mockActions)).toBeUndefined() + expect(matchAction('/account', mockActions)).toBeUndefined() + }) + }) + + describe('Submenu Mode Commands', () => { + it('should match submenu mode commands exactly', () => { + const result = matchAction('/theme', mockActions) + expect(result).toBe(mockActions.slash) + }) + + it('should match submenu mode with arguments', () => { + const result = matchAction('/theme dark', mockActions) + expect(result).toBe(mockActions.slash) + }) + + it('should match all submenu commands', () => { + expect(matchAction('/language', mockActions)).toBe(mockActions.slash) + expect(matchAction('/language en', mockActions)).toBe(mockActions.slash) + }) + }) + + describe('Slash Without Command', () => { + it('should not match single slash', () => { + const result = matchAction('/', mockActions) + expect(result).toBeUndefined() + }) + + it('should not match unregistered commands', () => { + const result = matchAction('/unknown', mockActions) + expect(result).toBeUndefined() + }) + }) + }) + + describe('Edge Cases', () => { + it('should handle empty query', () => { + const result = matchAction('', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle whitespace only', () => { + const result = matchAction(' ', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle regular text without actions', () => { + const result = matchAction('search something', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle special characters', () => { + const result = matchAction('#tag', mockActions) + expect(result).toBeUndefined() + }) + + it('should handle multiple @ or /', () => { + expect(matchAction('@@app', mockActions)).toBeUndefined() + expect(matchAction('//theme', mockActions)).toBeUndefined() + }) + }) + + describe('Mode-based Filtering', () => { + it('should filter direct mode commands from matching', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'test', mode: 'direct' }, + ]) + + const result = matchAction('/test', mockActions) + expect(result).toBeUndefined() + }) + + it('should allow submenu mode commands to match', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'test', mode: 'submenu' }, + ]) + + const result = matchAction('/test', mockActions) + expect(result).toBe(mockActions.slash) + }) + + it('should treat undefined mode as submenu', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([ + { name: 'test' }, // No mode specified + ]) + + const result = matchAction('/test', mockActions) + expect(result).toBe(mockActions.slash) + }) + }) + + describe('Registry Integration', () => { + it('should call getAllCommands when matching slash', () => { + matchAction('/theme', mockActions) + expect(slashCommandRegistry.getAllCommands).toHaveBeenCalled() + }) + + it('should not call getAllCommands for @ actions', () => { + matchAction('@app', mockActions) + expect(slashCommandRegistry.getAllCommands).not.toHaveBeenCalled() + }) + + it('should handle empty command list', () => { + ;(slashCommandRegistry.getAllCommands as jest.Mock).mockReturnValue([]) + const result = matchAction('/anything', mockActions) + expect(result).toBeUndefined() + }) + }) +}) diff --git a/web/__tests__/goto-anything/scope-command-tags.test.tsx b/web/__tests__/goto-anything/scope-command-tags.test.tsx new file mode 100644 index 0000000000..339e259a06 --- /dev/null +++ b/web/__tests__/goto-anything/scope-command-tags.test.tsx @@ -0,0 +1,134 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import '@testing-library/jest-dom' + +// Type alias for search mode +type SearchMode = 'scopes' | 'commands' | null + +// Mock component to test tag display logic +const TagDisplay: React.FC<{ searchMode: SearchMode }> = ({ searchMode }) => { + if (!searchMode) return null + + return ( +
+ {searchMode === 'scopes' ? 'SCOPES' : 'COMMANDS'} +
+ ) +} + +describe('Scope and Command Tags', () => { + describe('Tag Display Logic', () => { + it('should display SCOPES for @ actions', () => { + render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + expect(screen.queryByText('COMMANDS')).not.toBeInTheDocument() + }) + + it('should display COMMANDS for / actions', () => { + render() + expect(screen.getByText('COMMANDS')).toBeInTheDocument() + expect(screen.queryByText('SCOPES')).not.toBeInTheDocument() + }) + + it('should not display any tag when searchMode is null', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + }) + + describe('Search Mode Detection', () => { + const getSearchMode = (query: string): SearchMode => { + if (query.startsWith('@')) return 'scopes' + if (query.startsWith('/')) return 'commands' + return null + } + + it('should detect scopes mode for @ queries', () => { + expect(getSearchMode('@app')).toBe('scopes') + expect(getSearchMode('@knowledge')).toBe('scopes') + expect(getSearchMode('@plugin')).toBe('scopes') + expect(getSearchMode('@node')).toBe('scopes') + }) + + it('should detect commands mode for / queries', () => { + expect(getSearchMode('/theme')).toBe('commands') + expect(getSearchMode('/language')).toBe('commands') + expect(getSearchMode('/docs')).toBe('commands') + }) + + it('should return null for regular queries', () => { + expect(getSearchMode('')).toBe(null) + expect(getSearchMode('search term')).toBe(null) + expect(getSearchMode('app')).toBe(null) + }) + + it('should handle queries with spaces', () => { + expect(getSearchMode('@app search')).toBe('scopes') + expect(getSearchMode('/theme dark')).toBe('commands') + }) + }) + + describe('Tag Styling', () => { + it('should apply correct styling classes', () => { + const { container } = render() + const tagContainer = container.querySelector('.flex.items-center.gap-1.text-xs.text-text-tertiary') + expect(tagContainer).toBeInTheDocument() + }) + + it('should use hardcoded English text', () => { + // Verify that tags are hardcoded and not using i18n + render() + const scopesText = screen.getByText('SCOPES') + expect(scopesText.textContent).toBe('SCOPES') + + render() + const commandsText = screen.getByText('COMMANDS') + expect(commandsText.textContent).toBe('COMMANDS') + }) + }) + + describe('Integration with Search States', () => { + const SearchComponent: React.FC<{ query: string }> = ({ query }) => { + let searchMode: SearchMode = null + + if (query.startsWith('@')) searchMode = 'scopes' + else if (query.startsWith('/')) searchMode = 'commands' + + return ( +
+ + +
+ ) + } + + it('should update tag when switching between @ and /', () => { + const { rerender } = render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.queryByText('SCOPES')).not.toBeInTheDocument() + expect(screen.getByText('COMMANDS')).toBeInTheDocument() + }) + + it('should hide tag when clearing search', () => { + const { rerender } = render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.queryByText('SCOPES')).not.toBeInTheDocument() + expect(screen.queryByText('COMMANDS')).not.toBeInTheDocument() + }) + + it('should maintain correct tag during search refinement', () => { + const { rerender } = render() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + + rerender() + expect(screen.getByText('SCOPES')).toBeInTheDocument() + }) + }) +}) diff --git a/web/__tests__/goto-anything/slash-command-modes.test.tsx b/web/__tests__/goto-anything/slash-command-modes.test.tsx new file mode 100644 index 0000000000..f8126958fc --- /dev/null +++ b/web/__tests__/goto-anything/slash-command-modes.test.tsx @@ -0,0 +1,212 @@ +import '@testing-library/jest-dom' +import { slashCommandRegistry } from '../../app/components/goto-anything/actions/commands/registry' +import type { SlashCommandHandler } from '../../app/components/goto-anything/actions/commands/types' + +// Mock the registry +jest.mock('../../app/components/goto-anything/actions/commands/registry') + +describe('Slash Command Dual-Mode System', () => { + const mockDirectCommand: SlashCommandHandler = { + name: 'docs', + description: 'Open documentation', + mode: 'direct', + execute: jest.fn(), + search: jest.fn().mockResolvedValue([ + { + id: 'docs', + title: 'Documentation', + description: 'Open documentation', + type: 'command' as const, + data: { command: 'navigation.docs', args: {} }, + }, + ]), + register: jest.fn(), + unregister: jest.fn(), + } + + const mockSubmenuCommand: SlashCommandHandler = { + name: 'theme', + description: 'Change theme', + mode: 'submenu', + search: jest.fn().mockResolvedValue([ + { + id: 'theme-light', + title: 'Light Theme', + description: 'Switch to light theme', + type: 'command' as const, + data: { command: 'theme.set', args: { theme: 'light' } }, + }, + { + id: 'theme-dark', + title: 'Dark Theme', + description: 'Switch to dark theme', + type: 'command' as const, + data: { command: 'theme.set', args: { theme: 'dark' } }, + }, + ]), + register: jest.fn(), + unregister: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + ;(slashCommandRegistry as any).findCommand = jest.fn((name: string) => { + if (name === 'docs') return mockDirectCommand + if (name === 'theme') return mockSubmenuCommand + return null + }) + ;(slashCommandRegistry as any).getAllCommands = jest.fn(() => [ + mockDirectCommand, + mockSubmenuCommand, + ]) + }) + + describe('Direct Mode Commands', () => { + it('should execute immediately when selected', () => { + const mockSetShow = jest.fn() + const mockSetSearchQuery = jest.fn() + + // Simulate command selection + const handler = slashCommandRegistry.findCommand('docs') + expect(handler?.mode).toBe('direct') + + if (handler?.mode === 'direct' && handler.execute) { + handler.execute() + mockSetShow(false) + mockSetSearchQuery('') + } + + expect(mockDirectCommand.execute).toHaveBeenCalled() + expect(mockSetShow).toHaveBeenCalledWith(false) + expect(mockSetSearchQuery).toHaveBeenCalledWith('') + }) + + it('should not enter submenu for direct mode commands', () => { + const handler = slashCommandRegistry.findCommand('docs') + expect(handler?.mode).toBe('direct') + expect(handler?.execute).toBeDefined() + }) + + it('should close modal after execution', () => { + const mockModalClose = jest.fn() + + const handler = slashCommandRegistry.findCommand('docs') + if (handler?.mode === 'direct' && handler.execute) { + handler.execute() + mockModalClose() + } + + expect(mockModalClose).toHaveBeenCalled() + }) + }) + + describe('Submenu Mode Commands', () => { + it('should show options instead of executing immediately', async () => { + const handler = slashCommandRegistry.findCommand('theme') + expect(handler?.mode).toBe('submenu') + + const results = await handler?.search('', 'en') + expect(results).toHaveLength(2) + expect(results?.[0].title).toBe('Light Theme') + expect(results?.[1].title).toBe('Dark Theme') + }) + + it('should not have execute function for submenu mode', () => { + const handler = slashCommandRegistry.findCommand('theme') + expect(handler?.mode).toBe('submenu') + expect(handler?.execute).toBeUndefined() + }) + + it('should keep modal open for selection', () => { + const mockModalClose = jest.fn() + + const handler = slashCommandRegistry.findCommand('theme') + // For submenu mode, modal should not close immediately + expect(handler?.mode).toBe('submenu') + expect(mockModalClose).not.toHaveBeenCalled() + }) + }) + + describe('Mode Detection and Routing', () => { + it('should correctly identify direct mode commands', () => { + const commands = slashCommandRegistry.getAllCommands() + const directCommands = commands.filter(cmd => cmd.mode === 'direct') + const submenuCommands = commands.filter(cmd => cmd.mode === 'submenu') + + expect(directCommands).toContainEqual(expect.objectContaining({ name: 'docs' })) + expect(submenuCommands).toContainEqual(expect.objectContaining({ name: 'theme' })) + }) + + it('should handle missing mode property gracefully', () => { + const commandWithoutMode: SlashCommandHandler = { + name: 'test', + description: 'Test command', + search: jest.fn(), + register: jest.fn(), + unregister: jest.fn(), + } + + ;(slashCommandRegistry as any).findCommand = jest.fn(() => commandWithoutMode) + + const handler = slashCommandRegistry.findCommand('test') + // Default behavior should be submenu when mode is not specified + expect(handler?.mode).toBeUndefined() + expect(handler?.execute).toBeUndefined() + }) + }) + + describe('Enter Key Handling', () => { + // Helper function to simulate key handler behavior + const createKeyHandler = () => { + return (commandKey: string) => { + if (commandKey.startsWith('/')) { + const commandName = commandKey.substring(1) + const handler = slashCommandRegistry.findCommand(commandName) + if (handler?.mode === 'direct' && handler.execute) { + handler.execute() + return true // Indicates handled + } + } + return false + } + } + + it('should trigger direct execution on Enter for direct mode', () => { + const keyHandler = createKeyHandler() + const handled = keyHandler('/docs') + expect(handled).toBe(true) + expect(mockDirectCommand.execute).toHaveBeenCalled() + }) + + it('should not trigger direct execution for submenu mode', () => { + const keyHandler = createKeyHandler() + const handled = keyHandler('/theme') + expect(handled).toBe(false) + expect(mockSubmenuCommand.search).not.toHaveBeenCalled() + }) + }) + + describe('Command Registration', () => { + it('should register both direct and submenu commands', () => { + mockDirectCommand.register?.({}) + mockSubmenuCommand.register?.({ setTheme: jest.fn() }) + + expect(mockDirectCommand.register).toHaveBeenCalled() + expect(mockSubmenuCommand.register).toHaveBeenCalled() + }) + + it('should handle unregistration for both command types', () => { + // Test unregister for direct command + mockDirectCommand.unregister?.() + expect(mockDirectCommand.unregister).toHaveBeenCalled() + + // Test unregister for submenu command + mockSubmenuCommand.unregister?.() + expect(mockSubmenuCommand.unregister).toHaveBeenCalled() + + // Verify both were called independently + expect(mockDirectCommand.unregister).toHaveBeenCalledTimes(1) + expect(mockSubmenuCommand.unregister).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index f8189b0c8a..6d72e957e3 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -129,7 +129,7 @@ const DatasetDetailLayout: FC = (props) => { params: { datasetId }, } = props const pathname = usePathname() - const hideSideBar = /documents\/create$/.test(pathname) + const hideSideBar = pathname.endsWith('documents/create') const { t } = useTranslation() const { isCurrentWorkspaceDatasetOperator } = useAppContext() diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index 0d41691dfd..ccbc73aef0 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -1949,57 +1949,6 @@ ___
- - - - ### Path - - - Knowledge ID - - - Document ID - - - - - - ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/upload-file' \ - --header 'Authorization: Bearer {api_key}' \ - --header 'Content-Type: application/json' - ``` - - - ```json {{ title: 'Response' }} - { - "id": "file_id", - "name": "file_name", - "size": 1024, - "extension": "txt", - "url": "preview_url", - "download_url": "download_url", - "mime_type": "text/plain", - "created_by": "user_id", - "created_at": 1728734540, - } - ``` - - - - -
- - - - - ### パス - - - ナレッジ ID - - - ドキュメント ID - - - - - - ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/upload-file' \ - --header 'Authorization: Bearer {api_key}' \ - --header 'Content-Type: application/json' - ``` - - - ```json {{ title: 'Response' }} - { - "id": "file_id", - "name": "file_name", - "size": 1024, - "extension": "txt", - "url": "preview_url", - "download_url": "download_url", - "mime_type": "text/plain", - "created_by": "user_id", - "created_at": 1728734540, - } - ``` - - - - -
-
- diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index b7ea889a46..1971d9ff84 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -1991,57 +1991,6 @@ ___
- - - - ### Path - - - 知识库 ID - - - 文档 ID - - - - - - ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/upload-file' \ - --header 'Authorization: Bearer {api_key}' \ - --header 'Content-Type: application/json' - ``` - - - ```json {{ title: 'Response' }} - { - "id": "file_id", - "name": "file_name", - "size": 1024, - "extension": "txt", - "url": "preview_url", - "download_url": "download_url", - "mime_type": "text/plain", - "created_by": "user_id", - "created_at": 1728734540, - } - ``` - - - - -
- setEmail(e.target.value)} />
- +
diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 2037647b99..dc13d59f2b 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -319,7 +319,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx background={appDetail.icon_background} imageUrl={appDetail.icon_url} /> -
+
{appDetail.name}
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
diff --git a/web/app/components/app/configuration/config-vision/index.tsx b/web/app/components/app/configuration/config-vision/index.tsx index f719822bf9..f0904b3fd8 100644 --- a/web/app/components/app/configuration/config-vision/index.tsx +++ b/web/app/components/app/configuration/config-vision/index.tsx @@ -45,7 +45,7 @@ const ConfigVision: FC = () => { if (draft.file) { draft.file.enabled = (draft.file.allowed_file_types?.length ?? 0) > 0 draft.file.image = { - ...(draft.file.image || {}), + ...draft.file.image, enabled: value, } } diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index e6b6c83846..5c87eea3ca 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -50,6 +50,7 @@ export type IGetAutomaticResProps = { onFinished: (res: GenRes) => void flowId?: string nodeId?: string + editorId?: string currentPrompt?: string isBasicMode?: boolean } @@ -76,6 +77,7 @@ const GetAutomaticRes: FC = ({ onClose, flowId, nodeId, + editorId, currentPrompt, isBasicMode, onFinished, @@ -132,7 +134,8 @@ const GetAutomaticRes: FC = ({ }, ] - const [instructionFromSessionStorage, setInstruction] = useSessionStorageState(`improve-instruction-${flowId}${isBasicMode ? '' : `-${nodeId}`}`) + // eslint-disable-next-line sonarjs/no-nested-template-literals, sonarjs/no-nested-conditional + const [instructionFromSessionStorage, setInstruction] = useSessionStorageState(`improve-instruction-${flowId}${isBasicMode ? '' : `-${nodeId}${editorId ? `-${editorId}` : ''}`}`) const instruction = instructionFromSessionStorage || '' const [ideaOutput, setIdeaOutput] = useState('') @@ -166,7 +169,7 @@ const GetAutomaticRes: FC = ({ return true } const [isLoading, { setTrue: setLoadingTrue, setFalse: setLoadingFalse }] = useBoolean(false) - const storageKey = `${flowId}${isBasicMode ? '' : `-${nodeId}`}` + const storageKey = `${flowId}${isBasicMode ? '' : `-${nodeId}${editorId ? `-${editorId}` : ''}`}` const { addVersion, current, currentVersionIndex, setCurrentVersionIndex, versions } = useGenData({ storageKey, }) diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 70a45a4bbe..cd73874c2c 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -1,6 +1,6 @@ 'use client' -import { useCallback, useRef, useState } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' @@ -35,14 +35,15 @@ type CreateAppProps = { onSuccess: () => void onClose: () => void onCreateFromTemplate?: () => void + defaultAppMode?: AppMode } -function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) { +function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: CreateAppProps) { const { t } = useTranslation() const { push } = useRouter() const { notify } = useContext(ToastContext) - const [appMode, setAppMode] = useState('advanced-chat') + const [appMode, setAppMode] = useState(defaultAppMode || 'advanced-chat') const [appIcon, setAppIcon] = useState({ type: 'emoji', icon: '🤖', background: '#FFEAD5' }) const [showAppIconPicker, setShowAppIconPicker] = useState(false) const [name, setName] = useState('') @@ -55,6 +56,11 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) const isCreatingRef = useRef(false) + useEffect(() => { + if (appMode === 'chat' || appMode === 'agent-chat' || appMode === 'completion') + setIsAppTypeExpanded(true) + }, [appMode]) + const onCreate = useCallback(async () => { if (!appMode) { notify({ type: 'error', message: t('app.newApp.appTypeRequired') }) @@ -264,7 +270,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) type CreateAppDialogProps = CreateAppProps & { show: boolean } -const CreateAppModal = ({ show, onClose, onSuccess, onCreateFromTemplate }: CreateAppDialogProps) => { +const CreateAppModal = ({ show, onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: CreateAppDialogProps) => { return ( - + ) } diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index aa85fb1313..4ee9a6d6d5 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -211,14 +211,14 @@ const List = () => { {(data && data[0].total > 0) ?
{isCurrentWorkspaceEditor - && } + && } {data.map(({ data: apps }) => apps.map(app => ( )))}
:
{isCurrentWorkspaceEditor - && } + && }
} diff --git a/web/app/components/apps/new-app-card.tsx b/web/app/components/apps/new-app-card.tsx index 451d2ae326..6ceeb47982 100644 --- a/web/app/components/apps/new-app-card.tsx +++ b/web/app/components/apps/new-app-card.tsx @@ -26,12 +26,14 @@ export type CreateAppCardProps = { className?: string onSuccess?: () => void ref: React.RefObject + selectedAppType?: string } const CreateAppCard = ({ ref, className, onSuccess, + selectedAppType, }: CreateAppCardProps) => { const { t } = useTranslation() const { onPlanInfoChanged } = useProviderContext() @@ -86,6 +88,7 @@ const CreateAppCard = ({ setShowNewAppTemplateDialog(true) setShowNewAppModal(false) }} + defaultAppMode={selectedAppType !== 'all' ? selectedAppType as any : undefined} /> )} {showNewAppTemplateDialog && ( diff --git a/web/app/components/base/agent-log-modal/tool-call.tsx b/web/app/components/base/agent-log-modal/tool-call.tsx index 499a70367c..433a20fd5d 100644 --- a/web/app/components/base/agent-log-modal/tool-call.tsx +++ b/web/app/components/base/agent-log-modal/tool-call.tsx @@ -33,7 +33,7 @@ const ToolCallItem: FC = ({ toolCall, isLLM = false, isFinal, tokens, obs if (time < 1) return `${(time * 1000).toFixed(3)} ms` if (time > 60) - return `${Number.parseInt(Math.round(time / 60).toString())} m ${(time % 60).toFixed(3)} s` + return `${Math.floor(time / 60)} m ${(time % 60).toFixed(3)} s` return `${time.toFixed(3)} s` } diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 17373cec9d..665e7e8bc3 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -682,7 +682,7 @@ export const useChat = ( updateChatTreeNode(targetAnswerId, { content: chatList[index].content, annotation: { - ...(chatList[index].annotation || {}), + ...chatList[index].annotation, id: '', } as Annotation, }) diff --git a/web/app/components/base/date-and-time-picker/date-picker/index.tsx b/web/app/components/base/date-and-time-picker/date-picker/index.tsx index f99b8257c1..f6b7973cb0 100644 --- a/web/app/components/base/date-and-time-picker/date-picker/index.tsx +++ b/web/app/components/base/date-and-time-picker/date-picker/index.tsx @@ -42,7 +42,14 @@ const DatePicker = ({ const [view, setView] = useState(ViewType.date) const containerRef = useRef(null) const isInitial = useRef(true) - const inputValue = useRef(value ? value.tz(timezone) : undefined).current + + // Normalize the value to ensure that all subsequent uses are Day.js objects. + const normalizedValue = useMemo(() => { + if (!value) return undefined + return dayjs.isDayjs(value) ? value.tz(timezone) : dayjs(value).tz(timezone) + }, [value, timezone]) + + const inputValue = useRef(normalizedValue).current const defaultValue = useRef(getDateWithTimezone({ timezone })).current const [currentDate, setCurrentDate] = useState(inputValue || defaultValue) @@ -68,8 +75,8 @@ const DatePicker = ({ return } clearMonthMapCache() - if (value) { - const newValue = getDateWithTimezone({ date: value, timezone }) + if (normalizedValue) { + const newValue = getDateWithTimezone({ date: normalizedValue, timezone }) setCurrentDate(newValue) setSelectedDate(newValue) onChange(newValue) @@ -88,9 +95,9 @@ const DatePicker = ({ } setView(ViewType.date) setIsOpen(true) - if (value) { - setCurrentDate(value) - setSelectedDate(value) + if (normalizedValue) { + setCurrentDate(normalizedValue) + setSelectedDate(normalizedValue) } } @@ -192,7 +199,7 @@ const DatePicker = ({ } const timeFormat = needTimePicker ? t('time.dateFormats.displayWithTime') : t('time.dateFormats.display') - const displayValue = value?.format(timeFormat) || '' + const displayValue = normalizedValue?.format(timeFormat) || '' const displayTime = selectedDate?.format('hh:mm A') || '--:-- --' const placeholderDate = isOpen && selectedDate ? selectedDate.format(timeFormat) : (placeholder || t('time.defaultPlaceholder')) @@ -204,7 +211,7 @@ const DatePicker = ({ > {renderTrigger ? (renderTrigger({ - value, + value: normalizedValue, selectedDate, isOpen, handleClear, diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx index 53db991e71..ec8681f37c 100644 --- a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx @@ -83,9 +83,7 @@ const OpeningSettingModal = ({ }, [handleSave, hideConfirmAddVar]) const autoAddVar = useCallback(() => { - onAutoAddPromptVariable?.([ - ...notIncludeKeys.map(key => getNewVar(key, 'string')), - ]) + onAutoAddPromptVariable?.(notIncludeKeys.map(key => getNewVar(key, 'string'))) hideConfirmAddVar() handleSave(true) }, [handleSave, hideConfirmAddVar, notIncludeKeys, onAutoAddPromptVariable]) diff --git a/web/app/components/base/ga/index.tsx b/web/app/components/base/ga/index.tsx index 7a95561754..81d84a85d3 100644 --- a/web/app/components/base/ga/index.tsx +++ b/web/app/components/base/ga/index.tsx @@ -24,7 +24,7 @@ const GA: FC = ({ if (IS_CE_EDITION) return null - const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') : '' + const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : '' return ( <> @@ -32,7 +32,7 @@ const GA: FC = ({ strategy="beforeInteractive" async src={`https://www.googletagmanager.com/gtag/js?id=${gaIdMaps[gaType]}`} - nonce={nonce!} + nonce={nonce ?? undefined} > {/* Cookie banner */} diff --git a/web/app/components/base/markdown-blocks/link.tsx b/web/app/components/base/markdown-blocks/link.tsx index 0274ee0141..9bf13040a7 100644 --- a/web/app/components/base/markdown-blocks/link.tsx +++ b/web/app/components/base/markdown-blocks/link.tsx @@ -17,7 +17,7 @@ const Link = ({ node, children, ...props }: any) => { } else { const href = props.href || node.properties?.href - if (href && /^#[a-zA-Z0-9_\-]+$/.test(href.toString())) { + if (href && /^#[a-zA-Z0-9_-]+$/.test(href.toString())) { const handleClick = (e: React.MouseEvent) => { e.preventDefault() // scroll to target element if exists within the answer container diff --git a/web/app/components/base/notion-page-selector/page-selector/index.tsx b/web/app/components/base/notion-page-selector/page-selector/index.tsx index 498955994c..a1ce3116db 100644 --- a/web/app/components/base/notion-page-selector/page-selector/index.tsx +++ b/web/app/components/base/notion-page-selector/page-selector/index.tsx @@ -229,7 +229,7 @@ const PageSelector = ({ if (current.expand) { current.expand = false - newDataList = [...dataList.filter(item => !descendantsIds.includes(item.page_id))] + newDataList = dataList.filter(item => !descendantsIds.includes(item.page_id)) } else { current.expand = true @@ -246,7 +246,7 @@ const PageSelector = ({ setDataList(newDataList) } - const copyValue = new Set([...value]) + const copyValue = new Set(value) const handleCheck = (index: number) => { const current = currentDataList[index] const pageId = current.page_id @@ -269,7 +269,7 @@ const PageSelector = ({ copyValue.add(pageId) } - onSelect(new Set([...copyValue])) + onSelect(new Set(copyValue)) } const handlePreview = (index: number) => { diff --git a/web/app/components/base/param-item/score-threshold-item.tsx b/web/app/components/base/param-item/score-threshold-item.tsx index b5557c80cf..3790a2a074 100644 --- a/web/app/components/base/param-item/score-threshold-item.tsx +++ b/web/app/components/base/param-item/score-threshold-item.tsx @@ -20,7 +20,6 @@ const VALUE_LIMIT = { max: 1, } -const key = 'score_threshold' const ScoreThresholdItem: FC = ({ className, value, @@ -39,9 +38,9 @@ const ScoreThresholdItem: FC = ({ return ( = ({ className, value, @@ -41,9 +40,9 @@ const TopKItem: FC = ({ return ( ), ) diff --git a/web/app/components/base/zendesk/index.tsx b/web/app/components/base/zendesk/index.tsx new file mode 100644 index 0000000000..b3d67eb390 --- /dev/null +++ b/web/app/components/base/zendesk/index.tsx @@ -0,0 +1,21 @@ +import { memo } from 'react' +import { type UnsafeUnwrappedHeaders, headers } from 'next/headers' +import Script from 'next/script' +import { IS_CE_EDITION, ZENDESK_WIDGET_KEY } from '@/config' + +const Zendesk = () => { + if (IS_CE_EDITION || !ZENDESK_WIDGET_KEY) + return null + + const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : '' + + return ( +