diff --git a/.cursorrules b/.cursorrules
new file mode 100644
index 0000000000..cdfb8b17a3
--- /dev/null
+++ b/.cursorrules
@@ -0,0 +1,6 @@
+# Cursor Rules for Dify Project
+
+## Automated Test Generation
+
+- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
+- When proposing or saving tests, re-read that document and follow every requirement.
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
new file mode 100644
index 0000000000..d6f326d4dc
--- /dev/null
+++ b/.github/CODEOWNERS
@@ -0,0 +1,234 @@
+# CODEOWNERS
+# This file defines code ownership for the Dify project.
+# Each line is a file pattern followed by one or more owners.
+# Owners can be @username, @org/team-name, or email addresses.
+# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
+
+* @crazywoola @laipz8200 @Yeuoly
+
+# Backend (default owner, more specific rules below will override)
+api/ @QuantumGhost
+
+# Backend - MCP
+api/core/mcp/ @Nov1c444
+api/core/entities/mcp_provider.py @Nov1c444
+api/services/tools/mcp_tools_manage_service.py @Nov1c444
+api/controllers/mcp/ @Nov1c444
+api/controllers/console/app/mcp_server.py @Nov1c444
+api/tests/**/*mcp* @Nov1c444
+
+# Backend - Workflow - Engine (Core graph execution engine)
+api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
+api/core/workflow/runtime/ @laipz8200 @QuantumGhost
+api/core/workflow/graph/ @laipz8200 @QuantumGhost
+api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
+api/core/workflow/node_events/ @laipz8200 @QuantumGhost
+api/core/model_runtime/ @laipz8200 @QuantumGhost
+
+# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
+api/core/workflow/nodes/agent/ @Nov1c444
+api/core/workflow/nodes/iteration/ @Nov1c444
+api/core/workflow/nodes/loop/ @Nov1c444
+api/core/workflow/nodes/llm/ @Nov1c444
+
+# Backend - RAG (Retrieval Augmented Generation)
+api/core/rag/ @JohnJyong
+api/services/rag_pipeline/ @JohnJyong
+api/services/dataset_service.py @JohnJyong
+api/services/knowledge_service.py @JohnJyong
+api/services/external_knowledge_service.py @JohnJyong
+api/services/hit_testing_service.py @JohnJyong
+api/services/metadata_service.py @JohnJyong
+api/services/vector_service.py @JohnJyong
+api/services/entities/knowledge_entities/ @JohnJyong
+api/services/entities/external_knowledge_entities/ @JohnJyong
+api/controllers/console/datasets/ @JohnJyong
+api/controllers/service_api/dataset/ @JohnJyong
+api/models/dataset.py @JohnJyong
+api/tasks/rag_pipeline/ @JohnJyong
+api/tasks/add_document_to_index_task.py @JohnJyong
+api/tasks/batch_clean_document_task.py @JohnJyong
+api/tasks/clean_document_task.py @JohnJyong
+api/tasks/clean_notion_document_task.py @JohnJyong
+api/tasks/document_indexing_task.py @JohnJyong
+api/tasks/document_indexing_sync_task.py @JohnJyong
+api/tasks/document_indexing_update_task.py @JohnJyong
+api/tasks/duplicate_document_indexing_task.py @JohnJyong
+api/tasks/recover_document_indexing_task.py @JohnJyong
+api/tasks/remove_document_from_index_task.py @JohnJyong
+api/tasks/retry_document_indexing_task.py @JohnJyong
+api/tasks/sync_website_document_indexing_task.py @JohnJyong
+api/tasks/batch_create_segment_to_index_task.py @JohnJyong
+api/tasks/create_segment_to_index_task.py @JohnJyong
+api/tasks/delete_segment_from_index_task.py @JohnJyong
+api/tasks/disable_segment_from_index_task.py @JohnJyong
+api/tasks/disable_segments_from_index_task.py @JohnJyong
+api/tasks/enable_segment_to_index_task.py @JohnJyong
+api/tasks/enable_segments_to_index_task.py @JohnJyong
+api/tasks/clean_dataset_task.py @JohnJyong
+api/tasks/deal_dataset_index_update_task.py @JohnJyong
+api/tasks/deal_dataset_vector_index_task.py @JohnJyong
+
+# Backend - Plugins
+api/core/plugin/ @Mairuis @Yeuoly @Stream29
+api/services/plugin/ @Mairuis @Yeuoly @Stream29
+api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
+api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
+api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
+
+# Backend - Trigger/Schedule/Webhook
+api/controllers/trigger/ @Mairuis @Yeuoly
+api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
+api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
+api/core/trigger/ @Mairuis @Yeuoly
+api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
+api/services/trigger/ @Mairuis @Yeuoly
+api/models/trigger.py @Mairuis @Yeuoly
+api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
+api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
+api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
+api/libs/schedule_utils.py @Mairuis @Yeuoly
+api/services/workflow/scheduler.py @Mairuis @Yeuoly
+api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
+api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
+api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
+api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
+api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
+api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
+api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
+api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
+api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
+api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
+
+# Backend - Async Workflow
+api/services/async_workflow_service.py @Mairuis @Yeuoly
+api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
+
+# Backend - Billing
+api/services/billing_service.py @hj24 @zyssyz123
+api/controllers/console/billing/ @hj24 @zyssyz123
+
+# Backend - Enterprise
+api/configs/enterprise/ @GarfieldDai @GareArc
+api/services/enterprise/ @GarfieldDai @GareArc
+api/services/feature_service.py @GarfieldDai @GareArc
+api/controllers/console/feature.py @GarfieldDai @GareArc
+api/controllers/web/feature.py @GarfieldDai @GareArc
+
+# Backend - Database Migrations
+api/migrations/ @snakevash @laipz8200
+
+# Frontend
+web/ @iamjoel
+
+# Frontend - App - Orchestration
+web/app/components/workflow/ @iamjoel @zxhlyh
+web/app/components/workflow-app/ @iamjoel @zxhlyh
+web/app/components/app/configuration/ @iamjoel @zxhlyh
+web/app/components/app/app-publisher/ @iamjoel @zxhlyh
+
+# Frontend - WebApp - Chat
+web/app/components/base/chat/ @iamjoel @zxhlyh
+
+# Frontend - WebApp - Completion
+web/app/components/share/text-generation/ @iamjoel @zxhlyh
+
+# Frontend - App - List and Creation
+web/app/components/apps/ @JzoNgKVO @iamjoel
+web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
+web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
+web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
+
+# Frontend - App - API Documentation
+web/app/components/develop/ @JzoNgKVO @iamjoel
+
+# Frontend - App - Logs and Annotations
+web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
+web/app/components/app/log/ @JzoNgKVO @iamjoel
+web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
+web/app/components/app/annotation/ @JzoNgKVO @iamjoel
+
+# Frontend - App - Monitoring
+web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
+web/app/components/app/overview/ @JzoNgKVO @iamjoel
+
+# Frontend - App - Settings
+web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
+
+# Frontend - RAG - Hit Testing
+web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
+
+# Frontend - RAG - List and Creation
+web/app/components/datasets/list/ @iamjoel @WTW0313
+web/app/components/datasets/create/ @iamjoel @WTW0313
+web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
+web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
+
+# Frontend - RAG - Orchestration (general rule first, specific rules below override)
+web/app/components/rag-pipeline/ @iamjoel @WTW0313
+web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
+web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
+
+# Frontend - RAG - Documents List
+web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
+web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
+
+# Frontend - RAG - Segments List
+web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
+
+# Frontend - RAG - Settings
+web/app/components/datasets/settings/ @iamjoel @WTW0313
+
+# Frontend - Ecosystem - Plugins
+web/app/components/plugins/ @iamjoel @zhsama
+
+# Frontend - Ecosystem - Tools
+web/app/components/tools/ @iamjoel @Yessenia-d
+
+# Frontend - Ecosystem - MarketPlace
+web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
+
+# Frontend - Login and Registration
+web/app/signin/ @douxc @iamjoel
+web/app/signup/ @douxc @iamjoel
+web/app/reset-password/ @douxc @iamjoel
+web/app/install/ @douxc @iamjoel
+web/app/init/ @douxc @iamjoel
+web/app/forgot-password/ @douxc @iamjoel
+web/app/account/ @douxc @iamjoel
+
+# Frontend - Service Authentication
+web/service/base.ts @douxc @iamjoel
+
+# Frontend - WebApp Authentication and Access Control
+web/app/(shareLayout)/components/ @douxc @iamjoel
+web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
+web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
+web/app/components/app/app-access-control/ @douxc @iamjoel
+
+# Frontend - Explore Page
+web/app/components/explore/ @CodingOnStar @iamjoel
+
+# Frontend - Personal Settings
+web/app/components/header/account-setting/ @CodingOnStar @iamjoel
+web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
+
+# Frontend - Analytics
+web/app/components/base/ga/ @CodingOnStar @iamjoel
+
+# Frontend - Base Components
+web/app/components/base/ @iamjoel @zxhlyh
+
+# Frontend - Utils and Hooks
+web/utils/classnames.ts @iamjoel @zxhlyh
+web/utils/time.ts @iamjoel @zxhlyh
+web/utils/format.ts @iamjoel @zxhlyh
+web/utils/clipboard.ts @iamjoel @zxhlyh
+web/hooks/use-document-title.ts @iamjoel @zxhlyh
+
+# Frontend - Billing and Education
+web/app/components/billing/ @iamjoel @zxhlyh
+web/app/education-apply/ @iamjoel @zxhlyh
+
+# Frontend - Workspace
+web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
diff --git a/.github/ISSUE_TEMPLATE/refactor.yml b/.github/ISSUE_TEMPLATE/refactor.yml
index cf74dcc546..dbe8cbb602 100644
--- a/.github/ISSUE_TEMPLATE/refactor.yml
+++ b/.github/ISSUE_TEMPLATE/refactor.yml
@@ -1,8 +1,6 @@
-name: "✨ Refactor"
-description: Refactor existing code for improved readability and maintainability.
-title: "[Chore/Refactor] "
-labels:
- - refactor
+name: "✨ Refactor or Chore"
+description: Refactor existing code or perform maintenance chores to improve readability and reliability.
+title: "[Refactor/Chore] "
body:
- type: checkboxes
attributes:
@@ -11,7 +9,7 @@ body:
options:
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true
- - label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
+ - label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
@@ -25,14 +23,14 @@ body:
id: description
attributes:
label: Description
- placeholder: "Describe the refactor you are proposing."
+ placeholder: "Describe the refactor or chore you are proposing."
validations:
required: true
- type: textarea
id: motivation
attributes:
label: Motivation
- placeholder: "Explain why this refactor is necessary."
+ placeholder: "Explain why this refactor or chore is necessary."
validations:
required: false
- type: textarea
diff --git a/.github/ISSUE_TEMPLATE/tracker.yml b/.github/ISSUE_TEMPLATE/tracker.yml
deleted file mode 100644
index 35fedefc75..0000000000
--- a/.github/ISSUE_TEMPLATE/tracker.yml
+++ /dev/null
@@ -1,13 +0,0 @@
-name: "👾 Tracker"
-description: For inner usages, please do not use this template.
-title: "[Tracker] "
-labels:
- - tracker
-body:
- - type: textarea
- id: content
- attributes:
- label: Blockers
- placeholder: "- [ ] ..."
- validations:
- required: true
diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md
new file mode 100644
index 0000000000..53afcbda1e
--- /dev/null
+++ b/.github/copilot-instructions.md
@@ -0,0 +1,12 @@
+# Copilot Instructions
+
+GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
+
+Key reminders:
+
+- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
+- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
+- Target >95% line and branch coverage and 100% function/statement coverage.
+- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
+
+Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.
diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml
new file mode 100644
index 0000000000..b15c26a096
--- /dev/null
+++ b/.github/workflows/semantic-pull-request.yml
@@ -0,0 +1,21 @@
+name: Semantic Pull Request
+
+on:
+ pull_request:
+ types:
+ - opened
+ - edited
+ - reopened
+ - synchronize
+
+jobs:
+ lint:
+ name: Validate PR title
+ permissions:
+ pull-requests: read
+ runs-on: ubuntu-latest
+ steps:
+ - name: Check title
+ uses: amannn/action-semantic-pull-request@v6.1.1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index e652657705..5a8a34be79 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -106,7 +106,7 @@ jobs:
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
- run: pnpm run type-check
+ run: pnpm run type-check:tsgo
docker-compose-template:
name: Docker Compose Template
diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml
index 836c3e0b02..fe8e2ebc2b 100644
--- a/.github/workflows/translate-i18n-base-on-english.yml
+++ b/.github/workflows/translate-i18n-base-on-english.yml
@@ -20,22 +20,22 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
- fetch-depth: 2
+ fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Check for file changes in i18n/en-US
id: check_files
run: |
- recent_commit_sha=$(git rev-parse HEAD)
- second_recent_commit_sha=$(git rev-parse HEAD~1)
- changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
+ git fetch origin "${{ github.event.before }}" || true
+ git fetch origin "${{ github.sha }}" || true
+ changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .ts)
- file_args="$file_args --file=$filename"
+ file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
@@ -77,12 +77,15 @@ jobs:
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
- commit-message: Update i18n files and type definitions based on en-US changes
- title: 'chore: translate i18n files and update type definitions'
+ commit-message: 'chore(i18n): update translations based on en-US changes'
+ title: 'chore(i18n): translate i18n files and update type definitions'
body: |
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
-
+
+ **Triggered by:** ${{ github.sha }}
+
**Changes included:**
- Updated translation files for all locales
- Regenerated TypeScript type definitions for type safety
- branch: chore/automated-i18n-updates
+ branch: chore/automated-i18n-updates-${{ github.sha }}
+ delete-branch: true
diff --git a/.nvmrc b/.nvmrc
new file mode 100644
index 0000000000..7af24b7ddb
--- /dev/null
+++ b/.nvmrc
@@ -0,0 +1 @@
+22.11.0
diff --git a/.windsurf/rules/testing.md b/.windsurf/rules/testing.md
new file mode 100644
index 0000000000..64fec20cb8
--- /dev/null
+++ b/.windsurf/rules/testing.md
@@ -0,0 +1,5 @@
+# Windsurf Testing Rules
+
+- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
+- Honor every requirement in that document when generating or accepting tests.
+- When proposing or saving tests, re-read that document and follow every requirement.
diff --git a/AGENTS.md b/AGENTS.md
index 2ef7931efc..782861ad36 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -24,8 +24,8 @@ The codebase is split into:
```bash
cd web
-pnpm lint
pnpm lint:fix
+pnpm type-check:tsgo
pnpm test
```
@@ -39,7 +39,7 @@ pnpm test
## Language Style
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
-- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types.
+- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
## General Practices
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index fdc414b047..20a7d6c6f6 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -77,6 +77,8 @@ How we prioritize:
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
+**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
+
#### Backend
For setting up the backend service, kindly refer to our detailed [instructions](https://github.com/langgenius/dify/blob/main/api/README.md) in the `api/README.md` file. This document contains step-by-step guidance to help you get the backend up and running smoothly.
diff --git a/README.md b/README.md
index e5cc05fbc0..b71764a214 100644
--- a/README.md
+++ b/README.md
@@ -36,6 +36,12 @@
+
+
+
+
+
+
@@ -133,6 +139,19 @@ Star Dify on GitHub and be instantly notified of new releases.
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
+#### Customizing Suggested Questions
+
+You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
+
+```bash
+# In your .env file
+SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
+SUGGESTED_QUESTIONS_MAX_TOKENS=512
+SUGGESTED_QUESTIONS_TEMPERATURE=0.3
+```
+
+See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
+
### Metrics Monitoring with Grafana
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
diff --git a/api/.env.example b/api/.env.example
index ba512a668d..516a119d98 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -176,6 +176,7 @@ WEAVIATE_ENDPOINT=http://localhost:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
+WEAVIATE_TOKENIZATION=word
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
@@ -539,6 +540,7 @@ WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# App configuration
APP_MAX_EXECUTION_TIME=1200
+APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
@@ -631,8 +633,30 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true
+# Suggested Questions After Answer Configuration
+# These environment variables allow customization of the suggested questions feature
+#
+# Custom prompt for generating suggested questions (optional)
+# If not set, uses the default prompt that generates 3 questions under 20 characters each
+# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
+# SUGGESTED_QUESTIONS_PROMPT=
+
+# Maximum number of tokens for suggested questions generation (default: 256)
+# Adjust this value for longer questions or more questions
+# SUGGESTED_QUESTIONS_MAX_TOKENS=256
+
+# Temperature for suggested questions generation (default: 0.0)
+# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
+# SUGGESTED_QUESTIONS_TEMPERATURE=0
+
# Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0
+
+# Multimodal knowledgebase limit
+SINGLE_CHUNK_ATTACHMENT_LIMIT=10
+ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
+ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
+IMAGE_FILE_BATCH_LIMIT=10
diff --git a/api/.importlinter b/api/.importlinter
index 98fe5f50bb..24ece72b30 100644
--- a/api/.importlinter
+++ b/api/.importlinter
@@ -16,6 +16,7 @@ layers =
graph
nodes
node_events
+ runtime
entities
containers =
core.workflow
diff --git a/api/.ruff.toml b/api/.ruff.toml
index 5a29e1d8fa..7206f7fa0f 100644
--- a/api/.ruff.toml
+++ b/api/.ruff.toml
@@ -36,17 +36,20 @@ select = [
"UP", # pyupgrade rules
"W191", # tab-indentation
"W605", # invalid-escape-sequence
+ "G001", # don't use str format to logging messages
+ "G003", # don't use + in logging messages
+ "G004", # don't use f-strings to format logging messages
+ "UP042", # use StrEnum,
+ "S110", # disallow the try-except-pass pattern.
+
# security related linting rules
# RCE proctection (sort of)
"S102", # exec-builtin, disallow use of `exec`
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
- "S311", # suspicious-non-cryptographic-random-usage
- "G001", # don't use str format to logging messages
- "G003", # don't use + in logging messages
- "G004", # don't use f-strings to format logging messages
- "UP042", # use StrEnum
+ "S311", # suspicious-non-cryptographic-random-usage,
+
]
ignore = [
@@ -91,18 +94,16 @@ ignore = [
"configs/*" = [
"N802", # invalid-function-name
]
-"core/model_runtime/callbacks/base_callback.py" = [
- "T201",
-]
-"core/workflow/callbacks/workflow_logging_callback.py" = [
- "T201",
-]
+"core/model_runtime/callbacks/base_callback.py" = ["T201"]
+"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
- "T201", # allow print in tests
+ "T201", # allow print in tests,
+ "S110", # allow ignoring exceptions in tests code (currently)
+
]
[lint.pyflakes]
diff --git a/api/Dockerfile b/api/Dockerfile
index ed61923a40..02df91bfc1 100644
--- a/api/Dockerfile
+++ b/api/Dockerfile
@@ -48,6 +48,12 @@ ENV PYTHONIOENCODING=utf-8
WORKDIR /app/api
+# Create non-root user
+ARG dify_uid=1001
+RUN groupadd -r -g ${dify_uid} dify && \
+ useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
+ chown -R dify:dify /app
+
RUN \
apt-get update \
# Install dependencies
@@ -57,7 +63,7 @@ RUN \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security
- expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
+ expat libldap-2.5-0=2.5.13+dfsg-5 perl libsqlite3-0=3.40.1-2+deb12u2 zlib1g=1:1.2.13.dfsg-1 \
# install fonts to support the use of tools like pypdfium2
fonts-noto-cjk \
# install a package to improve the accuracy of guessing mime type and file extension
@@ -69,24 +75,29 @@ RUN \
# Copy Python environment and packages
ENV VIRTUAL_ENV=/app/api/.venv
-COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
+COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
-RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
+RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
+ && chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
-RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
+RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" \
+ && chown -R dify:dify ${TIKTOKEN_CACHE_DIR}
# Copy source code
-COPY . /app/api/
+COPY --chown=dify:dify . /app/api/
+
+# Prepare entrypoint script
+COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh /entrypoint.sh
-# Copy entrypoint
-COPY docker/entrypoint.sh /entrypoint.sh
-RUN chmod +x /entrypoint.sh
ARG COMMIT_SHA
ENV COMMIT_SHA=${COMMIT_SHA}
+ENV NLTK_DATA=/usr/local/share/nltk_data
+
+USER dify
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
diff --git a/api/app_factory.py b/api/app_factory.py
index 933cf294d1..3a3ee03cff 100644
--- a/api/app_factory.py
+++ b/api/app_factory.py
@@ -1,6 +1,8 @@
import logging
import time
+from opentelemetry.trace import get_current_span
+
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp
@@ -26,8 +28,25 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()
+ # add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
+ @dify_app.after_request
+ def add_trace_id_header(response):
+ try:
+ span = get_current_span()
+ ctx = span.get_span_context() if span else None
+ if ctx and ctx.is_valid:
+ trace_id_hex = format(ctx.trace_id, "032x")
+ # Avoid duplicates if some middleware added it
+ if "X-Trace-Id" not in response.headers:
+ response.headers["X-Trace-Id"] = trace_id_hex
+ except Exception:
+ # Never break the response due to tracing header injection
+ logger.warning("Failed to add trace ID to response header", exc_info=True)
+ return response
+
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
+ _ = add_trace_id_header
return dify_app
@@ -51,6 +70,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
+ ext_forward_refs,
ext_hosting_provider,
ext_import_modules,
ext_logging,
@@ -75,6 +95,7 @@ def initialize_extensions(app: DifyApp):
ext_warnings,
ext_import_modules,
ext_orjson,
+ ext_forward_refs,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,
diff --git a/api/commands.py b/api/commands.py
index e15c996a34..a8d89ac200 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -1139,6 +1139,7 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
except Exception as e:
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
+ return
all_files_on_storage = []
for storage_path in storage_paths:
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index 7cce3847b4..a5916241df 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -73,6 +73,10 @@ class AppExecutionConfig(BaseSettings):
description="Maximum allowed execution time for the application in seconds",
default=1200,
)
+ APP_DEFAULT_ACTIVE_REQUESTS: NonNegativeInt = Field(
+ description="Default number of concurrent active requests per app (0 for unlimited)",
+ default=0,
+ )
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Maximum number of concurrent active requests per app (0 for unlimited)",
default=0,
@@ -356,6 +360,26 @@ class FileUploadConfig(BaseSettings):
default=10,
)
+ IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
+ description="Maximum number of files allowed in a image batch upload operation",
+ default=10,
+ )
+
+ SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
+ description="Maximum number of files allowed in a single chunk attachment",
+ default=10,
+ )
+
+ ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
+ description="Maximum allowed image file size for attachments in megabytes",
+ default=2,
+ )
+
+ ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
+ description="Timeout for downloading image attachments in seconds",
+ default=60,
+ )
+
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "
@@ -549,7 +573,10 @@ class LoggingConfig(BaseSettings):
LOG_FORMAT: str = Field(
description="Format string for log messages",
- default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
+ default=(
+ "%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
+ "[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
+ ),
)
LOG_DATEFORMAT: str | None = Field(
diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py
index aa81c870f6..6f4fccaa7f 100644
--- a/api/configs/middleware/vdb/weaviate_config.py
+++ b/api/configs/middleware/vdb/weaviate_config.py
@@ -31,3 +31,8 @@ class WeaviateConfig(BaseSettings):
description="Number of objects to be processed in a single batch operation (default is 100)",
default=100,
)
+
+ WEAVIATE_TOKENIZATION: str | None = Field(
+ description="Tokenization for Weaviate (default is word)",
+ default="word",
+ )
diff --git a/api/controllers/common/schema.py b/api/controllers/common/schema.py
new file mode 100644
index 0000000000..e0896a8dc2
--- /dev/null
+++ b/api/controllers/common/schema.py
@@ -0,0 +1,26 @@
+"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
+
+from flask_restx import Namespace
+from pydantic import BaseModel
+
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
+ """Register a single BaseModel with a namespace for Swagger documentation."""
+
+ namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
+ """Register multiple BaseModels with a namespace."""
+
+ for model in models:
+ register_schema_model(namespace, model)
+
+
+__all__ = [
+ "DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
+ "register_schema_model",
+ "register_schema_models",
+]
diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py
index 2c4d8709eb..7aa1e6dbd8 100644
--- a/api/controllers/console/admin.py
+++ b/api/controllers/console/admin.py
@@ -3,7 +3,8 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
-from flask_restx import Resource, fields, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
@@ -12,12 +13,36 @@ P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class InsertExploreAppPayload(BaseModel):
+ app_id: str = Field(...)
+ desc: str | None = None
+ copyright: str | None = None
+ privacy_policy: str | None = None
+ custom_disclaimer: str | None = None
+ language: str = Field(...)
+ category: str = Field(...)
+ position: int = Field(...)
+
+ @field_validator("language")
+ @classmethod
+ def validate_language(cls, value: str) -> str:
+ return supported_language(value)
+
+
+console_ns.schema_model(
+ InsertExploreAppPayload.__name__,
+ InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
def admin_required(view: Callable[P, R]):
@wraps(view)
@@ -38,61 +63,36 @@ def admin_required(view: Callable[P, R]):
@console_ns.route("/admin/insert-explore-apps")
class InsertExploreAppListApi(Resource):
- @api.doc("insert_explore_app")
- @api.doc(description="Insert or update an app in the explore list")
- @api.expect(
- api.model(
- "InsertExploreAppRequest",
- {
- "app_id": fields.String(required=True, description="Application ID"),
- "desc": fields.String(description="App description"),
- "copyright": fields.String(description="Copyright information"),
- "privacy_policy": fields.String(description="Privacy policy"),
- "custom_disclaimer": fields.String(description="Custom disclaimer"),
- "language": fields.String(required=True, description="Language code"),
- "category": fields.String(required=True, description="App category"),
- "position": fields.Integer(required=True, description="Display position"),
- },
- )
- )
- @api.response(200, "App updated successfully")
- @api.response(201, "App inserted successfully")
- @api.response(404, "App not found")
+ @console_ns.doc("insert_explore_app")
+ @console_ns.doc(description="Insert or update an app in the explore list")
+ @console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
+ @console_ns.response(200, "App updated successfully")
+ @console_ns.response(201, "App inserted successfully")
+ @console_ns.response(404, "App not found")
@only_edition_cloud
@admin_required
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("app_id", type=str, required=True, nullable=False, location="json")
- .add_argument("desc", type=str, location="json")
- .add_argument("copyright", type=str, location="json")
- .add_argument("privacy_policy", type=str, location="json")
- .add_argument("custom_disclaimer", type=str, location="json")
- .add_argument("language", type=supported_language, required=True, nullable=False, location="json")
- .add_argument("category", type=str, required=True, nullable=False, location="json")
- .add_argument("position", type=int, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ payload = InsertExploreAppPayload.model_validate(console_ns.payload)
- app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
+ app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
if not app:
- raise NotFound(f"App '{args['app_id']}' is not found")
+ raise NotFound(f"App '{payload.app_id}' is not found")
site = app.site
if not site:
- desc = args["desc"] or ""
- copy_right = args["copyright"] or ""
- privacy_policy = args["privacy_policy"] or ""
- custom_disclaimer = args["custom_disclaimer"] or ""
+ desc = payload.desc or ""
+ copy_right = payload.copyright or ""
+ privacy_policy = payload.privacy_policy or ""
+ custom_disclaimer = payload.custom_disclaimer or ""
else:
- desc = site.description or args["desc"] or ""
- copy_right = site.copyright or args["copyright"] or ""
- privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
- custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
+ desc = site.description or payload.desc or ""
+ copy_right = site.copyright or payload.copyright or ""
+ privacy_policy = site.privacy_policy or payload.privacy_policy or ""
+ custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
with Session(db.engine) as session:
recommended_app = session.execute(
- select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
+ select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none()
if not recommended_app:
@@ -102,9 +102,9 @@ class InsertExploreAppListApi(Resource):
copyright=copy_right,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
- language=args["language"],
- category=args["category"],
- position=args["position"],
+ language=payload.language,
+ category=payload.category,
+ position=payload.position,
)
db.session.add(recommended_app)
@@ -118,9 +118,9 @@ class InsertExploreAppListApi(Resource):
recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer
- recommended_app.language = args["language"]
- recommended_app.category = args["category"]
- recommended_app.position = args["position"]
+ recommended_app.language = payload.language
+ recommended_app.category = payload.category
+ recommended_app.position = payload.position
app.is_public = True
@@ -131,10 +131,10 @@ class InsertExploreAppListApi(Resource):
@console_ns.route("/admin/insert-explore-apps/")
class InsertExploreAppApi(Resource):
- @api.doc("delete_explore_app")
- @api.doc(description="Remove an app from the explore list")
- @api.doc(params={"app_id": "Application ID to remove"})
- @api.response(204, "App removed successfully")
+ @console_ns.doc("delete_explore_app")
+ @console_ns.doc(description="Remove an app from the explore list")
+ @console_ns.doc(params={"app_id": "Application ID to remove"})
+ @console_ns.response(204, "App removed successfully")
@only_edition_cloud
@admin_required
def delete(self, app_id):
diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py
index bd5862cbd0..9b0d4b1a78 100644
--- a/api/controllers/console/apikey.py
+++ b/api/controllers/console/apikey.py
@@ -11,7 +11,7 @@ from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.model import ApiToken, App
-from . import api, console_ns
+from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = {
@@ -24,6 +24,12 @@ api_key_fields = {
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
+api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
+
+api_key_list_model = console_ns.model(
+ "ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
+)
+
def _get_resource(resource_id, tenant_id, resource_model):
if resource_model == App:
@@ -52,7 +58,7 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None
max_keys = 10
- @marshal_with(api_key_list)
+ @marshal_with(api_key_list_model)
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
@@ -66,7 +72,7 @@ class BaseApiKeyListResource(Resource):
).all()
return {"items": keys}
- @marshal_with(api_key_fields)
+ @marshal_with(api_key_item_model)
@edit_permission_required
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
@@ -133,20 +139,20 @@ class BaseApiKeyResource(Resource):
@console_ns.route("/apps//api-keys")
class AppApiKeyListResource(BaseApiKeyListResource):
- @api.doc("get_app_api_keys")
- @api.doc(description="Get all API keys for an app")
- @api.doc(params={"resource_id": "App ID"})
- @api.response(200, "Success", api_key_list)
- def get(self, resource_id):
+ @console_ns.doc("get_app_api_keys")
+ @console_ns.doc(description="Get all API keys for an app")
+ @console_ns.doc(params={"resource_id": "App ID"})
+ @console_ns.response(200, "Success", api_key_list_model)
+ def get(self, resource_id): # type: ignore
"""Get all API keys for an app"""
return super().get(resource_id)
- @api.doc("create_app_api_key")
- @api.doc(description="Create a new API key for an app")
- @api.doc(params={"resource_id": "App ID"})
- @api.response(201, "API key created successfully", api_key_fields)
- @api.response(400, "Maximum keys exceeded")
- def post(self, resource_id):
+ @console_ns.doc("create_app_api_key")
+ @console_ns.doc(description="Create a new API key for an app")
+ @console_ns.doc(params={"resource_id": "App ID"})
+ @console_ns.response(201, "API key created successfully", api_key_item_model)
+ @console_ns.response(400, "Maximum keys exceeded")
+ def post(self, resource_id): # type: ignore
"""Create a new API key for an app"""
return super().post(resource_id)
@@ -158,10 +164,10 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.route("/apps//api-keys/")
class AppApiKeyResource(BaseApiKeyResource):
- @api.doc("delete_app_api_key")
- @api.doc(description="Delete an API key for an app")
- @api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
- @api.response(204, "API key deleted successfully")
+ @console_ns.doc("delete_app_api_key")
+ @console_ns.doc(description="Delete an API key for an app")
+ @console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
+ @console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id):
"""Delete an API key for an app"""
return super().delete(resource_id, api_key_id)
@@ -173,20 +179,20 @@ class AppApiKeyResource(BaseApiKeyResource):
@console_ns.route("/datasets//api-keys")
class DatasetApiKeyListResource(BaseApiKeyListResource):
- @api.doc("get_dataset_api_keys")
- @api.doc(description="Get all API keys for a dataset")
- @api.doc(params={"resource_id": "Dataset ID"})
- @api.response(200, "Success", api_key_list)
- def get(self, resource_id):
+ @console_ns.doc("get_dataset_api_keys")
+ @console_ns.doc(description="Get all API keys for a dataset")
+ @console_ns.doc(params={"resource_id": "Dataset ID"})
+ @console_ns.response(200, "Success", api_key_list_model)
+ def get(self, resource_id): # type: ignore
"""Get all API keys for a dataset"""
return super().get(resource_id)
- @api.doc("create_dataset_api_key")
- @api.doc(description="Create a new API key for a dataset")
- @api.doc(params={"resource_id": "Dataset ID"})
- @api.response(201, "API key created successfully", api_key_fields)
- @api.response(400, "Maximum keys exceeded")
- def post(self, resource_id):
+ @console_ns.doc("create_dataset_api_key")
+ @console_ns.doc(description="Create a new API key for a dataset")
+ @console_ns.doc(params={"resource_id": "Dataset ID"})
+ @console_ns.response(201, "API key created successfully", api_key_item_model)
+ @console_ns.response(400, "Maximum keys exceeded")
+ def post(self, resource_id): # type: ignore
"""Create a new API key for a dataset"""
return super().post(resource_id)
@@ -198,10 +204,10 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.route("/datasets//api-keys/")
class DatasetApiKeyResource(BaseApiKeyResource):
- @api.doc("delete_dataset_api_key")
- @api.doc(description="Delete an API key for a dataset")
- @api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
- @api.response(204, "API key deleted successfully")
+ @console_ns.doc("delete_dataset_api_key")
+ @console_ns.doc(description="Delete an API key for a dataset")
+ @console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
+ @console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py
index 075345d860..3bd61feb44 100644
--- a/api/controllers/console/app/advanced_prompt_template.py
+++ b/api/controllers/console/app/advanced_prompt_template.py
@@ -1,32 +1,39 @@
-from flask_restx import Resource, fields, reqparse
+from flask import request
+from flask_restx import Resource, fields
+from pydantic import BaseModel, Field
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
-parser = (
- reqparse.RequestParser()
- .add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
- .add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
- .add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
- .add_argument("model_name", type=str, required=True, location="args", help="Model name")
+
+class AdvancedPromptTemplateQuery(BaseModel):
+ app_mode: str = Field(..., description="Application mode")
+ model_mode: str = Field(..., description="Model mode")
+ has_context: str = Field(default="true", description="Whether has context")
+ model_name: str = Field(..., description="Model name")
+
+
+console_ns.schema_model(
+ AdvancedPromptTemplateQuery.__name__,
+ AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/app/prompt-templates")
class AdvancedPromptTemplateList(Resource):
- @api.doc("get_advanced_prompt_templates")
- @api.doc(description="Get advanced prompt templates based on app mode and model configuration")
- @api.expect(parser)
- @api.response(
+ @console_ns.doc("get_advanced_prompt_templates")
+ @console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
+ @console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
+ @console_ns.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
)
- @api.response(400, "Invalid request parameters")
+ @console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
def get(self):
- args = parser.parse_args()
+ args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- return AdvancedPromptTemplateService.get_prompt(args)
+ return AdvancedPromptTemplateService.get_prompt(args.model_dump())
diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py
index fde28fdb98..cfdb9cf417 100644
--- a/api/controllers/console/app/agent.py
+++ b/api/controllers/console/app/agent.py
@@ -1,6 +1,8 @@
-from flask_restx import Resource, fields, reqparse
+from flask import request
+from flask_restx import Resource, fields
+from pydantic import BaseModel, Field, field_validator
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from libs.helper import uuid_value
@@ -8,27 +10,40 @@ from libs.login import login_required
from models.model import AppMode
from services.agent_service import AgentService
-parser = (
- reqparse.RequestParser()
- .add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
- .add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class AgentLogQuery(BaseModel):
+ message_id: str = Field(..., description="Message UUID")
+ conversation_id: str = Field(..., description="Conversation UUID")
+
+ @field_validator("message_id", "conversation_id")
+ @classmethod
+ def validate_uuid(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+console_ns.schema_model(
+ AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps//agent/logs")
class AgentLogApi(Resource):
- @api.doc("get_agent_logs")
- @api.doc(description="Get agent execution logs for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(parser)
- @api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
- @api.response(400, "Invalid request parameters")
+ @console_ns.doc("get_agent_logs")
+ @console_ns.doc(description="Get agent execution logs for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[AgentLogQuery.__name__])
+ @console_ns.response(
+ 200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
+ )
+ @console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model):
"""Get agent logs"""
- args = parser.parse_args()
+ args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
+ return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)
diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py
index bc4113b5c7..3b6fb58931 100644
--- a/api/controllers/console/app/annotation.py
+++ b/api/controllers/console/app/annotation.py
@@ -1,10 +1,11 @@
-from typing import Literal
+from typing import Any, Literal
from flask import request
-from flask_restx import Resource, fields, marshal, marshal_with, reqparse
+from flask_restx import Resource, fields, marshal, marshal_with
+from pydantic import BaseModel, Field, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
@@ -15,29 +16,87 @@ from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,
annotation_hit_history_fields,
+ build_annotation_model,
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class AnnotationReplyPayload(BaseModel):
+ score_threshold: float = Field(..., description="Score threshold for annotation matching")
+ embedding_provider_name: str = Field(..., description="Embedding provider name")
+ embedding_model_name: str = Field(..., description="Embedding model name")
+
+
+class AnnotationSettingUpdatePayload(BaseModel):
+ score_threshold: float = Field(..., description="Score threshold")
+
+
+class AnnotationListQuery(BaseModel):
+ page: int = Field(default=1, ge=1, description="Page number")
+ limit: int = Field(default=20, ge=1, description="Page size")
+ keyword: str = Field(default="", description="Search keyword")
+
+
+class CreateAnnotationPayload(BaseModel):
+ message_id: str | None = Field(default=None, description="Message ID")
+ question: str | None = Field(default=None, description="Question text")
+ answer: str | None = Field(default=None, description="Answer text")
+ content: str | None = Field(default=None, description="Content text")
+ annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
+
+ @field_validator("message_id")
+ @classmethod
+ def validate_message_id(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class UpdateAnnotationPayload(BaseModel):
+ question: str | None = None
+ answer: str | None = None
+ content: str | None = None
+ annotation_reply: dict[str, Any] | None = None
+
+
+class AnnotationReplyStatusQuery(BaseModel):
+ action: Literal["enable", "disable"]
+
+
+class AnnotationFilePayload(BaseModel):
+ message_id: str = Field(..., description="Message ID")
+
+ @field_validator("message_id")
+ @classmethod
+ def validate_message_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+def reg(model: type[BaseModel]) -> None:
+ console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+reg(AnnotationReplyPayload)
+reg(AnnotationSettingUpdatePayload)
+reg(AnnotationListQuery)
+reg(CreateAnnotationPayload)
+reg(UpdateAnnotationPayload)
+reg(AnnotationReplyStatusQuery)
+reg(AnnotationFilePayload)
+
@console_ns.route("/apps//annotation-reply/")
class AnnotationReplyActionApi(Resource):
- @api.doc("annotation_reply_action")
- @api.doc(description="Enable or disable annotation reply for an app")
- @api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
- @api.expect(
- api.model(
- "AnnotationReplyActionRequest",
- {
- "score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
- "embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
- "embedding_model_name": fields.String(required=True, description="Embedding model name"),
- },
- )
- )
- @api.response(200, "Action completed successfully")
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("annotation_reply_action")
+ @console_ns.doc(description="Enable or disable annotation reply for an app")
+ @console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
+ @console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
+ @console_ns.response(200, "Action completed successfully")
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@@ -45,15 +104,9 @@ class AnnotationReplyActionApi(Resource):
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id)
- parser = (
- reqparse.RequestParser()
- .add_argument("score_threshold", required=True, type=float, location="json")
- .add_argument("embedding_provider_name", required=True, type=str, location="json")
- .add_argument("embedding_model_name", required=True, type=str, location="json")
- )
- args = parser.parse_args()
+ args = AnnotationReplyPayload.model_validate(console_ns.payload)
if action == "enable":
- result = AppAnnotationService.enable_app_annotation(args, app_id)
+ result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@@ -61,11 +114,11 @@ class AnnotationReplyActionApi(Resource):
@console_ns.route("/apps//annotation-setting")
class AppAnnotationSettingDetailApi(Resource):
- @api.doc("get_annotation_setting")
- @api.doc(description="Get annotation settings for an app")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Annotation settings retrieved successfully")
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("get_annotation_setting")
+ @console_ns.doc(description="Get annotation settings for an app")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Annotation settings retrieved successfully")
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@@ -78,21 +131,12 @@ class AppAnnotationSettingDetailApi(Resource):
@console_ns.route("/apps//annotation-settings/")
class AppAnnotationSettingUpdateApi(Resource):
- @api.doc("update_annotation_setting")
- @api.doc(description="Update annotation settings for an app")
- @api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
- @api.expect(
- api.model(
- "AnnotationSettingUpdateRequest",
- {
- "score_threshold": fields.Float(required=True, description="Score threshold"),
- "embedding_provider_name": fields.String(required=True, description="Embedding provider"),
- "embedding_model_name": fields.String(required=True, description="Embedding model"),
- },
- )
- )
- @api.response(200, "Settings updated successfully")
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("update_annotation_setting")
+ @console_ns.doc(description="Update annotation settings for an app")
+ @console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
+ @console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
+ @console_ns.response(200, "Settings updated successfully")
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@@ -101,20 +145,19 @@ class AppAnnotationSettingUpdateApi(Resource):
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
- parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
- args = parser.parse_args()
+ args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
- result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
+ result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
return result, 200
@console_ns.route("/apps//annotation-reply//status/")
class AnnotationReplyActionStatusApi(Resource):
- @api.doc("get_annotation_reply_action_status")
- @api.doc(description="Get status of annotation reply action job")
- @api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
- @api.response(200, "Job status retrieved successfully")
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("get_annotation_reply_action_status")
+ @console_ns.doc(description="Get status of annotation reply action job")
+ @console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
+ @console_ns.response(200, "Job status retrieved successfully")
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@@ -138,25 +181,21 @@ class AnnotationReplyActionStatusApi(Resource):
@console_ns.route("/apps//annotations")
class AnnotationApi(Resource):
- @api.doc("list_annotations")
- @api.doc(description="Get annotations for an app with pagination")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.parser()
- .add_argument("page", type=int, location="args", default=1, help="Page number")
- .add_argument("limit", type=int, location="args", default=20, help="Page size")
- .add_argument("keyword", type=str, location="args", default="", help="Search keyword")
- )
- @api.response(200, "Annotations retrieved successfully")
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("list_annotations")
+ @console_ns.doc(description="Get annotations for an app with pagination")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
+ @console_ns.response(200, "Annotations retrieved successfully")
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- keyword = request.args.get("keyword", default="", type=str)
+ args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+ page = args.page
+ limit = args.limit
+ keyword = args.keyword
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
@@ -169,23 +208,12 @@ class AnnotationApi(Resource):
}
return response, 200
- @api.doc("create_annotation")
- @api.doc(description="Create a new annotation for an app")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "CreateAnnotationRequest",
- {
- "message_id": fields.String(description="Message ID (optional)"),
- "question": fields.String(description="Question text (required when message_id not provided)"),
- "answer": fields.String(description="Answer text (use 'answer' or 'content')"),
- "content": fields.String(description="Content text (use 'answer' or 'content')"),
- "annotation_reply": fields.Raw(description="Annotation reply data"),
- },
- )
- )
- @api.response(201, "Annotation created successfully", annotation_fields)
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("create_annotation")
+ @console_ns.doc(description="Create a new annotation for an app")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
+ @console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@@ -194,16 +222,9 @@ class AnnotationApi(Resource):
@edit_permission_required
def post(self, app_id):
app_id = str(app_id)
- parser = (
- reqparse.RequestParser()
- .add_argument("message_id", required=False, type=uuid_value, location="json")
- .add_argument("question", required=False, type=str, location="json")
- .add_argument("answer", required=False, type=str, location="json")
- .add_argument("content", required=False, type=str, location="json")
- .add_argument("annotation_reply", required=False, type=dict, location="json")
- )
- args = parser.parse_args()
- annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
+ args = CreateAnnotationPayload.model_validate(console_ns.payload)
+ data = args.model_dump(exclude_none=True)
+ annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return annotation
@setup_required
@@ -235,11 +256,15 @@ class AnnotationApi(Resource):
@console_ns.route("/apps//annotations/export")
class AnnotationExportApi(Resource):
- @api.doc("export_annotations")
- @api.doc(description="Export all annotations for an app")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields)))
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("export_annotations")
+ @console_ns.doc(description="Export all annotations for an app")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(
+ 200,
+ "Annotations exported successfully",
+ console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
+ )
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@@ -251,22 +276,15 @@ class AnnotationExportApi(Resource):
return response, 200
-parser = (
- reqparse.RequestParser()
- .add_argument("question", required=True, type=str, location="json")
- .add_argument("answer", required=True, type=str, location="json")
-)
-
-
@console_ns.route("/apps//annotations/")
class AnnotationUpdateDeleteApi(Resource):
- @api.doc("update_delete_annotation")
- @api.doc(description="Update or delete an annotation")
- @api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
- @api.response(200, "Annotation updated successfully", annotation_fields)
- @api.response(204, "Annotation deleted successfully")
- @api.response(403, "Insufficient permissions")
- @api.expect(parser)
+ @console_ns.doc("update_delete_annotation")
+ @console_ns.doc(description="Update or delete an annotation")
+ @console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
+ @console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
+ @console_ns.response(204, "Annotation deleted successfully")
+ @console_ns.response(403, "Insufficient permissions")
+ @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -276,8 +294,10 @@ class AnnotationUpdateDeleteApi(Resource):
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
- args = parser.parse_args()
- annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
+ args = UpdateAnnotationPayload.model_validate(console_ns.payload)
+ annotation = AppAnnotationService.update_app_annotation_directly(
+ args.model_dump(exclude_none=True), app_id, annotation_id
+ )
return annotation
@setup_required
@@ -293,12 +313,12 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.route("/apps//annotations/batch-import")
class AnnotationBatchImportApi(Resource):
- @api.doc("batch_import_annotations")
- @api.doc(description="Batch import annotations from CSV file")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Batch import started successfully")
- @api.response(403, "Insufficient permissions")
- @api.response(400, "No file uploaded or too many files")
+ @console_ns.doc("batch_import_annotations")
+ @console_ns.doc(description="Batch import annotations from CSV file")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Batch import started successfully")
+ @console_ns.response(403, "Insufficient permissions")
+ @console_ns.response(400, "No file uploaded or too many files")
@setup_required
@login_required
@account_initialization_required
@@ -323,11 +343,11 @@ class AnnotationBatchImportApi(Resource):
@console_ns.route("/apps//annotations/batch-import-status/")
class AnnotationBatchImportStatusApi(Resource):
- @api.doc("get_batch_import_status")
- @api.doc(description="Get status of batch import job")
- @api.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
- @api.response(200, "Job status retrieved successfully")
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("get_batch_import_status")
+ @console_ns.doc(description="Get status of batch import job")
+ @console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
+ @console_ns.response(200, "Job status retrieved successfully")
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@@ -350,18 +370,27 @@ class AnnotationBatchImportStatusApi(Resource):
@console_ns.route("/apps//annotations//hit-histories")
class AnnotationHitHistoryListApi(Resource):
- @api.doc("list_annotation_hit_histories")
- @api.doc(description="Get hit histories for an annotation")
- @api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
- @api.expect(
- api.parser()
+ @console_ns.doc("list_annotation_hit_histories")
+ @console_ns.doc(description="Get hit histories for an annotation")
+ @console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
+ @console_ns.expect(
+ console_ns.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size")
)
- @api.response(
- 200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields))
+ @console_ns.response(
+ 200,
+ "Hit histories retrieved successfully",
+ console_ns.model(
+ "AnnotationHitHistoryList",
+ {
+ "data": fields.List(
+ fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
+ )
+ },
+ ),
)
- @api.response(403, "Insufficient permissions")
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index a487512961..62e997dae2 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -1,11 +1,14 @@
import uuid
+from typing import Literal
-from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
+from flask import request
+from flask_restx import Resource, fields, marshal, marshal_with
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
-from werkzeug.exceptions import BadRequest, abort
+from werkzeug.exceptions import BadRequest
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@@ -18,9 +21,16 @@ from controllers.console.wraps import (
from core.ops.ops_trace_manager import OpsTraceManager
from core.workflow.enums import NodeType
from extensions.ext_database import db
-from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
+from fields.app_fields import (
+ deleted_tool_fields,
+ model_config_fields,
+ model_config_partial_fields,
+ site_fields,
+ tag_fields,
+)
+from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
+from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required
-from libs.validators import validate_description_length
from models import App, Workflow
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
@@ -28,29 +38,229 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class AppListQuery(BaseModel):
+ page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
+ limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
+ mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
+ default="all", description="App mode filter"
+ )
+ name: str | None = Field(default=None, description="Filter by app name")
+ tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
+ is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
+
+ @field_validator("tag_ids", mode="before")
+ @classmethod
+ def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
+ if not value:
+ return None
+
+ if isinstance(value, str):
+ items = [item.strip() for item in value.split(",") if item.strip()]
+ elif isinstance(value, list):
+ items = [str(item).strip() for item in value if item and str(item).strip()]
+ else:
+ raise TypeError("Unsupported tag_ids type.")
+
+ if not items:
+ return None
+
+ try:
+ return [str(uuid.UUID(item)) for item in items]
+ except ValueError as exc:
+ raise ValueError("Invalid UUID format in tag_ids.") from exc
+
+
+class CreateAppPayload(BaseModel):
+ name: str = Field(..., min_length=1, description="App name")
+ description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
+ mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
+ icon_type: str | None = Field(default=None, description="Icon type")
+ icon: str | None = Field(default=None, description="Icon")
+ icon_background: str | None = Field(default=None, description="Icon background color")
+
+
+class UpdateAppPayload(BaseModel):
+ name: str = Field(..., min_length=1, description="App name")
+ description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
+ icon_type: str | None = Field(default=None, description="Icon type")
+ icon: str | None = Field(default=None, description="Icon")
+ icon_background: str | None = Field(default=None, description="Icon background color")
+ use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
+ max_active_requests: int | None = Field(default=None, description="Maximum active requests")
+
+
+class CopyAppPayload(BaseModel):
+ name: str | None = Field(default=None, description="Name for the copied app")
+ description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
+ icon_type: str | None = Field(default=None, description="Icon type")
+ icon: str | None = Field(default=None, description="Icon")
+ icon_background: str | None = Field(default=None, description="Icon background color")
+
+
+class AppExportQuery(BaseModel):
+ include_secret: bool = Field(default=False, description="Include secrets in export")
+ workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
+
+
+class AppNamePayload(BaseModel):
+ name: str = Field(..., min_length=1, description="Name to check")
+
+
+class AppIconPayload(BaseModel):
+ icon: str | None = Field(default=None, description="Icon data")
+ icon_background: str | None = Field(default=None, description="Icon background color")
+
+
+class AppSiteStatusPayload(BaseModel):
+ enable_site: bool = Field(..., description="Enable or disable site")
+
+
+class AppApiStatusPayload(BaseModel):
+ enable_api: bool = Field(..., description="Enable or disable API")
+
+
+class AppTracePayload(BaseModel):
+ enabled: bool = Field(..., description="Enable or disable tracing")
+ tracing_provider: str | None = Field(default=None, description="Tracing provider")
+
+ @field_validator("tracing_provider")
+ @classmethod
+ def validate_tracing_provider(cls, value: str | None, info) -> str | None:
+ if info.data.get("enabled") and not value:
+ raise ValueError("tracing_provider is required when enabled is True")
+ return value
+
+
+def reg(cls: type[BaseModel]):
+ console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+reg(AppListQuery)
+reg(CreateAppPayload)
+reg(UpdateAppPayload)
+reg(CopyAppPayload)
+reg(AppExportQuery)
+reg(AppNamePayload)
+reg(AppIconPayload)
+reg(AppSiteStatusPayload)
+reg(AppApiStatusPayload)
+reg(AppTracePayload)
+
+# Register models for flask_restx to avoid dict type issues in Swagger
+# Register base models first
+tag_model = console_ns.model("Tag", tag_fields)
+
+workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
+
+model_config_model = console_ns.model("ModelConfig", model_config_fields)
+
+model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
+
+deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
+
+site_model = console_ns.model("Site", site_fields)
+
+app_partial_model = console_ns.model(
+ "AppPartial",
+ {
+ "id": fields.String,
+ "name": fields.String,
+ "max_active_requests": fields.Raw(),
+ "description": fields.String(attribute="desc_or_prompt"),
+ "mode": fields.String(attribute="mode_compatible_with_agent"),
+ "icon_type": fields.String,
+ "icon": fields.String,
+ "icon_background": fields.String,
+ "icon_url": AppIconUrlField,
+ "model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
+ "workflow": fields.Nested(workflow_partial_model, allow_null=True),
+ "use_icon_as_answer_icon": fields.Boolean,
+ "created_by": fields.String,
+ "created_at": TimestampField,
+ "updated_by": fields.String,
+ "updated_at": TimestampField,
+ "tags": fields.List(fields.Nested(tag_model)),
+ "access_mode": fields.String,
+ "create_user_name": fields.String,
+ "author_name": fields.String,
+ "has_draft_trigger": fields.Boolean,
+ },
+)
+
+app_detail_model = console_ns.model(
+ "AppDetail",
+ {
+ "id": fields.String,
+ "name": fields.String,
+ "description": fields.String,
+ "mode": fields.String(attribute="mode_compatible_with_agent"),
+ "icon": fields.String,
+ "icon_background": fields.String,
+ "enable_site": fields.Boolean,
+ "enable_api": fields.Boolean,
+ "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
+ "workflow": fields.Nested(workflow_partial_model, allow_null=True),
+ "tracing": fields.Raw,
+ "use_icon_as_answer_icon": fields.Boolean,
+ "created_by": fields.String,
+ "created_at": TimestampField,
+ "updated_by": fields.String,
+ "updated_at": TimestampField,
+ "access_mode": fields.String,
+ "tags": fields.List(fields.Nested(tag_model)),
+ },
+)
+
+app_detail_with_site_model = console_ns.model(
+ "AppDetailWithSite",
+ {
+ "id": fields.String,
+ "name": fields.String,
+ "description": fields.String,
+ "mode": fields.String(attribute="mode_compatible_with_agent"),
+ "icon_type": fields.String,
+ "icon": fields.String,
+ "icon_background": fields.String,
+ "icon_url": AppIconUrlField,
+ "enable_site": fields.Boolean,
+ "enable_api": fields.Boolean,
+ "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
+ "workflow": fields.Nested(workflow_partial_model, allow_null=True),
+ "api_base_url": fields.String,
+ "use_icon_as_answer_icon": fields.Boolean,
+ "max_active_requests": fields.Integer,
+ "created_by": fields.String,
+ "created_at": TimestampField,
+ "updated_by": fields.String,
+ "updated_at": TimestampField,
+ "deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
+ "access_mode": fields.String,
+ "tags": fields.List(fields.Nested(tag_model)),
+ "site": fields.Nested(site_model),
+ },
+)
+
+app_pagination_model = console_ns.model(
+ "AppPagination",
+ {
+ "page": fields.Integer,
+ "limit": fields.Integer(attribute="per_page"),
+ "total": fields.Integer,
+ "has_more": fields.Boolean(attribute="has_next"),
+ "data": fields.List(fields.Nested(app_partial_model), attribute="items"),
+ },
+)
@console_ns.route("/apps")
class AppListApi(Resource):
- @api.doc("list_apps")
- @api.doc(description="Get list of applications with pagination and filtering")
- @api.expect(
- api.parser()
- .add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
- .add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
- .add_argument(
- "mode",
- type=str,
- location="args",
- choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
- default="all",
- help="App mode filter",
- )
- .add_argument("name", type=str, location="args", help="Filter by app name")
- .add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
- .add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
- )
- @api.response(200, "Success", app_pagination_fields)
+ @console_ns.doc("list_apps")
+ @console_ns.doc(description="Get list of applications with pagination and filtering")
+ @console_ns.expect(console_ns.models[AppListQuery.__name__])
+ @console_ns.response(200, "Success", app_pagination_model)
@setup_required
@login_required
@account_initialization_required
@@ -59,42 +269,12 @@ class AppListApi(Resource):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
- def uuid_list(value):
- try:
- return [str(uuid.UUID(v)) for v in value.split(",")]
- except ValueError:
- abort(400, message="Invalid UUID format in tag_ids.")
-
- parser = (
- reqparse.RequestParser()
- .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
- .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
- .add_argument(
- "mode",
- type=str,
- choices=[
- "completion",
- "chat",
- "advanced-chat",
- "workflow",
- "agent-chat",
- "channel",
- "all",
- ],
- default="all",
- location="args",
- required=False,
- )
- .add_argument("name", type=str, location="args", required=False)
- .add_argument("tag_ids", type=uuid_list, location="args", required=False)
- .add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
- )
-
- args = parser.parse_args()
+ args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+ args_dict = args.model_dump()
# get app list
app_service = AppService()
- app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
+ app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@@ -129,75 +309,54 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
- for _, node_data in workflow.walk_nodes():
- if node_data.get("type") in trigger_node_types:
- draft_trigger_app_ids.add(str(workflow.app_id))
- break
+ try:
+ for _, node_data in workflow.walk_nodes():
+ if node_data.get("type") in trigger_node_types:
+ draft_trigger_app_ids.add(str(workflow.app_id))
+ break
+ except Exception:
+ continue
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
- return marshal(app_pagination, app_pagination_fields), 200
+ return marshal(app_pagination, app_pagination_model), 200
- @api.doc("create_app")
- @api.doc(description="Create a new application")
- @api.expect(
- api.model(
- "CreateAppRequest",
- {
- "name": fields.String(required=True, description="App name"),
- "description": fields.String(description="App description (max 400 chars)"),
- "mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
- "icon_type": fields.String(description="Icon type"),
- "icon": fields.String(description="Icon"),
- "icon_background": fields.String(description="Icon background color"),
- },
- )
- )
- @api.response(201, "App created successfully", app_detail_fields)
- @api.response(403, "Insufficient permissions")
- @api.response(400, "Invalid request parameters")
+ @console_ns.doc("create_app")
+ @console_ns.doc(description="Create a new application")
+ @console_ns.expect(console_ns.models[CreateAppPayload.__name__])
+ @console_ns.response(201, "App created successfully", app_detail_model)
+ @console_ns.response(403, "Insufficient permissions")
+ @console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
- @marshal_with(app_detail_fields)
+ @marshal_with(app_detail_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("name", type=str, required=True, location="json")
- .add_argument("description", type=validate_description_length, location="json")
- .add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
- .add_argument("icon_type", type=str, location="json")
- .add_argument("icon", type=str, location="json")
- .add_argument("icon_background", type=str, location="json")
- )
- args = parser.parse_args()
-
- if "mode" not in args or args["mode"] is None:
- raise BadRequest("mode is required")
+ args = CreateAppPayload.model_validate(console_ns.payload)
app_service = AppService()
- app = app_service.create_app(current_tenant_id, args, current_user)
+ app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
return app, 201
@console_ns.route("/apps/")
class AppApi(Resource):
- @api.doc("get_app_detail")
- @api.doc(description="Get application details")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Success", app_detail_fields_with_site)
+ @console_ns.doc("get_app_detail")
+ @console_ns.doc(description="Get application details")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Success", app_detail_with_site_model)
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@get_app_model
- @marshal_with(app_detail_fields_with_site)
+ @marshal_with(app_detail_with_site_model)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
@@ -210,66 +369,43 @@ class AppApi(Resource):
return app_model
- @api.doc("update_app")
- @api.doc(description="Update application details")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "UpdateAppRequest",
- {
- "name": fields.String(required=True, description="App name"),
- "description": fields.String(description="App description (max 400 chars)"),
- "icon_type": fields.String(description="Icon type"),
- "icon": fields.String(description="Icon"),
- "icon_background": fields.String(description="Icon background color"),
- "use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
- "max_active_requests": fields.Integer(description="Maximum active requests"),
- },
- )
- )
- @api.response(200, "App updated successfully", app_detail_fields_with_site)
- @api.response(403, "Insufficient permissions")
- @api.response(400, "Invalid request parameters")
+ @console_ns.doc("update_app")
+ @console_ns.doc(description="Update application details")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
+ @console_ns.response(200, "App updated successfully", app_detail_with_site_model)
+ @console_ns.response(403, "Insufficient permissions")
+ @console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@edit_permission_required
- @marshal_with(app_detail_fields_with_site)
+ @marshal_with(app_detail_with_site_model)
def put(self, app_model):
"""Update app"""
- parser = (
- reqparse.RequestParser()
- .add_argument("name", type=str, required=True, nullable=False, location="json")
- .add_argument("description", type=validate_description_length, location="json")
- .add_argument("icon_type", type=str, location="json")
- .add_argument("icon", type=str, location="json")
- .add_argument("icon_background", type=str, location="json")
- .add_argument("use_icon_as_answer_icon", type=bool, location="json")
- .add_argument("max_active_requests", type=int, location="json")
- )
- args = parser.parse_args()
+ args = UpdateAppPayload.model_validate(console_ns.payload)
app_service = AppService()
args_dict: AppService.ArgsDict = {
- "name": args["name"],
- "description": args.get("description", ""),
- "icon_type": args.get("icon_type", ""),
- "icon": args.get("icon", ""),
- "icon_background": args.get("icon_background", ""),
- "use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
- "max_active_requests": args.get("max_active_requests", 0),
+ "name": args.name,
+ "description": args.description or "",
+ "icon_type": args.icon_type or "",
+ "icon": args.icon or "",
+ "icon_background": args.icon_background or "",
+ "use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
+ "max_active_requests": args.max_active_requests or 0,
}
app_model = app_service.update_app(app_model, args_dict)
return app_model
- @api.doc("delete_app")
- @api.doc(description="Delete application")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(204, "App deleted successfully")
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("delete_app")
+ @console_ns.doc(description="Delete application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(204, "App deleted successfully")
+ @console_ns.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@login_required
@@ -285,43 +421,24 @@ class AppApi(Resource):
@console_ns.route("/apps//copy")
class AppCopyApi(Resource):
- @api.doc("copy_app")
- @api.doc(description="Create a copy of an existing application")
- @api.doc(params={"app_id": "Application ID to copy"})
- @api.expect(
- api.model(
- "CopyAppRequest",
- {
- "name": fields.String(description="Name for the copied app"),
- "description": fields.String(description="Description for the copied app"),
- "icon_type": fields.String(description="Icon type"),
- "icon": fields.String(description="Icon"),
- "icon_background": fields.String(description="Icon background color"),
- },
- )
- )
- @api.response(201, "App copied successfully", app_detail_fields_with_site)
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("copy_app")
+ @console_ns.doc(description="Create a copy of an existing application")
+ @console_ns.doc(params={"app_id": "Application ID to copy"})
+ @console_ns.expect(console_ns.models[CopyAppPayload.__name__])
+ @console_ns.response(201, "App copied successfully", app_detail_with_site_model)
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@edit_permission_required
- @marshal_with(app_detail_fields_with_site)
+ @marshal_with(app_detail_with_site_model)
def post(self, app_model):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("name", type=str, location="json")
- .add_argument("description", type=validate_description_length, location="json")
- .add_argument("icon_type", type=str, location="json")
- .add_argument("icon", type=str, location="json")
- .add_argument("icon_background", type=str, location="json")
- )
- args = parser.parse_args()
+ args = CopyAppPayload.model_validate(console_ns.payload or {})
with Session(db.engine) as session:
import_service = AppDslService(session)
@@ -330,11 +447,11 @@ class AppCopyApi(Resource):
account=current_user,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
- name=args.get("name"),
- description=args.get("description"),
- icon_type=args.get("icon_type"),
- icon=args.get("icon"),
- icon_background=args.get("icon_background"),
+ name=args.name,
+ description=args.description,
+ icon_type=args.icon_type,
+ icon=args.icon,
+ icon_background=args.icon_background,
)
session.commit()
@@ -346,20 +463,16 @@ class AppCopyApi(Resource):
@console_ns.route("/apps//export")
class AppExportApi(Resource):
- @api.doc("export_app")
- @api.doc(description="Export application configuration as DSL")
- @api.doc(params={"app_id": "Application ID to export"})
- @api.expect(
- api.parser()
- .add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
- .add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
- )
- @api.response(
+ @console_ns.doc("export_app")
+ @console_ns.doc(description="Export application configuration as DSL")
+ @console_ns.doc(params={"app_id": "Application ID to export"})
+ @console_ns.expect(console_ns.models[AppExportQuery.__name__])
+ @console_ns.response(
200,
"App exported successfully",
- api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
+ console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
)
- @api.response(403, "Insufficient permissions")
+ @console_ns.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@login_required
@@ -367,145 +480,114 @@ class AppExportApi(Resource):
@edit_permission_required
def get(self, app_model):
"""Export app"""
- # Add include_secret params
- parser = (
- reqparse.RequestParser()
- .add_argument("include_secret", type=inputs.boolean, default=False, location="args")
- .add_argument("workflow_id", type=str, location="args")
- )
- args = parser.parse_args()
+ args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return {
"data": AppDslService.export_dsl(
- app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
+ app_model=app_model,
+ include_secret=args.include_secret,
+ workflow_id=args.workflow_id,
)
}
-parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
-
-
@console_ns.route("/apps//name")
class AppNameApi(Resource):
- @api.doc("check_app_name")
- @api.doc(description="Check if app name is available")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(parser)
- @api.response(200, "Name availability checked")
+ @console_ns.doc("check_app_name")
+ @console_ns.doc(description="Check if app name is available")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[AppNamePayload.__name__])
+ @console_ns.response(200, "Name availability checked")
@setup_required
@login_required
@account_initialization_required
@get_app_model
- @marshal_with(app_detail_fields)
+ @marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
- args = parser.parse_args()
+ args = AppNamePayload.model_validate(console_ns.payload)
app_service = AppService()
- app_model = app_service.update_app_name(app_model, args["name"])
+ app_model = app_service.update_app_name(app_model, args.name)
return app_model
@console_ns.route("/apps//icon")
class AppIconApi(Resource):
- @api.doc("update_app_icon")
- @api.doc(description="Update application icon")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "AppIconRequest",
- {
- "icon": fields.String(required=True, description="Icon data"),
- "icon_type": fields.String(description="Icon type"),
- "icon_background": fields.String(description="Icon background color"),
- },
- )
- )
- @api.response(200, "Icon updated successfully")
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("update_app_icon")
+ @console_ns.doc(description="Update application icon")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[AppIconPayload.__name__])
+ @console_ns.response(200, "Icon updated successfully")
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model
- @marshal_with(app_detail_fields)
+ @marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
- parser = (
- reqparse.RequestParser()
- .add_argument("icon", type=str, location="json")
- .add_argument("icon_background", type=str, location="json")
- )
- args = parser.parse_args()
+ args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
- app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
+ app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
return app_model
@console_ns.route("/apps//site-enable")
class AppSiteStatus(Resource):
- @api.doc("update_app_site_status")
- @api.doc(description="Enable or disable app site")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
- )
- )
- @api.response(200, "Site status updated successfully", app_detail_fields)
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("update_app_site_status")
+ @console_ns.doc(description="Enable or disable app site")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
+ @console_ns.response(200, "Site status updated successfully", app_detail_model)
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model
- @marshal_with(app_detail_fields)
+ @marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
- parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
- args = parser.parse_args()
+ args = AppSiteStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
- app_model = app_service.update_app_site_status(app_model, args["enable_site"])
+ app_model = app_service.update_app_site_status(app_model, args.enable_site)
return app_model
@console_ns.route("/apps//api-enable")
class AppApiStatus(Resource):
- @api.doc("update_app_api_status")
- @api.doc(description="Enable or disable app API")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
- )
- )
- @api.response(200, "API status updated successfully", app_detail_fields)
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("update_app_api_status")
+ @console_ns.doc(description="Enable or disable app API")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
+ @console_ns.response(200, "API status updated successfully", app_detail_model)
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
- @marshal_with(app_detail_fields)
+ @marshal_with(app_detail_model)
def post(self, app_model):
- parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
- args = parser.parse_args()
+ args = AppApiStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
- app_model = app_service.update_app_api_status(app_model, args["enable_api"])
+ app_model = app_service.update_app_api_status(app_model, args.enable_api)
return app_model
@console_ns.route("/apps//trace")
class AppTraceApi(Resource):
- @api.doc("get_app_trace")
- @api.doc(description="Get app tracing configuration")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Trace configuration retrieved successfully")
+ @console_ns.doc("get_app_trace")
+ @console_ns.doc(description="Get app tracing configuration")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Trace configuration retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -515,37 +597,24 @@ class AppTraceApi(Resource):
return app_trace_config
- @api.doc("update_app_trace")
- @api.doc(description="Update app tracing configuration")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "AppTraceRequest",
- {
- "enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
- "tracing_provider": fields.String(required=True, description="Tracing provider"),
- },
- )
- )
- @api.response(200, "Trace configuration updated successfully")
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("update_app_trace")
+ @console_ns.doc(description="Update app tracing configuration")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[AppTracePayload.__name__])
+ @console_ns.response(200, "Trace configuration updated successfully")
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id):
# add app trace
- parser = (
- reqparse.RequestParser()
- .add_argument("enabled", type=bool, required=True, location="json")
- .add_argument("tracing_provider", type=str, required=True, location="json")
- )
- args = parser.parse_args()
+ args = AppTracePayload.model_validate(console_ns.payload)
OpsTraceManager.update_app_tracing_config(
app_id=app_id,
- enabled=args["enabled"],
- tracing_provider=args["tracing_provider"],
+ enabled=args.enabled,
+ tracing_provider=args.tracing_provider,
)
return {"result": "success"}
diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py
index 02dbd42515..22e2aeb720 100644
--- a/api/controllers/console/app/app_import.py
+++ b/api/controllers/console/app/app_import.py
@@ -1,7 +1,7 @@
-from flask_restx import Resource, marshal_with, reqparse
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
-from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@@ -10,7 +10,11 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
-from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
+from fields.app_fields import (
+ app_import_check_dependencies_fields,
+ app_import_fields,
+ leaked_dependency_fields,
+)
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
@@ -19,33 +23,52 @@ from services.feature_service import FeatureService
from .. import console_ns
-parser = (
- reqparse.RequestParser()
- .add_argument("mode", type=str, required=True, location="json")
- .add_argument("yaml_content", type=str, location="json")
- .add_argument("yaml_url", type=str, location="json")
- .add_argument("name", type=str, location="json")
- .add_argument("description", type=str, location="json")
- .add_argument("icon_type", type=str, location="json")
- .add_argument("icon", type=str, location="json")
- .add_argument("icon_background", type=str, location="json")
- .add_argument("app_id", type=str, location="json")
+# Register models for flask_restx to avoid dict type issues in Swagger
+# Register base model first
+leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
+
+app_import_model = console_ns.model("AppImport", app_import_fields)
+
+# For nested models, need to replace nested dict with registered model
+app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
+app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
+app_import_check_dependencies_model = console_ns.model(
+ "AppImportCheckDependencies", app_import_check_dependencies_fields_copy
+)
+
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class AppImportPayload(BaseModel):
+ mode: str = Field(..., description="Import mode")
+ yaml_content: str | None = None
+ yaml_url: str | None = None
+ name: str | None = None
+ description: str | None = None
+ icon_type: str | None = None
+ icon: str | None = None
+ icon_background: str | None = None
+ app_id: str | None = None
+
+
+console_ns.schema_model(
+ AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
- @api.expect(parser)
+ @console_ns.expect(console_ns.models[AppImportPayload.__name__])
@setup_required
@login_required
@account_initialization_required
- @marshal_with(app_import_fields)
+ @marshal_with(app_import_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
- args = parser.parse_args()
+ args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with Session(db.engine) as session:
@@ -54,15 +77,15 @@ class AppImportApi(Resource):
account = current_user
result = import_service.import_app(
account=account,
- import_mode=args["mode"],
- yaml_content=args.get("yaml_content"),
- yaml_url=args.get("yaml_url"),
- name=args.get("name"),
- description=args.get("description"),
- icon_type=args.get("icon_type"),
- icon=args.get("icon"),
- icon_background=args.get("icon_background"),
- app_id=args.get("app_id"),
+ import_mode=args.mode,
+ yaml_content=args.yaml_content,
+ yaml_url=args.yaml_url,
+ name=args.name,
+ description=args.description,
+ icon_type=args.icon_type,
+ icon=args.icon,
+ icon_background=args.icon_background,
+ app_id=args.app_id,
)
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
@@ -82,7 +105,7 @@ class AppImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(app_import_fields)
+ @marshal_with(app_import_model)
@edit_permission_required
def post(self, import_id):
# Check user role first
@@ -108,7 +131,7 @@ class AppImportCheckDependenciesApi(Resource):
@login_required
@get_app_model
@account_initialization_required
- @marshal_with(app_import_check_dependencies_fields)
+ @marshal_with(app_import_check_dependencies_model)
@edit_permission_required
def get(self, app_model: App):
with Session(db.engine) as session:
diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py
index 8170ba271a..d344ede466 100644
--- a/api/controllers/console/app/audio.py
+++ b/api/controllers/console/app/audio.py
@@ -1,11 +1,12 @@
import logging
from flask import request
-from flask_restx import Resource, fields, reqparse
+from flask_restx import Resource, fields
+from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@@ -32,20 +33,41 @@ from services.errors.audio import (
)
logger = logging.getLogger(__name__)
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class TextToSpeechPayload(BaseModel):
+ message_id: str | None = Field(default=None, description="Message ID")
+ text: str = Field(..., description="Text to convert")
+ voice: str | None = Field(default=None, description="Voice name")
+ streaming: bool | None = Field(default=None, description="Whether to stream audio")
+
+
+class TextToSpeechVoiceQuery(BaseModel):
+ language: str = Field(..., description="Language code")
+
+
+console_ns.schema_model(
+ TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+console_ns.schema_model(
+ TextToSpeechVoiceQuery.__name__,
+ TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
@console_ns.route("/apps//audio-to-text")
class ChatMessageAudioApi(Resource):
- @api.doc("chat_message_audio_transcript")
- @api.doc(description="Transcript audio to text for chat messages")
- @api.doc(params={"app_id": "App ID"})
- @api.response(
+ @console_ns.doc("chat_message_audio_transcript")
+ @console_ns.doc(description="Transcript audio to text for chat messages")
+ @console_ns.doc(params={"app_id": "App ID"})
+ @console_ns.response(
200,
"Audio transcription successful",
- api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
+ console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
)
- @api.response(400, "Bad request - No audio uploaded or unsupported type")
- @api.response(413, "Audio file too large")
+ @console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
+ @console_ns.response(413, "Audio file too large")
@setup_required
@login_required
@account_initialization_required
@@ -89,43 +111,26 @@ class ChatMessageAudioApi(Resource):
@console_ns.route("/apps//text-to-audio")
class ChatMessageTextApi(Resource):
- @api.doc("chat_message_text_to_speech")
- @api.doc(description="Convert text to speech for chat messages")
- @api.doc(params={"app_id": "App ID"})
- @api.expect(
- api.model(
- "TextToSpeechRequest",
- {
- "message_id": fields.String(description="Message ID"),
- "text": fields.String(required=True, description="Text to convert to speech"),
- "voice": fields.String(description="Voice to use for TTS"),
- "streaming": fields.Boolean(description="Whether to stream the audio"),
- },
- )
- )
- @api.response(200, "Text to speech conversion successful")
- @api.response(400, "Bad request - Invalid parameters")
+ @console_ns.doc("chat_message_text_to_speech")
+ @console_ns.doc(description="Convert text to speech for chat messages")
+ @console_ns.doc(params={"app_id": "App ID"})
+ @console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
+ @console_ns.response(200, "Text to speech conversion successful")
+ @console_ns.response(400, "Bad request - Invalid parameters")
@get_app_model
@setup_required
@login_required
@account_initialization_required
def post(self, app_model: App):
try:
- parser = (
- reqparse.RequestParser()
- .add_argument("message_id", type=str, location="json")
- .add_argument("text", type=str, location="json")
- .add_argument("voice", type=str, location="json")
- .add_argument("streaming", type=bool, location="json")
- )
- args = parser.parse_args()
-
- message_id = args.get("message_id", None)
- text = args.get("text", None)
- voice = args.get("voice", None)
+ payload = TextToSpeechPayload.model_validate(console_ns.payload)
response = AudioService.transcript_tts(
- app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
+ app_model=app_model,
+ text=payload.text,
+ voice=payload.voice,
+ message_id=payload.message_id,
+ is_draft=True,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
@@ -156,24 +161,25 @@ class ChatMessageTextApi(Resource):
@console_ns.route("/apps//text-to-audio/voices")
class TextModesApi(Resource):
- @api.doc("get_text_to_speech_voices")
- @api.doc(description="Get available TTS voices for a specific language")
- @api.doc(params={"app_id": "App ID"})
- @api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code"))
- @api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")))
- @api.response(400, "Invalid language parameter")
+ @console_ns.doc("get_text_to_speech_voices")
+ @console_ns.doc(description="Get available TTS voices for a specific language")
+ @console_ns.doc(params={"app_id": "App ID"})
+ @console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
+ @console_ns.response(
+ 200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
+ )
+ @console_ns.response(400, "Invalid language parameter")
@get_app_model
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
try:
- parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
- args = parser.parse_args()
+ args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id,
- language=args["language"],
+ language=args.language,
)
return response
diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py
index d7bc3cc20d..2922121a54 100644
--- a/api/controllers/console/app/completion.py
+++ b/api/controllers/console/app/completion.py
@@ -1,11 +1,13 @@
import logging
+from typing import Any, Literal
from flask import request
-from flask_restx import Resource, fields, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.error import (
AppUnavailableError,
CompletionRequestError,
@@ -17,7 +19,6 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
-from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
@@ -32,50 +33,66 @@ from libs.login import current_user, login_required
from models import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
+from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class BaseMessagePayload(BaseModel):
+ inputs: dict[str, Any]
+ model_config_data: dict[str, Any] = Field(..., alias="model_config")
+ files: list[Any] | None = Field(default=None, description="Uploaded files")
+ response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
+ retriever_from: str = Field(default="dev", description="Retriever source")
+
+
+class CompletionMessagePayload(BaseMessagePayload):
+ query: str = Field(default="", description="Query text")
+
+
+class ChatMessagePayload(BaseMessagePayload):
+ query: str = Field(..., description="User query")
+ conversation_id: str | None = Field(default=None, description="Conversation ID")
+ parent_message_id: str | None = Field(default=None, description="Parent message ID")
+
+ @field_validator("conversation_id", "parent_message_id")
+ @classmethod
+ def validate_uuid(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+console_ns.schema_model(
+ CompletionMessagePayload.__name__,
+ CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
# define completion message api for user
@console_ns.route("/apps//completion-messages")
class CompletionMessageApi(Resource):
- @api.doc("create_completion_message")
- @api.doc(description="Generate completion message for debugging")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "CompletionMessageRequest",
- {
- "inputs": fields.Raw(required=True, description="Input variables"),
- "query": fields.String(description="Query text", default=""),
- "files": fields.List(fields.Raw(), description="Uploaded files"),
- "model_config": fields.Raw(required=True, description="Model configuration"),
- "response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
- "retriever_from": fields.String(default="dev", description="Retriever source"),
- },
- )
- )
- @api.response(200, "Completion generated successfully")
- @api.response(400, "Invalid request parameters")
- @api.response(404, "App not found")
+ @console_ns.doc("create_completion_message")
+ @console_ns.doc(description="Generate completion message for debugging")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
+ @console_ns.response(200, "Completion generated successfully")
+ @console_ns.response(400, "Invalid request parameters")
+ @console_ns.response(404, "App not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model):
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, location="json")
- .add_argument("query", type=str, location="json", default="")
- .add_argument("files", type=list, required=False, location="json")
- .add_argument("model_config", type=dict, required=True, location="json")
- .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
- .add_argument("retriever_from", type=str, required=False, default="dev", location="json")
- )
- args = parser.parse_args()
+ args_model = CompletionMessagePayload.model_validate(console_ns.payload)
+ args = args_model.model_dump(exclude_none=True, by_alias=True)
- streaming = args["response_mode"] != "blocking"
+ streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False
try:
@@ -110,10 +127,10 @@ class CompletionMessageApi(Resource):
@console_ns.route("/apps//completion-messages//stop")
class CompletionMessageStopApi(Resource):
- @api.doc("stop_completion_message")
- @api.doc(description="Stop a running completion message generation")
- @api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
- @api.response(200, "Task stopped successfully")
+ @console_ns.doc("stop_completion_message")
+ @console_ns.doc(description="Stop a running completion message generation")
+ @console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
+ @console_ns.response(200, "Task stopped successfully")
@setup_required
@login_required
@account_initialization_required
@@ -121,54 +138,36 @@ class CompletionMessageStopApi(Resource):
def post(self, app_model, task_id):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
- AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
+
+ AppTaskService.stop_task(
+ task_id=task_id,
+ invoke_from=InvokeFrom.DEBUGGER,
+ user_id=current_user.id,
+ app_mode=AppMode.value_of(app_model.mode),
+ )
return {"result": "success"}, 200
@console_ns.route("/apps//chat-messages")
class ChatMessageApi(Resource):
- @api.doc("create_chat_message")
- @api.doc(description="Generate chat message for debugging")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "ChatMessageRequest",
- {
- "inputs": fields.Raw(required=True, description="Input variables"),
- "query": fields.String(required=True, description="User query"),
- "files": fields.List(fields.Raw(), description="Uploaded files"),
- "model_config": fields.Raw(required=True, description="Model configuration"),
- "conversation_id": fields.String(description="Conversation ID"),
- "parent_message_id": fields.String(description="Parent message ID"),
- "response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
- "retriever_from": fields.String(default="dev", description="Retriever source"),
- },
- )
- )
- @api.response(200, "Chat message generated successfully")
- @api.response(400, "Invalid request parameters")
- @api.response(404, "App or conversation not found")
+ @console_ns.doc("create_chat_message")
+ @console_ns.doc(description="Generate chat message for debugging")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
+ @console_ns.response(200, "Chat message generated successfully")
+ @console_ns.response(400, "Invalid request parameters")
+ @console_ns.response(404, "App or conversation not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model):
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, location="json")
- .add_argument("query", type=str, required=True, location="json")
- .add_argument("files", type=list, required=False, location="json")
- .add_argument("model_config", type=dict, required=True, location="json")
- .add_argument("conversation_id", type=uuid_value, location="json")
- .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
- .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
- .add_argument("retriever_from", type=str, required=False, default="dev", location="json")
- )
- args = parser.parse_args()
+ args_model = ChatMessagePayload.model_validate(console_ns.payload)
+ args = args_model.model_dump(exclude_none=True, by_alias=True)
- streaming = args["response_mode"] != "blocking"
+ streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request)
@@ -209,10 +208,10 @@ class ChatMessageApi(Resource):
@console_ns.route("/apps//chat-messages//stop")
class ChatMessageStopApi(Resource):
- @api.doc("stop_chat_message")
- @api.doc(description="Stop a running chat message generation")
- @api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
- @api.response(200, "Task stopped successfully")
+ @console_ns.doc("stop_chat_message")
+ @console_ns.doc(description="Stop a running chat message generation")
+ @console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
+ @console_ns.response(200, "Task stopped successfully")
@setup_required
@login_required
@account_initialization_required
@@ -220,6 +219,12 @@ class ChatMessageStopApi(Resource):
def post(self, app_model, task_id):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
- AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
+
+ AppTaskService.stop_task(
+ task_id=task_id,
+ invoke_from=InvokeFrom.DEBUGGER,
+ user_id=current_user.id,
+ app_mode=AppMode.value_of(app_model.mode),
+ )
return {"result": "success"}, 200
diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py
index 57b6c314f3..c16dcfd91f 100644
--- a/api/controllers/console/app/conversation.py
+++ b/api/controllers/console/app/conversation.py
@@ -1,88 +1,352 @@
+from typing import Literal
+
import sqlalchemy as sa
-from flask import abort
-from flask_restx import Resource, marshal_with, reqparse
-from flask_restx.inputs import int_range
+from flask import abort, request
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import NotFound
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
-from fields.conversation_fields import (
- conversation_detail_fields,
- conversation_message_detail_fields,
- conversation_pagination_fields,
- conversation_with_summary_pagination_fields,
-)
+from fields.conversation_fields import MessageTextField
+from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range
-from libs.helper import DatetimeString
+from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class BaseConversationQuery(BaseModel):
+ keyword: str | None = Field(default=None, description="Search keyword")
+ start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
+ end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
+ annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
+ default="all", description="Annotation status filter"
+ )
+ page: int = Field(default=1, ge=1, le=99999, description="Page number")
+ limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
+
+ @field_validator("start", "end", mode="before")
+ @classmethod
+ def blank_to_none(cls, value: str | None) -> str | None:
+ if value == "":
+ return None
+ return value
+
+
+class CompletionConversationQuery(BaseConversationQuery):
+ pass
+
+
+class ChatConversationQuery(BaseConversationQuery):
+ sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
+ default="-updated_at", description="Sort field and direction"
+ )
+
+
+console_ns.schema_model(
+ CompletionConversationQuery.__name__,
+ CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ ChatConversationQuery.__name__,
+ ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+# Register models for flask_restx to avoid dict type issues in Swagger
+# Register in dependency order: base models first, then dependent models
+
+# Base models
+simple_account_model = console_ns.model(
+ "SimpleAccount",
+ {
+ "id": fields.String,
+ "name": fields.String,
+ "email": fields.String,
+ },
+)
+
+feedback_stat_model = console_ns.model(
+ "FeedbackStat",
+ {
+ "like": fields.Integer,
+ "dislike": fields.Integer,
+ },
+)
+
+status_count_model = console_ns.model(
+ "StatusCount",
+ {
+ "success": fields.Integer,
+ "failed": fields.Integer,
+ "partial_success": fields.Integer,
+ },
+)
+
+message_file_model = console_ns.model(
+ "MessageFile",
+ {
+ "id": fields.String,
+ "filename": fields.String,
+ "type": fields.String,
+ "url": fields.String,
+ "mime_type": fields.String,
+ "size": fields.Integer,
+ "transfer_method": fields.String,
+ "belongs_to": fields.String(default="user"),
+ "upload_file_id": fields.String(default=None),
+ },
+)
+
+agent_thought_model = console_ns.model(
+ "AgentThought",
+ {
+ "id": fields.String,
+ "chain_id": fields.String,
+ "message_id": fields.String,
+ "position": fields.Integer,
+ "thought": fields.String,
+ "tool": fields.String,
+ "tool_labels": fields.Raw,
+ "tool_input": fields.String,
+ "created_at": TimestampField,
+ "observation": fields.String,
+ "files": fields.List(fields.String),
+ },
+)
+
+simple_model_config_model = console_ns.model(
+ "SimpleModelConfig",
+ {
+ "model": fields.Raw(attribute="model_dict"),
+ "pre_prompt": fields.String,
+ },
+)
+
+model_config_model = console_ns.model(
+ "ModelConfig",
+ {
+ "opening_statement": fields.String,
+ "suggested_questions": fields.Raw,
+ "model": fields.Raw,
+ "user_input_form": fields.Raw,
+ "pre_prompt": fields.String,
+ "agent_mode": fields.Raw,
+ },
+)
+
+# Models that depend on simple_account_model
+feedback_model = console_ns.model(
+ "Feedback",
+ {
+ "rating": fields.String,
+ "content": fields.String,
+ "from_source": fields.String,
+ "from_end_user_id": fields.String,
+ "from_account": fields.Nested(simple_account_model, allow_null=True),
+ },
+)
+
+annotation_model = console_ns.model(
+ "Annotation",
+ {
+ "id": fields.String,
+ "question": fields.String,
+ "content": fields.String,
+ "account": fields.Nested(simple_account_model, allow_null=True),
+ "created_at": TimestampField,
+ },
+)
+
+annotation_hit_history_model = console_ns.model(
+ "AnnotationHitHistory",
+ {
+ "annotation_id": fields.String(attribute="id"),
+ "annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
+ "created_at": TimestampField,
+ },
+)
+
+# Simple message detail model
+simple_message_detail_model = console_ns.model(
+ "SimpleMessageDetail",
+ {
+ "inputs": FilesContainedField,
+ "query": fields.String,
+ "message": MessageTextField,
+ "answer": fields.String,
+ },
+)
+
+# Message detail model that depends on multiple models
+message_detail_model = console_ns.model(
+ "MessageDetail",
+ {
+ "id": fields.String,
+ "conversation_id": fields.String,
+ "inputs": FilesContainedField,
+ "query": fields.String,
+ "message": fields.Raw,
+ "message_tokens": fields.Integer,
+ "answer": fields.String(attribute="re_sign_file_url_answer"),
+ "answer_tokens": fields.Integer,
+ "provider_response_latency": fields.Float,
+ "from_source": fields.String,
+ "from_end_user_id": fields.String,
+ "from_account_id": fields.String,
+ "feedbacks": fields.List(fields.Nested(feedback_model)),
+ "workflow_run_id": fields.String,
+ "annotation": fields.Nested(annotation_model, allow_null=True),
+ "annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
+ "created_at": TimestampField,
+ "agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
+ "message_files": fields.List(fields.Nested(message_file_model)),
+ "metadata": fields.Raw(attribute="message_metadata_dict"),
+ "status": fields.String,
+ "error": fields.String,
+ "parent_message_id": fields.String,
+ },
+)
+
+# Conversation models
+conversation_fields_model = console_ns.model(
+ "Conversation",
+ {
+ "id": fields.String,
+ "status": fields.String,
+ "from_source": fields.String,
+ "from_end_user_id": fields.String,
+ "from_end_user_session_id": fields.String(),
+ "from_account_id": fields.String,
+ "from_account_name": fields.String,
+ "read_at": TimestampField,
+ "created_at": TimestampField,
+ "updated_at": TimestampField,
+ "annotation": fields.Nested(annotation_model, allow_null=True),
+ "model_config": fields.Nested(simple_model_config_model),
+ "user_feedback_stats": fields.Nested(feedback_stat_model),
+ "admin_feedback_stats": fields.Nested(feedback_stat_model),
+ "message": fields.Nested(simple_message_detail_model, attribute="first_message"),
+ },
+)
+
+conversation_pagination_model = console_ns.model(
+ "ConversationPagination",
+ {
+ "page": fields.Integer,
+ "limit": fields.Integer(attribute="per_page"),
+ "total": fields.Integer,
+ "has_more": fields.Boolean(attribute="has_next"),
+ "data": fields.List(fields.Nested(conversation_fields_model), attribute="items"),
+ },
+)
+
+conversation_message_detail_model = console_ns.model(
+ "ConversationMessageDetail",
+ {
+ "id": fields.String,
+ "status": fields.String,
+ "from_source": fields.String,
+ "from_end_user_id": fields.String,
+ "from_account_id": fields.String,
+ "created_at": TimestampField,
+ "model_config": fields.Nested(model_config_model),
+ "message": fields.Nested(message_detail_model, attribute="first_message"),
+ },
+)
+
+conversation_with_summary_model = console_ns.model(
+ "ConversationWithSummary",
+ {
+ "id": fields.String,
+ "status": fields.String,
+ "from_source": fields.String,
+ "from_end_user_id": fields.String,
+ "from_end_user_session_id": fields.String,
+ "from_account_id": fields.String,
+ "from_account_name": fields.String,
+ "name": fields.String,
+ "summary": fields.String(attribute="summary_or_query"),
+ "read_at": TimestampField,
+ "created_at": TimestampField,
+ "updated_at": TimestampField,
+ "annotated": fields.Boolean,
+ "model_config": fields.Nested(simple_model_config_model),
+ "message_count": fields.Integer,
+ "user_feedback_stats": fields.Nested(feedback_stat_model),
+ "admin_feedback_stats": fields.Nested(feedback_stat_model),
+ "status_count": fields.Nested(status_count_model),
+ },
+)
+
+conversation_with_summary_pagination_model = console_ns.model(
+ "ConversationWithSummaryPagination",
+ {
+ "page": fields.Integer,
+ "limit": fields.Integer(attribute="per_page"),
+ "total": fields.Integer,
+ "has_more": fields.Boolean(attribute="has_next"),
+ "data": fields.List(fields.Nested(conversation_with_summary_model), attribute="items"),
+ },
+)
+
+conversation_detail_model = console_ns.model(
+ "ConversationDetail",
+ {
+ "id": fields.String,
+ "status": fields.String,
+ "from_source": fields.String,
+ "from_end_user_id": fields.String,
+ "from_account_id": fields.String,
+ "created_at": TimestampField,
+ "updated_at": TimestampField,
+ "annotated": fields.Boolean,
+ "introduction": fields.String,
+ "model_config": fields.Nested(model_config_model),
+ "message_count": fields.Integer,
+ "user_feedback_stats": fields.Nested(feedback_stat_model),
+ "admin_feedback_stats": fields.Nested(feedback_stat_model),
+ },
+)
+
@console_ns.route("/apps//completion-conversations")
class CompletionConversationApi(Resource):
- @api.doc("list_completion_conversations")
- @api.doc(description="Get completion conversations with pagination and filtering")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.parser()
- .add_argument("keyword", type=str, location="args", help="Search keyword")
- .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
- .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
- .add_argument(
- "annotation_status",
- type=str,
- location="args",
- choices=["annotated", "not_annotated", "all"],
- default="all",
- help="Annotation status filter",
- )
- .add_argument("page", type=int, location="args", default=1, help="Page number")
- .add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
- )
- @api.response(200, "Success", conversation_pagination_fields)
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("list_completion_conversations")
+ @console_ns.doc(description="Get completion conversations with pagination and filtering")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
+ @console_ns.response(200, "Success", conversation_pagination_model)
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
- @marshal_with(conversation_pagination_fields)
+ @marshal_with(conversation_pagination_model)
@edit_permission_required
def get(self, app_model):
current_user, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("keyword", type=str, location="args")
- .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- .add_argument(
- "annotation_status",
- type=str,
- choices=["annotated", "not_annotated", "all"],
- default="all",
- location="args",
- )
- .add_argument("page", type=int_range(1, 99999), default=1, location="args")
- .add_argument("limit", type=int_range(1, 100), default=20, location="args")
- )
- args = parser.parse_args()
+ args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
)
- if args["keyword"]:
+ if args.keyword:
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
- Message.query.ilike(f"%{args['keyword']}%"),
- Message.answer.ilike(f"%{args['keyword']}%"),
+ Message.query.ilike(f"%{args.keyword}%"),
+ Message.answer.ilike(f"%{args.keyword}%"),
)
)
@@ -90,7 +354,7 @@ class CompletionConversationApi(Resource):
assert account.timezone is not None
try:
- start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
+ start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@@ -102,11 +366,11 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file
- if args["annotation_status"] == "annotated":
+ if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
- elif args["annotation_status"] == "not_annotated":
+ elif args.annotation_status == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
@@ -115,36 +379,36 @@ class CompletionConversationApi(Resource):
query = query.order_by(Conversation.created_at.desc())
- conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
+ conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
return conversations
@console_ns.route("/apps//completion-conversations/")
class CompletionConversationDetailApi(Resource):
- @api.doc("get_completion_conversation")
- @api.doc(description="Get completion conversation details with messages")
- @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
- @api.response(200, "Success", conversation_message_detail_fields)
- @api.response(403, "Insufficient permissions")
- @api.response(404, "Conversation not found")
+ @console_ns.doc("get_completion_conversation")
+ @console_ns.doc(description="Get completion conversation details with messages")
+ @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
+ @console_ns.response(200, "Success", conversation_message_detail_model)
+ @console_ns.response(403, "Insufficient permissions")
+ @console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
- @marshal_with(conversation_message_detail_fields)
+ @marshal_with(conversation_message_detail_model)
@edit_permission_required
def get(self, app_model, conversation_id):
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
- @api.doc("delete_completion_conversation")
- @api.doc(description="Delete a completion conversation")
- @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
- @api.response(204, "Conversation deleted successfully")
- @api.response(403, "Insufficient permissions")
- @api.response(404, "Conversation not found")
+ @console_ns.doc("delete_completion_conversation")
+ @console_ns.doc(description="Delete a completion conversation")
+ @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
+ @console_ns.response(204, "Conversation deleted successfully")
+ @console_ns.response(403, "Insufficient permissions")
+ @console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@account_initialization_required
@@ -164,69 +428,21 @@ class CompletionConversationDetailApi(Resource):
@console_ns.route("/apps//chat-conversations")
class ChatConversationApi(Resource):
- @api.doc("list_chat_conversations")
- @api.doc(description="Get chat conversations with pagination, filtering and summary")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.parser()
- .add_argument("keyword", type=str, location="args", help="Search keyword")
- .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
- .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
- .add_argument(
- "annotation_status",
- type=str,
- location="args",
- choices=["annotated", "not_annotated", "all"],
- default="all",
- help="Annotation status filter",
- )
- .add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
- .add_argument("page", type=int, location="args", default=1, help="Page number")
- .add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
- .add_argument(
- "sort_by",
- type=str,
- location="args",
- choices=["created_at", "-created_at", "updated_at", "-updated_at"],
- default="-updated_at",
- help="Sort field and direction",
- )
- )
- @api.response(200, "Success", conversation_with_summary_pagination_fields)
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("list_chat_conversations")
+ @console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
+ @console_ns.response(200, "Success", conversation_with_summary_pagination_model)
+ @console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
- @marshal_with(conversation_with_summary_pagination_fields)
+ @marshal_with(conversation_with_summary_pagination_model)
@edit_permission_required
def get(self, app_model):
current_user, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("keyword", type=str, location="args")
- .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- .add_argument(
- "annotation_status",
- type=str,
- choices=["annotated", "not_annotated", "all"],
- default="all",
- location="args",
- )
- .add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
- .add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- .add_argument(
- "sort_by",
- type=str,
- choices=["created_at", "-created_at", "updated_at", "-updated_at"],
- required=False,
- default="-updated_at",
- location="args",
- )
- )
- args = parser.parse_args()
+ args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
subquery = (
db.session.query(
@@ -238,8 +454,8 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
- if args["keyword"]:
- keyword_filter = f"%{args['keyword']}%"
+ if args.keyword:
+ keyword_filter = f"%{args.keyword}%"
query = (
query.join(
Message,
@@ -262,12 +478,12 @@ class ChatConversationApi(Resource):
assert account.timezone is not None
try:
- start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
+ start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
- match args["sort_by"]:
+ match args.sort_by:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
@@ -275,35 +491,27 @@ class ChatConversationApi(Resource):
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
- match args["sort_by"]:
+ match args.sort_by:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
- if args["annotation_status"] == "annotated":
+ if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
- elif args["annotation_status"] == "not_annotated":
+ elif args.annotation_status == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
- if args["message_count_gte"] and args["message_count_gte"] >= 1:
- query = (
- query.options(joinedload(Conversation.messages)) # type: ignore
- .join(Message, Message.conversation_id == Conversation.id)
- .group_by(Conversation.id)
- .having(func.count(Message.id) >= args["message_count_gte"])
- )
-
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
- match args["sort_by"]:
+ match args.sort_by:
case "created_at":
query = query.order_by(Conversation.created_at.asc())
case "-created_at":
@@ -315,36 +523,36 @@ class ChatConversationApi(Resource):
case _:
query = query.order_by(Conversation.created_at.desc())
- conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
+ conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
return conversations
@console_ns.route("/apps//chat-conversations/")
class ChatConversationDetailApi(Resource):
- @api.doc("get_chat_conversation")
- @api.doc(description="Get chat conversation details")
- @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
- @api.response(200, "Success", conversation_detail_fields)
- @api.response(403, "Insufficient permissions")
- @api.response(404, "Conversation not found")
+ @console_ns.doc("get_chat_conversation")
+ @console_ns.doc(description="Get chat conversation details")
+ @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
+ @console_ns.response(200, "Success", conversation_detail_model)
+ @console_ns.response(403, "Insufficient permissions")
+ @console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
- @marshal_with(conversation_detail_fields)
+ @marshal_with(conversation_detail_model)
@edit_permission_required
def get(self, app_model, conversation_id):
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
- @api.doc("delete_chat_conversation")
- @api.doc(description="Delete a chat conversation")
- @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
- @api.response(204, "Conversation deleted successfully")
- @api.response(403, "Insufficient permissions")
- @api.response(404, "Conversation not found")
+ @console_ns.doc("delete_chat_conversation")
+ @console_ns.doc(description="Delete a chat conversation")
+ @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
+ @console_ns.response(204, "Conversation deleted successfully")
+ @console_ns.response(403, "Insufficient permissions")
+ @console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py
index d4c0b5697f..368a6112ba 100644
--- a/api/controllers/console/app/conversation_variables.py
+++ b/api/controllers/console/app/conversation_variables.py
@@ -1,46 +1,68 @@
-from flask_restx import Resource, marshal_with, reqparse
+from flask import request
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
-from fields.conversation_variable_fields import paginated_conversation_variable_fields
+from fields.conversation_variable_fields import (
+ conversation_variable_fields,
+ paginated_conversation_variable_fields,
+)
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class ConversationVariablesQuery(BaseModel):
+ conversation_id: str = Field(..., description="Conversation ID to filter variables")
+
+
+console_ns.schema_model(
+ ConversationVariablesQuery.__name__,
+ ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+# Register models for flask_restx to avoid dict type issues in Swagger
+# Register base model first
+conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
+
+# For nested models, need to replace nested dict with registered model
+paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
+paginated_conversation_variable_fields_copy["data"] = fields.List(
+ fields.Nested(conversation_variable_model), attribute="data"
+)
+paginated_conversation_variable_model = console_ns.model(
+ "PaginatedConversationVariable", paginated_conversation_variable_fields_copy
+)
+
@console_ns.route("/apps//conversation-variables")
class ConversationVariablesApi(Resource):
- @api.doc("get_conversation_variables")
- @api.doc(description="Get conversation variables for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.parser().add_argument(
- "conversation_id", type=str, location="args", help="Conversation ID to filter variables"
- )
- )
- @api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields)
+ @console_ns.doc("get_conversation_variables")
+ @console_ns.doc(description="Get conversation variables for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
+ @console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
- @marshal_with(paginated_conversation_variable_fields)
+ @marshal_with(paginated_conversation_variable_model)
def get(self, app_model):
- parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
- args = parser.parse_args()
+ args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at)
)
- if args["conversation_id"]:
- stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
- else:
- raise ValueError("conversation_id is required")
+ stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
# NOTE: This is a temporary solution to avoid performance issues.
page = 1
diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py
index 54a101946c..b4fc44767a 100644
--- a/api/controllers/console/app/generator.py
+++ b/api/controllers/console/app/generator.py
@@ -1,8 +1,10 @@
from collections.abc import Sequence
+from typing import Any
-from flask_restx import Resource, fields, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, Field
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
@@ -21,43 +23,70 @@ from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class RuleGeneratePayload(BaseModel):
+ instruction: str = Field(..., description="Rule generation instruction")
+ model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
+ no_variable: bool = Field(default=False, description="Whether to exclude variables")
+
+
+class RuleCodeGeneratePayload(RuleGeneratePayload):
+ code_language: str = Field(default="javascript", description="Programming language for code generation")
+
+
+class RuleStructuredOutputPayload(BaseModel):
+ instruction: str = Field(..., description="Structured output generation instruction")
+ model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
+
+
+class InstructionGeneratePayload(BaseModel):
+ flow_id: str = Field(..., description="Workflow/Flow ID")
+ node_id: str = Field(default="", description="Node ID for workflow context")
+ current: str = Field(default="", description="Current instruction text")
+ language: str = Field(default="javascript", description="Programming language (javascript/python)")
+ instruction: str = Field(..., description="Instruction for generation")
+ model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
+ ideal_output: str = Field(default="", description="Expected ideal output")
+
+
+class InstructionTemplatePayload(BaseModel):
+ type: str = Field(..., description="Instruction template type")
+
+
+def reg(cls: type[BaseModel]):
+ console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+reg(RuleGeneratePayload)
+reg(RuleCodeGeneratePayload)
+reg(RuleStructuredOutputPayload)
+reg(InstructionGeneratePayload)
+reg(InstructionTemplatePayload)
+
@console_ns.route("/rule-generate")
class RuleGenerateApi(Resource):
- @api.doc("generate_rule_config")
- @api.doc(description="Generate rule configuration using LLM")
- @api.expect(
- api.model(
- "RuleGenerateRequest",
- {
- "instruction": fields.String(required=True, description="Rule generation instruction"),
- "model_config": fields.Raw(required=True, description="Model configuration"),
- "no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
- },
- )
- )
- @api.response(200, "Rule configuration generated successfully")
- @api.response(400, "Invalid request parameters")
- @api.response(402, "Provider quota exceeded")
+ @console_ns.doc("generate_rule_config")
+ @console_ns.doc(description="Generate rule configuration using LLM")
+ @console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
+ @console_ns.response(200, "Rule configuration generated successfully")
+ @console_ns.response(400, "Invalid request parameters")
+ @console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("instruction", type=str, required=True, nullable=False, location="json")
- .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
- .add_argument("no_variable", type=bool, required=True, default=False, location="json")
- )
- args = parser.parse_args()
+ args = RuleGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id,
- instruction=args["instruction"],
- model_config=args["model_config"],
- no_variable=args["no_variable"],
+ instruction=args.instruction,
+ model_config=args.model_config_data,
+ no_variable=args.no_variable,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -73,44 +102,25 @@ class RuleGenerateApi(Resource):
@console_ns.route("/rule-code-generate")
class RuleCodeGenerateApi(Resource):
- @api.doc("generate_rule_code")
- @api.doc(description="Generate code rules using LLM")
- @api.expect(
- api.model(
- "RuleCodeGenerateRequest",
- {
- "instruction": fields.String(required=True, description="Code generation instruction"),
- "model_config": fields.Raw(required=True, description="Model configuration"),
- "no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
- "code_language": fields.String(
- default="javascript", description="Programming language for code generation"
- ),
- },
- )
- )
- @api.response(200, "Code rules generated successfully")
- @api.response(400, "Invalid request parameters")
- @api.response(402, "Provider quota exceeded")
+ @console_ns.doc("generate_rule_code")
+ @console_ns.doc(description="Generate code rules using LLM")
+ @console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
+ @console_ns.response(200, "Code rules generated successfully")
+ @console_ns.response(400, "Invalid request parameters")
+ @console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("instruction", type=str, required=True, nullable=False, location="json")
- .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
- .add_argument("no_variable", type=bool, required=True, default=False, location="json")
- .add_argument("code_language", type=str, required=False, default="javascript", location="json")
- )
- args = parser.parse_args()
+ args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id,
- instruction=args["instruction"],
- model_config=args["model_config"],
- code_language=args["code_language"],
+ instruction=args.instruction,
+ model_config=args.model_config_data,
+ code_language=args.code_language,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -126,37 +136,24 @@ class RuleCodeGenerateApi(Resource):
@console_ns.route("/rule-structured-output-generate")
class RuleStructuredOutputGenerateApi(Resource):
- @api.doc("generate_structured_output")
- @api.doc(description="Generate structured output rules using LLM")
- @api.expect(
- api.model(
- "StructuredOutputGenerateRequest",
- {
- "instruction": fields.String(required=True, description="Structured output generation instruction"),
- "model_config": fields.Raw(required=True, description="Model configuration"),
- },
- )
- )
- @api.response(200, "Structured output generated successfully")
- @api.response(400, "Invalid request parameters")
- @api.response(402, "Provider quota exceeded")
+ @console_ns.doc("generate_structured_output")
+ @console_ns.doc(description="Generate structured output rules using LLM")
+ @console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
+ @console_ns.response(200, "Structured output generated successfully")
+ @console_ns.response(400, "Invalid request parameters")
+ @console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("instruction", type=str, required=True, nullable=False, location="json")
- .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id,
- instruction=args["instruction"],
- model_config=args["model_config"],
+ instruction=args.instruction,
+ model_config=args.model_config_data,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -172,102 +169,79 @@ class RuleStructuredOutputGenerateApi(Resource):
@console_ns.route("/instruction-generate")
class InstructionGenerateApi(Resource):
- @api.doc("generate_instruction")
- @api.doc(description="Generate instruction for workflow nodes or general use")
- @api.expect(
- api.model(
- "InstructionGenerateRequest",
- {
- "flow_id": fields.String(required=True, description="Workflow/Flow ID"),
- "node_id": fields.String(description="Node ID for workflow context"),
- "current": fields.String(description="Current instruction text"),
- "language": fields.String(default="javascript", description="Programming language (javascript/python)"),
- "instruction": fields.String(required=True, description="Instruction for generation"),
- "model_config": fields.Raw(required=True, description="Model configuration"),
- "ideal_output": fields.String(description="Expected ideal output"),
- },
- )
- )
- @api.response(200, "Instruction generated successfully")
- @api.response(400, "Invalid request parameters or flow/workflow not found")
- @api.response(402, "Provider quota exceeded")
+ @console_ns.doc("generate_instruction")
+ @console_ns.doc(description="Generate instruction for workflow nodes or general use")
+ @console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
+ @console_ns.response(200, "Instruction generated successfully")
+ @console_ns.response(400, "Invalid request parameters or flow/workflow not found")
+ @console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("flow_id", type=str, required=True, default="", location="json")
- .add_argument("node_id", type=str, required=False, default="", location="json")
- .add_argument("current", type=str, required=False, default="", location="json")
- .add_argument("language", type=str, required=False, default="javascript", location="json")
- .add_argument("instruction", type=str, required=True, nullable=False, location="json")
- .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
- .add_argument("ideal_output", type=str, required=False, default="", location="json")
- )
- args = parser.parse_args()
+ args = InstructionGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
- (p for p in providers if p.is_accept_language(args["language"])), None
+ (p for p in providers if p.is_accept_language(args.language)), None
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
- if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
- app = db.session.query(App).where(App.id == args["flow_id"]).first()
+ if (args.current in (code_template, "")) and args.node_id != "":
+ app = db.session.query(App).where(App.id == args.flow_id).first()
if not app:
- return {"error": f"app {args['flow_id']} not found"}, 400
+ return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
- return {"error": f"workflow {args['flow_id']} not found"}, 400
+ return {"error": f"workflow {args.flow_id} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]
- node = [node for node in nodes if node["id"] == args["node_id"]]
+ node = [node for node in nodes if node["id"] == args.node_id]
if len(node) == 0:
- return {"error": f"node {args['node_id']} not found"}, 400
+ return {"error": f"node {args.node_id} not found"}, 400
node_type = node[0]["data"]["type"]
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_tenant_id,
- instruction=args["instruction"],
- model_config=args["model_config"],
+ instruction=args.instruction,
+ model_config=args.model_config_data,
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_tenant_id,
- instruction=args["instruction"],
- model_config=args["model_config"],
+ instruction=args.instruction,
+ model_config=args.model_config_data,
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_tenant_id,
- instruction=args["instruction"],
- model_config=args["model_config"],
- code_language=args["language"],
+ instruction=args.instruction,
+ model_config=args.model_config_data,
+ code_language=args.language,
)
case _:
return {"error": f"invalid node type: {node_type}"}
- if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
+ if args.node_id == "" and args.current != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_tenant_id,
- flow_id=args["flow_id"],
- current=args["current"],
- instruction=args["instruction"],
- model_config=args["model_config"],
- ideal_output=args["ideal_output"],
+ flow_id=args.flow_id,
+ current=args.current,
+ instruction=args.instruction,
+ model_config=args.model_config_data,
+ ideal_output=args.ideal_output,
)
- if args["node_id"] != "" and args["current"] != "": # For workflow node
+ if args.node_id != "" and args.current != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_tenant_id,
- flow_id=args["flow_id"],
- node_id=args["node_id"],
- current=args["current"],
- instruction=args["instruction"],
- model_config=args["model_config"],
- ideal_output=args["ideal_output"],
+ flow_id=args.flow_id,
+ node_id=args.node_id,
+ current=args.current,
+ instruction=args.instruction,
+ model_config=args.model_config_data,
+ ideal_output=args.ideal_output,
workflow_service=WorkflowService(),
)
return {"error": "incompatible parameters"}, 400
@@ -283,26 +257,17 @@ class InstructionGenerateApi(Resource):
@console_ns.route("/instruction-generate/template")
class InstructionGenerationTemplateApi(Resource):
- @api.doc("get_instruction_template")
- @api.doc(description="Get instruction generation template")
- @api.expect(
- api.model(
- "InstructionTemplateRequest",
- {
- "instruction": fields.String(required=True, description="Template instruction"),
- "ideal_output": fields.String(description="Expected ideal output"),
- },
- )
- )
- @api.response(200, "Template retrieved successfully")
- @api.response(400, "Invalid request parameters")
+ @console_ns.doc("get_instruction_template")
+ @console_ns.doc(description="Get instruction generation template")
+ @console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
+ @console_ns.response(200, "Template retrieved successfully")
+ @console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
def post(self):
- parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
- args = parser.parse_args()
- match args["type"]:
+ args = InstructionTemplatePayload.model_validate(console_ns.payload)
+ match args.type:
case "prompt":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
@@ -312,4 +277,4 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
- raise ValueError(f"Invalid type: {args['type']}")
+ raise ValueError(f"Invalid type: {args.type}")
diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py
index 3700c6b1d0..dd982b6d7b 100644
--- a/api/controllers/console/app/mcp_server.py
+++ b/api/controllers/console/app/mcp_server.py
@@ -1,10 +1,11 @@
import json
from enum import StrEnum
-from flask_restx import Resource, fields, marshal_with, reqparse
+from flask_restx import Resource, marshal_with
+from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
@@ -12,64 +13,72 @@ from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+# Register model for flask_restx to avoid dict type issues in Swagger
+app_server_model = console_ns.model("AppServer", app_server_fields)
+
class AppMCPServerStatus(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
+class MCPServerCreatePayload(BaseModel):
+ description: str | None = Field(default=None, description="Server description")
+ parameters: dict = Field(..., description="Server parameters configuration")
+
+
+class MCPServerUpdatePayload(BaseModel):
+ id: str = Field(..., description="Server ID")
+ description: str | None = Field(default=None, description="Server description")
+ parameters: dict = Field(..., description="Server parameters configuration")
+ status: str | None = Field(default=None, description="Server status")
+
+
+for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
+ console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
@console_ns.route("/apps//server")
class AppMCPServerController(Resource):
- @api.doc("get_app_mcp_server")
- @api.doc(description="Get MCP server configuration for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
+ @console_ns.doc("get_app_mcp_server")
+ @console_ns.doc(description="Get MCP server configuration for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
@login_required
@account_initialization_required
@setup_required
@get_app_model
- @marshal_with(app_server_fields)
+ @marshal_with(app_server_model)
def get(self, app_model):
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
return server
- @api.doc("create_app_mcp_server")
- @api.doc(description="Create MCP server configuration for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "MCPServerCreateRequest",
- {
- "description": fields.String(description="Server description"),
- "parameters": fields.Raw(required=True, description="Server parameters configuration"),
- },
- )
- )
- @api.response(201, "MCP server configuration created successfully", app_server_fields)
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("create_app_mcp_server")
+ @console_ns.doc(description="Create MCP server configuration for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
+ @console_ns.response(201, "MCP server configuration created successfully", app_server_model)
+ @console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@get_app_model
@login_required
@setup_required
- @marshal_with(app_server_fields)
+ @marshal_with(app_server_model)
@edit_permission_required
def post(self, app_model):
_, current_tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("description", type=str, required=False, location="json")
- .add_argument("parameters", type=dict, required=True, location="json")
- )
- args = parser.parse_args()
+ payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
- description = args.get("description")
+ description = payload.description
if not description:
description = app_model.description or ""
server = AppMCPServer(
name=app_model.name,
description=description,
- parameters=json.dumps(args["parameters"], ensure_ascii=False),
+ parameters=json.dumps(payload.parameters, ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
tenant_id=current_tenant_id,
@@ -79,43 +88,26 @@ class AppMCPServerController(Resource):
db.session.commit()
return server
- @api.doc("update_app_mcp_server")
- @api.doc(description="Update MCP server configuration for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "MCPServerUpdateRequest",
- {
- "id": fields.String(required=True, description="Server ID"),
- "description": fields.String(description="Server description"),
- "parameters": fields.Raw(required=True, description="Server parameters configuration"),
- "status": fields.String(description="Server status"),
- },
- )
- )
- @api.response(200, "MCP server configuration updated successfully", app_server_fields)
- @api.response(403, "Insufficient permissions")
- @api.response(404, "Server not found")
+ @console_ns.doc("update_app_mcp_server")
+ @console_ns.doc(description="Update MCP server configuration for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
+ @console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
+ @console_ns.response(403, "Insufficient permissions")
+ @console_ns.response(404, "Server not found")
@get_app_model
@login_required
@setup_required
@account_initialization_required
- @marshal_with(app_server_fields)
+ @marshal_with(app_server_model)
@edit_permission_required
def put(self, app_model):
- parser = (
- reqparse.RequestParser()
- .add_argument("id", type=str, required=True, location="json")
- .add_argument("description", type=str, required=False, location="json")
- .add_argument("parameters", type=dict, required=True, location="json")
- .add_argument("status", type=str, required=False, location="json")
- )
- args = parser.parse_args()
- server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
+ payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
+ server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
if not server:
raise NotFound()
- description = args.get("description")
+ description = payload.description
if description is None:
pass
elif not description:
@@ -123,27 +115,27 @@ class AppMCPServerController(Resource):
else:
server.description = description
- server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
- if args["status"]:
- if args["status"] not in [status.value for status in AppMCPServerStatus]:
+ server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
+ if payload.status:
+ if payload.status not in [status.value for status in AppMCPServerStatus]:
raise ValueError("Invalid status")
- server.status = args["status"]
+ server.status = payload.status
db.session.commit()
return server
@console_ns.route("/apps//server/refresh")
class AppMCPServerRefreshController(Resource):
- @api.doc("refresh_app_mcp_server")
- @api.doc(description="Refresh MCP server configuration and regenerate server code")
- @api.doc(params={"server_id": "Server ID"})
- @api.response(200, "MCP server refreshed successfully", app_server_fields)
- @api.response(403, "Insufficient permissions")
- @api.response(404, "Server not found")
+ @console_ns.doc("refresh_app_mcp_server")
+ @console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
+ @console_ns.doc(params={"server_id": "Server ID"})
+ @console_ns.response(200, "MCP server refreshed successfully", app_server_model)
+ @console_ns.response(403, "Insufficient permissions")
+ @console_ns.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
- @marshal_with(app_server_fields)
+ @marshal_with(app_server_model)
@edit_permission_required
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py
index 3f66278940..12ada8b798 100644
--- a/api/controllers/console/app/message.py
+++ b/api/controllers/console/app/message.py
@@ -1,11 +1,13 @@
import logging
+from typing import Literal
-from flask_restx import Resource, fields, marshal_with, reqparse
-from flask_restx.inputs import int_range
+from flask import request
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
@@ -23,8 +25,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
-from fields.conversation_fields import message_detail_fields
-from libs.helper import uuid_value
+from fields.raws import FilesContainedField
+from libs.helper import TimestampField, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
@@ -33,55 +35,217 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService
logger = logging.getLogger(__name__)
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class ChatMessagesQuery(BaseModel):
+ conversation_id: str = Field(..., description="Conversation ID")
+ first_id: str | None = Field(default=None, description="First message ID for pagination")
+ limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
+
+ @field_validator("first_id", mode="before")
+ @classmethod
+ def empty_to_none(cls, value: str | None) -> str | None:
+ if value == "":
+ return None
+ return value
+
+ @field_validator("conversation_id", "first_id")
+ @classmethod
+ def validate_uuid(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class MessageFeedbackPayload(BaseModel):
+ message_id: str = Field(..., description="Message ID")
+ rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
+ content: str | None = Field(default=None, description="Feedback content")
+
+ @field_validator("message_id")
+ @classmethod
+ def validate_message_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+class FeedbackExportQuery(BaseModel):
+ from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
+ rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
+ has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
+ start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
+ end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
+ format: Literal["csv", "json"] = Field(default="csv", description="Export format")
+
+ @field_validator("has_comment", mode="before")
+ @classmethod
+ def parse_bool(cls, value: bool | str | None) -> bool | None:
+ if isinstance(value, bool) or value is None:
+ return value
+ lowered = value.lower()
+ if lowered in {"true", "1", "yes", "on"}:
+ return True
+ if lowered in {"false", "0", "no", "off"}:
+ return False
+ raise ValueError("has_comment must be a boolean value")
+
+
+def reg(cls: type[BaseModel]):
+ console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+reg(ChatMessagesQuery)
+reg(MessageFeedbackPayload)
+reg(FeedbackExportQuery)
+
+# Register models for flask_restx to avoid dict type issues in Swagger
+# Register in dependency order: base models first, then dependent models
+
+# Base models
+simple_account_model = console_ns.model(
+ "SimpleAccount",
+ {
+ "id": fields.String,
+ "name": fields.String,
+ "email": fields.String,
+ },
+)
+
+message_file_model = console_ns.model(
+ "MessageFile",
+ {
+ "id": fields.String,
+ "filename": fields.String,
+ "type": fields.String,
+ "url": fields.String,
+ "mime_type": fields.String,
+ "size": fields.Integer,
+ "transfer_method": fields.String,
+ "belongs_to": fields.String(default="user"),
+ "upload_file_id": fields.String(default=None),
+ },
+)
+
+agent_thought_model = console_ns.model(
+ "AgentThought",
+ {
+ "id": fields.String,
+ "chain_id": fields.String,
+ "message_id": fields.String,
+ "position": fields.Integer,
+ "thought": fields.String,
+ "tool": fields.String,
+ "tool_labels": fields.Raw,
+ "tool_input": fields.String,
+ "created_at": TimestampField,
+ "observation": fields.String,
+ "files": fields.List(fields.String),
+ },
+)
+
+# Models that depend on simple_account_model
+feedback_model = console_ns.model(
+ "Feedback",
+ {
+ "rating": fields.String,
+ "content": fields.String,
+ "from_source": fields.String,
+ "from_end_user_id": fields.String,
+ "from_account": fields.Nested(simple_account_model, allow_null=True),
+ },
+)
+
+annotation_model = console_ns.model(
+ "Annotation",
+ {
+ "id": fields.String,
+ "question": fields.String,
+ "content": fields.String,
+ "account": fields.Nested(simple_account_model, allow_null=True),
+ "created_at": TimestampField,
+ },
+)
+
+annotation_hit_history_model = console_ns.model(
+ "AnnotationHitHistory",
+ {
+ "annotation_id": fields.String(attribute="id"),
+ "annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
+ "created_at": TimestampField,
+ },
+)
+
+# Message detail model that depends on multiple models
+message_detail_model = console_ns.model(
+ "MessageDetail",
+ {
+ "id": fields.String,
+ "conversation_id": fields.String,
+ "inputs": FilesContainedField,
+ "query": fields.String,
+ "message": fields.Raw,
+ "message_tokens": fields.Integer,
+ "answer": fields.String(attribute="re_sign_file_url_answer"),
+ "answer_tokens": fields.Integer,
+ "provider_response_latency": fields.Float,
+ "from_source": fields.String,
+ "from_end_user_id": fields.String,
+ "from_account_id": fields.String,
+ "feedbacks": fields.List(fields.Nested(feedback_model)),
+ "workflow_run_id": fields.String,
+ "annotation": fields.Nested(annotation_model, allow_null=True),
+ "annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
+ "created_at": TimestampField,
+ "agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
+ "message_files": fields.List(fields.Nested(message_file_model)),
+ "metadata": fields.Raw(attribute="message_metadata_dict"),
+ "status": fields.String,
+ "error": fields.String,
+ "parent_message_id": fields.String,
+ },
+)
+
+# Message infinite scroll pagination model
+message_infinite_scroll_pagination_model = console_ns.model(
+ "MessageInfiniteScrollPagination",
+ {
+ "limit": fields.Integer,
+ "has_more": fields.Boolean,
+ "data": fields.List(fields.Nested(message_detail_model)),
+ },
+)
@console_ns.route("/apps//chat-messages")
class ChatMessageListApi(Resource):
- message_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_detail_fields)),
- }
-
- @api.doc("list_chat_messages")
- @api.doc(description="Get chat messages for a conversation with pagination")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.parser()
- .add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
- .add_argument("first_id", type=str, location="args", help="First message ID for pagination")
- .add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
- )
- @api.response(200, "Success", message_infinite_scroll_pagination_fields)
- @api.response(404, "Conversation not found")
+ @console_ns.doc("list_chat_messages")
+ @console_ns.doc(description="Get chat messages for a conversation with pagination")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
+ @console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
+ @console_ns.response(404, "Conversation not found")
@login_required
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
- @marshal_with(message_infinite_scroll_pagination_fields)
+ @marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
- parser = (
- reqparse.RequestParser()
- .add_argument("conversation_id", required=True, type=uuid_value, location="args")
- .add_argument("first_id", type=uuid_value, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- )
- args = parser.parse_args()
+ args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
conversation = (
db.session.query(Conversation)
- .where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
+ .where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
- if args["first_id"]:
+ if args.first_id:
first_message = (
db.session.query(Message)
- .where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
+ .where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first()
)
@@ -96,7 +260,7 @@ class ChatMessageListApi(Resource):
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
- .limit(args["limit"])
+ .limit(args.limit)
.all()
)
else:
@@ -104,12 +268,12 @@ class ChatMessageListApi(Resource):
db.session.query(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
- .limit(args["limit"])
+ .limit(args.limit)
.all()
)
# Initialize has_more based on whether we have a full page
- if len(history_messages) == args["limit"]:
+ if len(history_messages) == args.limit:
current_page_first_message = history_messages[-1]
# Check if there are more messages before the current page
has_more = db.session.scalar(
@@ -127,26 +291,18 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages))
- return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
+ return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
@console_ns.route("/apps//feedbacks")
class MessageFeedbackApi(Resource):
- @api.doc("create_message_feedback")
- @api.doc(description="Create or update message feedback (like/dislike)")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "MessageFeedbackRequest",
- {
- "message_id": fields.String(required=True, description="Message ID"),
- "rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
- },
- )
- )
- @api.response(200, "Feedback updated successfully")
- @api.response(404, "Message not found")
- @api.response(403, "Insufficient permissions")
+ @console_ns.doc("create_message_feedback")
+ @console_ns.doc(description="Create or update message feedback (like/dislike)")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
+ @console_ns.response(200, "Feedback updated successfully")
+ @console_ns.response(404, "Message not found")
+ @console_ns.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@login_required
@@ -154,14 +310,9 @@ class MessageFeedbackApi(Resource):
def post(self, app_model):
current_user, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("message_id", required=True, type=uuid_value, location="json")
- .add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
- )
- args = parser.parse_args()
+ args = MessageFeedbackPayload.model_validate(console_ns.payload)
- message_id = str(args["message_id"])
+ message_id = str(args.message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
@@ -170,18 +321,23 @@ class MessageFeedbackApi(Resource):
feedback = message.admin_feedback
- if not args["rating"] and feedback:
+ if not args.rating and feedback:
db.session.delete(feedback)
- elif args["rating"] and feedback:
- feedback.rating = args["rating"]
- elif not args["rating"] and not feedback:
+ elif args.rating and feedback:
+ feedback.rating = args.rating
+ feedback.content = args.content
+ elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
+ rating_value = args.rating
+ if rating_value is None:
+ raise ValueError("rating is required to create feedback")
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
- rating=args["rating"],
+ rating=rating_value,
+ content=args.content,
from_source="admin",
from_account_id=current_user.id,
)
@@ -194,13 +350,13 @@ class MessageFeedbackApi(Resource):
@console_ns.route("/apps//annotations/count")
class MessageAnnotationCountApi(Resource):
- @api.doc("get_annotation_count")
- @api.doc(description="Get count of message annotations for the app")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(
+ @console_ns.doc("get_annotation_count")
+ @console_ns.doc(description="Get count of message annotations for the app")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(
200,
"Annotation count retrieved successfully",
- api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
+ console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
)
@get_app_model
@setup_required
@@ -214,15 +370,17 @@ class MessageAnnotationCountApi(Resource):
@console_ns.route("/apps//chat-messages//suggested-questions")
class MessageSuggestedQuestionApi(Resource):
- @api.doc("get_message_suggested_questions")
- @api.doc(description="Get suggested questions for a message")
- @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
- @api.response(
+ @console_ns.doc("get_message_suggested_questions")
+ @console_ns.doc(description="Get suggested questions for a message")
+ @console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
+ @console_ns.response(
200,
"Suggested questions retrieved successfully",
- api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}),
+ console_ns.model(
+ "SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
+ ),
)
- @api.response(404, "Message or conversation not found")
+ @console_ns.response(404, "Message or conversation not found")
@setup_required
@login_required
@account_initialization_required
@@ -256,18 +414,58 @@ class MessageSuggestedQuestionApi(Resource):
return {"data": questions}
-@console_ns.route("/apps//messages/")
-class MessageApi(Resource):
- @api.doc("get_message")
- @api.doc(description="Get message details by ID")
- @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
- @api.response(200, "Message retrieved successfully", message_detail_fields)
- @api.response(404, "Message not found")
+@console_ns.route("/apps//feedbacks/export")
+class MessageFeedbackExportApi(Resource):
+ @console_ns.doc("export_feedbacks")
+ @console_ns.doc(description="Export user feedback data for Google Sheets")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
+ @console_ns.response(200, "Feedback data exported successfully")
+ @console_ns.response(400, "Invalid parameters")
+ @console_ns.response(500, "Internal server error")
@get_app_model
@setup_required
@login_required
@account_initialization_required
- @marshal_with(message_detail_fields)
+ def get(self, app_model):
+ args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+
+ # Import the service function
+ from services.feedback_service import FeedbackService
+
+ try:
+ export_data = FeedbackService.export_feedbacks(
+ app_id=app_model.id,
+ from_source=args.from_source,
+ rating=args.rating,
+ has_comment=args.has_comment,
+ start_date=args.start_date,
+ end_date=args.end_date,
+ format_type=args.format,
+ )
+
+ return export_data
+
+ except ValueError as e:
+ logger.exception("Parameter validation error in feedback export")
+ return {"error": f"Parameter validation error: {str(e)}"}, 400
+ except Exception as e:
+ logger.exception("Error exporting feedback data")
+ raise InternalServerError(str(e))
+
+
+@console_ns.route("/apps//messages/")
+class MessageApi(Resource):
+ @console_ns.doc("get_message")
+ @console_ns.doc(description="Get message details by ID")
+ @console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
+ @console_ns.response(200, "Message retrieved successfully", message_detail_model)
+ @console_ns.response(404, "Message not found")
+ @get_app_model
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @marshal_with(message_detail_model)
def get(self, app_model, message_id: str):
message_id = str(message_id)
diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py
index 91e2cfd60e..a85e54fb51 100644
--- a/api/controllers/console/app/model_config.py
+++ b/api/controllers/console/app/model_config.py
@@ -4,7 +4,7 @@ from typing import cast
from flask import request
from flask_restx import Resource, fields
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.agent.entities import AgentToolEntity
@@ -20,11 +20,11 @@ from services.app_model_config_service import AppModelConfigService
@console_ns.route("/apps//model-config")
class ModelConfigResource(Resource):
- @api.doc("update_app_model_config")
- @api.doc(description="Update application model configuration")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
+ @console_ns.doc("update_app_model_config")
+ @console_ns.doc(description="Update application model configuration")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(
+ console_ns.model(
"ModelConfigRequest",
{
"provider": fields.String(description="Model provider"),
@@ -42,9 +42,9 @@ class ModelConfigResource(Resource):
},
)
)
- @api.response(200, "Model configuration updated successfully")
- @api.response(400, "Invalid configuration")
- @api.response(404, "App not found")
+ @console_ns.response(200, "Model configuration updated successfully")
+ @console_ns.response(400, "Invalid configuration")
+ @console_ns.response(404, "App not found")
@setup_required
@login_required
@edit_permission_required
diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py
index 1d80314774..cbcf513162 100644
--- a/api/controllers/console/app/ops_trace.py
+++ b/api/controllers/console/app/ops_trace.py
@@ -1,12 +1,36 @@
-from flask_restx import Resource, fields, reqparse
+from typing import Any
+
+from flask import request
+from flask_restx import Resource, fields
+from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.ops_service import OpsService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class TraceProviderQuery(BaseModel):
+ tracing_provider: str = Field(..., description="Tracing provider name")
+
+
+class TraceConfigPayload(BaseModel):
+ tracing_provider: str = Field(..., description="Tracing provider name")
+ tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
+
+
+console_ns.schema_model(
+ TraceProviderQuery.__name__,
+ TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
@console_ns.route("/apps//trace-config")
class TraceAppConfigApi(Resource):
@@ -14,64 +38,46 @@ class TraceAppConfigApi(Resource):
Manage trace app configurations
"""
- @api.doc("get_trace_app_config")
- @api.doc(description="Get tracing configuration for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.parser().add_argument(
- "tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
- )
- )
- @api.response(
+ @console_ns.doc("get_trace_app_config")
+ @console_ns.doc(description="Get tracing configuration for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
+ @console_ns.response(
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
)
- @api.response(400, "Invalid request parameters")
+ @console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
- parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
- args = parser.parse_args()
+ args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
- trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
+ trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not trace_config:
return {"has_not_configured": True}
return trace_config
except Exception as e:
raise BadRequest(str(e))
- @api.doc("create_trace_app_config")
- @api.doc(description="Create a new tracing configuration for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "TraceConfigCreateRequest",
- {
- "tracing_provider": fields.String(required=True, description="Tracing provider name"),
- "tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
- },
- )
- )
- @api.response(
+ @console_ns.doc("create_trace_app_config")
+ @console_ns.doc(description="Create a new tracing configuration for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
+ @console_ns.response(
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
)
- @api.response(400, "Invalid request parameters or configuration already exists")
+ @console_ns.response(400, "Invalid request parameters or configuration already exists")
@setup_required
@login_required
@account_initialization_required
def post(self, app_id):
"""Create a new trace app configuration"""
- parser = (
- reqparse.RequestParser()
- .add_argument("tracing_provider", type=str, required=True, location="json")
- .add_argument("tracing_config", type=dict, required=True, location="json")
- )
- args = parser.parse_args()
+ args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.create_tracing_app_config(
- app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
+ app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigIsExist()
@@ -81,35 +87,22 @@ class TraceAppConfigApi(Resource):
except Exception as e:
raise BadRequest(str(e))
- @api.doc("update_trace_app_config")
- @api.doc(description="Update an existing tracing configuration for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "TraceConfigUpdateRequest",
- {
- "tracing_provider": fields.String(required=True, description="Tracing provider name"),
- "tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
- },
- )
- )
- @api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
- @api.response(400, "Invalid request parameters or configuration not found")
+ @console_ns.doc("update_trace_app_config")
+ @console_ns.doc(description="Update an existing tracing configuration for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
+ @console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
+ @console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required
@login_required
@account_initialization_required
def patch(self, app_id):
"""Update an existing trace app configuration"""
- parser = (
- reqparse.RequestParser()
- .add_argument("tracing_provider", type=str, required=True, location="json")
- .add_argument("tracing_config", type=dict, required=True, location="json")
- )
- args = parser.parse_args()
+ args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.update_tracing_app_config(
- app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
+ app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigNotExist()
@@ -117,26 +110,21 @@ class TraceAppConfigApi(Resource):
except Exception as e:
raise BadRequest(str(e))
- @api.doc("delete_trace_app_config")
- @api.doc(description="Delete an existing tracing configuration for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.parser().add_argument(
- "tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
- )
- )
- @api.response(204, "Tracing configuration deleted successfully")
- @api.response(400, "Invalid request parameters or configuration not found")
+ @console_ns.doc("delete_trace_app_config")
+ @console_ns.doc(description="Delete an existing tracing configuration for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
+ @console_ns.response(204, "Tracing configuration deleted successfully")
+ @console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id):
"""Delete an existing trace app configuration"""
- parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
- args = parser.parse_args()
+ args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
- result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
+ result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204
diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py
index b8edbf77c7..db218d8b81 100644
--- a/api/controllers/console/app/site.py
+++ b/api/controllers/console/app/site.py
@@ -1,8 +1,11 @@
-from flask_restx import Resource, fields, marshal_with, reqparse
+from typing import Literal
+
+from flask_restx import Resource, marshal_with
+from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@@ -16,77 +19,61 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import Site
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-def parse_app_site_args():
- parser = (
- reqparse.RequestParser()
- .add_argument("title", type=str, required=False, location="json")
- .add_argument("icon_type", type=str, required=False, location="json")
- .add_argument("icon", type=str, required=False, location="json")
- .add_argument("icon_background", type=str, required=False, location="json")
- .add_argument("description", type=str, required=False, location="json")
- .add_argument("default_language", type=supported_language, required=False, location="json")
- .add_argument("chat_color_theme", type=str, required=False, location="json")
- .add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
- .add_argument("customize_domain", type=str, required=False, location="json")
- .add_argument("copyright", type=str, required=False, location="json")
- .add_argument("privacy_policy", type=str, required=False, location="json")
- .add_argument("custom_disclaimer", type=str, required=False, location="json")
- .add_argument(
- "customize_token_strategy",
- type=str,
- choices=["must", "allow", "not_allow"],
- required=False,
- location="json",
- )
- .add_argument("prompt_public", type=bool, required=False, location="json")
- .add_argument("show_workflow_steps", type=bool, required=False, location="json")
- .add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
- )
- return parser.parse_args()
+
+class AppSiteUpdatePayload(BaseModel):
+ title: str | None = Field(default=None)
+ icon_type: str | None = Field(default=None)
+ icon: str | None = Field(default=None)
+ icon_background: str | None = Field(default=None)
+ description: str | None = Field(default=None)
+ default_language: str | None = Field(default=None)
+ chat_color_theme: str | None = Field(default=None)
+ chat_color_theme_inverted: bool | None = Field(default=None)
+ customize_domain: str | None = Field(default=None)
+ copyright: str | None = Field(default=None)
+ privacy_policy: str | None = Field(default=None)
+ custom_disclaimer: str | None = Field(default=None)
+ customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
+ prompt_public: bool | None = Field(default=None)
+ show_workflow_steps: bool | None = Field(default=None)
+ use_icon_as_answer_icon: bool | None = Field(default=None)
+
+ @field_validator("default_language")
+ @classmethod
+ def validate_language(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return supported_language(value)
+
+
+console_ns.schema_model(
+ AppSiteUpdatePayload.__name__,
+ AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
+# Register model for flask_restx to avoid dict type issues in Swagger
+app_site_model = console_ns.model("AppSite", app_site_fields)
@console_ns.route("/apps//site")
class AppSite(Resource):
- @api.doc("update_app_site")
- @api.doc(description="Update application site configuration")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "AppSiteRequest",
- {
- "title": fields.String(description="Site title"),
- "icon_type": fields.String(description="Icon type"),
- "icon": fields.String(description="Icon"),
- "icon_background": fields.String(description="Icon background color"),
- "description": fields.String(description="Site description"),
- "default_language": fields.String(description="Default language"),
- "chat_color_theme": fields.String(description="Chat color theme"),
- "chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
- "customize_domain": fields.String(description="Custom domain"),
- "copyright": fields.String(description="Copyright text"),
- "privacy_policy": fields.String(description="Privacy policy"),
- "custom_disclaimer": fields.String(description="Custom disclaimer"),
- "customize_token_strategy": fields.String(
- enum=["must", "allow", "not_allow"], description="Token strategy"
- ),
- "prompt_public": fields.Boolean(description="Make prompt public"),
- "show_workflow_steps": fields.Boolean(description="Show workflow steps"),
- "use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
- },
- )
- )
- @api.response(200, "Site configuration updated successfully", app_site_fields)
- @api.response(403, "Insufficient permissions")
- @api.response(404, "App not found")
+ @console_ns.doc("update_app_site")
+ @console_ns.doc(description="Update application site configuration")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
+ @console_ns.response(200, "Site configuration updated successfully", app_site_model)
+ @console_ns.response(403, "Insufficient permissions")
+ @console_ns.response(404, "App not found")
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_app_model
- @marshal_with(app_site_fields)
+ @marshal_with(app_site_model)
def post(self, app_model):
- args = parse_app_site_args()
+ args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
@@ -110,7 +97,7 @@ class AppSite(Resource):
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
- value = args.get(attr_name)
+ value = getattr(args, attr_name)
if value is not None:
setattr(site, attr_name, value)
@@ -123,18 +110,18 @@ class AppSite(Resource):
@console_ns.route("/apps//site/access-token-reset")
class AppSiteAccessTokenReset(Resource):
- @api.doc("reset_app_site_access_token")
- @api.doc(description="Reset access token for application site")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Access token reset successfully", app_site_fields)
- @api.response(403, "Insufficient permissions (admin/owner required)")
- @api.response(404, "App or site not found")
+ @console_ns.doc("reset_app_site_access_token")
+ @console_ns.doc(description="Reset access token for application site")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Access token reset successfully", app_site_model)
+ @console_ns.response(403, "Insufficient permissions (admin/owner required)")
+ @console_ns.response(404, "App or site not found")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
- @marshal_with(app_site_fields)
+ @marshal_with(app_site_model)
def post(self, app_model):
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py
index b4bd05e891..ffa28b1c95 100644
--- a/api/controllers/console/app/statistic.py
+++ b/api/controllers/console/app/statistic.py
@@ -1,31 +1,48 @@
from decimal import Decimal
import sqlalchemy as sa
-from flask import abort, jsonify
-from flask_restx import Resource, fields, reqparse
+from flask import abort, jsonify, request
+from flask_restx import Resource, fields
+from pydantic import BaseModel, Field, field_validator
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
-from libs.helper import DatetimeString, convert_datetime_to_date
+from libs.helper import convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required
from models import AppMode
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class StatisticTimeRangeQuery(BaseModel):
+ start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
+ end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
+
+ @field_validator("start", "end", mode="before")
+ @classmethod
+ def empty_string_to_none(cls, value: str | None) -> str | None:
+ if value == "":
+ return None
+ return value
+
+
+console_ns.schema_model(
+ StatisticTimeRangeQuery.__name__,
+ StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
@console_ns.route("/apps//statistics/daily-messages")
class DailyMessageStatistic(Resource):
- @api.doc("get_daily_message_statistics")
- @api.doc(description="Get daily message statistics for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.parser()
- .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
- .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
- )
- @api.response(
+ @console_ns.doc("get_daily_message_statistics")
+ @console_ns.doc(description="Get daily message statistics for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
+ @console_ns.response(
200,
"Daily message statistics retrieved successfully",
fields.List(fields.Raw(description="Daily message count data")),
@@ -37,12 +54,7 @@ class DailyMessageStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- )
- args = parser.parse_args()
+ args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@@ -57,7 +69,7 @@ WHERE
assert account.timezone is not None
try:
- start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
+ start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@@ -81,20 +93,13 @@ WHERE
return jsonify({"data": response_data})
-parser = (
- reqparse.RequestParser()
- .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
- .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
-)
-
-
@console_ns.route("/apps//statistics/daily-conversations")
class DailyConversationStatistic(Resource):
- @api.doc("get_daily_conversation_statistics")
- @api.doc(description="Get daily conversation statistics for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(parser)
- @api.response(
+ @console_ns.doc("get_daily_conversation_statistics")
+ @console_ns.doc(description="Get daily conversation statistics for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
+ @console_ns.response(
200,
"Daily conversation statistics retrieved successfully",
fields.List(fields.Raw(description="Daily conversation count data")),
@@ -106,7 +111,7 @@ class DailyConversationStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
- args = parser.parse_args()
+ args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@@ -121,7 +126,7 @@ WHERE
assert account.timezone is not None
try:
- start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
+ start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@@ -146,11 +151,11 @@ WHERE
@console_ns.route("/apps//statistics/daily-end-users")
class DailyTerminalsStatistic(Resource):
- @api.doc("get_daily_terminals_statistics")
- @api.doc(description="Get daily terminal/end-user statistics for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(parser)
- @api.response(
+ @console_ns.doc("get_daily_terminals_statistics")
+ @console_ns.doc(description="Get daily terminal/end-user statistics for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
+ @console_ns.response(
200,
"Daily terminal statistics retrieved successfully",
fields.List(fields.Raw(description="Daily terminal count data")),
@@ -162,7 +167,7 @@ class DailyTerminalsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
- args = parser.parse_args()
+ args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@@ -177,7 +182,7 @@ WHERE
assert account.timezone is not None
try:
- start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
+ start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@@ -203,11 +208,11 @@ WHERE
@console_ns.route("/apps//statistics/token-costs")
class DailyTokenCostStatistic(Resource):
- @api.doc("get_daily_token_cost_statistics")
- @api.doc(description="Get daily token cost statistics for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(parser)
- @api.response(
+ @console_ns.doc("get_daily_token_cost_statistics")
+ @console_ns.doc(description="Get daily token cost statistics for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
+ @console_ns.response(
200,
"Daily token cost statistics retrieved successfully",
fields.List(fields.Raw(description="Daily token cost data")),
@@ -219,7 +224,7 @@ class DailyTokenCostStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
- args = parser.parse_args()
+ args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@@ -235,7 +240,7 @@ WHERE
assert account.timezone is not None
try:
- start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
+ start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@@ -263,11 +268,11 @@ WHERE
@console_ns.route("/apps//statistics/average-session-interactions")
class AverageSessionInteractionStatistic(Resource):
- @api.doc("get_average_session_interaction_statistics")
- @api.doc(description="Get average session interaction statistics for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(parser)
- @api.response(
+ @console_ns.doc("get_average_session_interaction_statistics")
+ @console_ns.doc(description="Get average session interaction statistics for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
+ @console_ns.response(
200,
"Average session interaction statistics retrieved successfully",
fields.List(fields.Raw(description="Average session interaction data")),
@@ -279,7 +284,7 @@ class AverageSessionInteractionStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
- args = parser.parse_args()
+ args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("c.created_at")
sql_query = f"""SELECT
@@ -302,7 +307,7 @@ FROM
assert account.timezone is not None
try:
- start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
+ start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@@ -339,11 +344,11 @@ ORDER BY
@console_ns.route("/apps//statistics/user-satisfaction-rate")
class UserSatisfactionRateStatistic(Resource):
- @api.doc("get_user_satisfaction_rate_statistics")
- @api.doc(description="Get user satisfaction rate statistics for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(parser)
- @api.response(
+ @console_ns.doc("get_user_satisfaction_rate_statistics")
+ @console_ns.doc(description="Get user satisfaction rate statistics for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
+ @console_ns.response(
200,
"User satisfaction rate statistics retrieved successfully",
fields.List(fields.Raw(description="User satisfaction rate data")),
@@ -355,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
- args = parser.parse_args()
+ args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("m.created_at")
sql_query = f"""SELECT
@@ -374,7 +379,7 @@ WHERE
assert account.timezone is not None
try:
- start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
+ start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@@ -405,11 +410,11 @@ WHERE
@console_ns.route("/apps//statistics/average-response-time")
class AverageResponseTimeStatistic(Resource):
- @api.doc("get_average_response_time_statistics")
- @api.doc(description="Get average response time statistics for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(parser)
- @api.response(
+ @console_ns.doc("get_average_response_time_statistics")
+ @console_ns.doc(description="Get average response time statistics for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
+ @console_ns.response(
200,
"Average response time statistics retrieved successfully",
fields.List(fields.Raw(description="Average response time data")),
@@ -421,7 +426,7 @@ class AverageResponseTimeStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
- args = parser.parse_args()
+ args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@@ -436,7 +441,7 @@ WHERE
assert account.timezone is not None
try:
- start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
+ start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@@ -462,11 +467,11 @@ WHERE
@console_ns.route("/apps//statistics/tokens-per-second")
class TokensPerSecondStatistic(Resource):
- @api.doc("get_tokens_per_second_statistics")
- @api.doc(description="Get tokens per second statistics for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(parser)
- @api.response(
+ @console_ns.doc("get_tokens_per_second_statistics")
+ @console_ns.doc(description="Get tokens per second statistics for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
+ @console_ns.response(
200,
"Tokens per second statistics retrieved successfully",
fields.List(fields.Raw(description="Tokens per second data")),
@@ -477,7 +482,7 @@ class TokensPerSecondStatistic(Resource):
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
- args = parser.parse_args()
+ args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@@ -495,7 +500,7 @@ WHERE
assert account.timezone is not None
try:
- start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
+ start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py
index 2f6808f11d..b4f2ef0ba8 100644
--- a/api/controllers/console/app/workflow.py
+++ b/api/controllers/console/app/workflow.py
@@ -1,15 +1,16 @@
import json
import logging
from collections.abc import Sequence
-from typing import cast
+from typing import Any
from flask import abort, request
-from flask_restx import Resource, fields, inputs, marshal_with, reqparse
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@@ -32,6 +33,7 @@ from core.workflow.enums import NodeType
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from factories import file_factory, variable_factory
+from fields.member_fields import simple_account_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
@@ -48,6 +50,161 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
logger = logging.getLogger(__name__)
LISTENING_RETRY_IN = 2000
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+# Register models for flask_restx to avoid dict type issues in Swagger
+# Register in dependency order: base models first, then dependent models
+
+# Base models
+simple_account_model = console_ns.model("SimpleAccount", simple_account_fields)
+
+from fields.workflow_fields import pipeline_variable_fields, serialize_value_type
+
+conversation_variable_model = console_ns.model(
+ "ConversationVariable",
+ {
+ "id": fields.String,
+ "name": fields.String,
+ "value_type": fields.String(attribute=serialize_value_type),
+ "value": fields.Raw,
+ "description": fields.String,
+ },
+)
+
+pipeline_variable_model = console_ns.model("PipelineVariable", pipeline_variable_fields)
+
+# Workflow model with nested dependencies
+workflow_fields_copy = workflow_fields.copy()
+workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account")
+workflow_fields_copy["updated_by"] = fields.Nested(
+ simple_account_model, attribute="updated_by_account", allow_null=True
+)
+workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model))
+workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model))
+workflow_model = console_ns.model("Workflow", workflow_fields_copy)
+
+# Workflow pagination model
+workflow_pagination_fields_copy = workflow_pagination_fields.copy()
+workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items")
+workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy)
+
+# Reuse workflow_run_node_execution_model from workflow_run.py if already registered
+# Otherwise register it here
+from fields.end_user_fields import simple_end_user_fields
+
+simple_end_user_model = None
+try:
+ simple_end_user_model = console_ns.models.get("SimpleEndUser")
+except AttributeError:
+ pass
+if simple_end_user_model is None:
+ simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
+
+workflow_run_node_execution_model = None
+try:
+ workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution")
+except AttributeError:
+ pass
+if workflow_run_node_execution_model is None:
+ workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
+
+
+class SyncDraftWorkflowPayload(BaseModel):
+ graph: dict[str, Any]
+ features: dict[str, Any]
+ hash: str | None = None
+ environment_variables: list[dict[str, Any]] = Field(default_factory=list)
+ conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
+
+
+class BaseWorkflowRunPayload(BaseModel):
+ files: list[dict[str, Any]] | None = None
+
+
+class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
+ inputs: dict[str, Any] | None = None
+ query: str = ""
+ conversation_id: str | None = None
+ parent_message_id: str | None = None
+
+ @field_validator("conversation_id", "parent_message_id")
+ @classmethod
+ def validate_uuid(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class IterationNodeRunPayload(BaseModel):
+ inputs: dict[str, Any] | None = None
+
+
+class LoopNodeRunPayload(BaseModel):
+ inputs: dict[str, Any] | None = None
+
+
+class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
+ inputs: dict[str, Any]
+
+
+class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
+ inputs: dict[str, Any]
+ query: str = ""
+
+
+class PublishWorkflowPayload(BaseModel):
+ marked_name: str | None = Field(default=None, max_length=20)
+ marked_comment: str | None = Field(default=None, max_length=100)
+
+
+class DefaultBlockConfigQuery(BaseModel):
+ q: str | None = None
+
+
+class ConvertToWorkflowPayload(BaseModel):
+ name: str | None = None
+ icon_type: str | None = None
+ icon: str | None = None
+ icon_background: str | None = None
+
+
+class WorkflowListQuery(BaseModel):
+ page: int = Field(default=1, ge=1, le=99999)
+ limit: int = Field(default=10, ge=1, le=100)
+ user_id: str | None = None
+ named_only: bool = False
+
+
+class WorkflowUpdatePayload(BaseModel):
+ marked_name: str | None = Field(default=None, max_length=20)
+ marked_comment: str | None = Field(default=None, max_length=100)
+
+
+class DraftWorkflowTriggerRunPayload(BaseModel):
+ node_id: str
+
+
+class DraftWorkflowTriggerRunAllPayload(BaseModel):
+ node_ids: list[str]
+
+
+def reg(cls: type[BaseModel]):
+ console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+reg(SyncDraftWorkflowPayload)
+reg(AdvancedChatWorkflowRunPayload)
+reg(IterationNodeRunPayload)
+reg(LoopNodeRunPayload)
+reg(DraftWorkflowRunPayload)
+reg(DraftWorkflowNodeRunPayload)
+reg(PublishWorkflowPayload)
+reg(DefaultBlockConfigQuery)
+reg(ConvertToWorkflowPayload)
+reg(WorkflowListQuery)
+reg(WorkflowUpdatePayload)
+reg(DraftWorkflowTriggerRunPayload)
+reg(DraftWorkflowTriggerRunAllPayload)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
@@ -70,16 +227,16 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence
@console_ns.route("/apps//workflows/draft")
class DraftWorkflowApi(Resource):
- @api.doc("get_draft_workflow")
- @api.doc(description="Get draft workflow for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Draft workflow retrieved successfully", workflow_fields)
- @api.response(404, "Draft workflow not found")
+ @console_ns.doc("get_draft_workflow")
+ @console_ns.doc(description="Get draft workflow for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
+ @console_ns.response(404, "Draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- @marshal_with(workflow_fields)
+ @marshal_with(workflow_model)
@edit_permission_required
def get(self, app_model: App):
"""
@@ -99,24 +256,13 @@ class DraftWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- @api.doc("sync_draft_workflow")
- @api.doc(description="Sync draft workflow configuration")
- @api.expect(
- api.model(
- "SyncDraftWorkflowRequest",
- {
- "graph": fields.Raw(required=True, description="Workflow graph configuration"),
- "features": fields.Raw(required=True, description="Workflow features configuration"),
- "hash": fields.String(description="Workflow hash for validation"),
- "environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
- "conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
- },
- )
- )
- @api.response(
+ @console_ns.doc("sync_draft_workflow")
+ @console_ns.doc(description="Sync draft workflow configuration")
+ @console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
+ @console_ns.response(
200,
"Draft workflow synced successfully",
- api.model(
+ console_ns.model(
"SyncDraftWorkflowResponse",
{
"result": fields.String,
@@ -125,8 +271,8 @@ class DraftWorkflowApi(Resource):
},
),
)
- @api.response(400, "Invalid workflow configuration")
- @api.response(403, "Permission denied")
+ @console_ns.response(400, "Invalid workflow configuration")
+ @console_ns.response(403, "Permission denied")
@edit_permission_required
def post(self, app_model: App):
"""
@@ -136,36 +282,23 @@ class DraftWorkflowApi(Resource):
content_type = request.headers.get("Content-Type", "")
+ payload_data: dict[str, Any] | None = None
if "application/json" in content_type:
- parser = (
- reqparse.RequestParser()
- .add_argument("graph", type=dict, required=True, nullable=False, location="json")
- .add_argument("features", type=dict, required=True, nullable=False, location="json")
- .add_argument("hash", type=str, required=False, location="json")
- .add_argument("environment_variables", type=list, required=True, location="json")
- .add_argument("conversation_variables", type=list, required=False, location="json")
- )
- args = parser.parse_args()
+ payload_data = request.get_json(silent=True)
+ if not isinstance(payload_data, dict):
+ return {"message": "Invalid JSON data"}, 400
elif "text/plain" in content_type:
try:
- data = json.loads(request.data.decode("utf-8"))
- if "graph" not in data or "features" not in data:
- raise ValueError("graph or features not found in data")
-
- if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
- raise ValueError("graph or features is not a dict")
-
- args = {
- "graph": data.get("graph"),
- "features": data.get("features"),
- "hash": data.get("hash"),
- "environment_variables": data.get("environment_variables"),
- "conversation_variables": data.get("conversation_variables"),
- }
+ payload_data = json.loads(request.data.decode("utf-8"))
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
+ if not isinstance(payload_data, dict):
+ return {"message": "Invalid JSON data"}, 400
else:
abort(415)
+
+ args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
+ args = args_model.model_dump()
workflow_service = WorkflowService()
try:
@@ -198,23 +331,13 @@ class DraftWorkflowApi(Resource):
@console_ns.route("/apps//advanced-chat/workflows/draft/run")
class AdvancedChatDraftWorkflowRunApi(Resource):
- @api.doc("run_advanced_chat_draft_workflow")
- @api.doc(description="Run draft workflow for advanced chat application")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "AdvancedChatWorkflowRunRequest",
- {
- "query": fields.String(required=True, description="User query"),
- "inputs": fields.Raw(description="Input variables"),
- "files": fields.List(fields.Raw, description="File uploads"),
- "conversation_id": fields.String(description="Conversation ID"),
- },
- )
- )
- @api.response(200, "Workflow run started successfully")
- @api.response(400, "Invalid request parameters")
- @api.response(403, "Permission denied")
+ @console_ns.doc("run_advanced_chat_draft_workflow")
+ @console_ns.doc(description="Run draft workflow for advanced chat application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__])
+ @console_ns.response(200, "Workflow run started successfully")
+ @console_ns.response(400, "Invalid request parameters")
+ @console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@@ -226,16 +349,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
"""
current_user, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, location="json")
- .add_argument("query", type=str, required=True, location="json", default="")
- .add_argument("files", type=list, location="json")
- .add_argument("conversation_id", type=uuid_value, location="json")
- .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
- )
-
- args = parser.parse_args()
+ args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
+ args = args_model.model_dump(exclude_none=True)
external_trace_id = get_external_trace_id(request)
if external_trace_id:
@@ -262,21 +377,13 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@console_ns.route("/apps//advanced-chat/workflows/draft/iteration/nodes//run")
class AdvancedChatDraftRunIterationNodeApi(Resource):
- @api.doc("run_advanced_chat_draft_iteration_node")
- @api.doc(description="Run draft workflow iteration node for advanced chat")
- @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
- @api.expect(
- api.model(
- "IterationNodeRunRequest",
- {
- "task_id": fields.String(required=True, description="Task ID"),
- "inputs": fields.Raw(description="Input variables"),
- },
- )
- )
- @api.response(200, "Iteration node run started successfully")
- @api.response(403, "Permission denied")
- @api.response(404, "Node not found")
+ @console_ns.doc("run_advanced_chat_draft_iteration_node")
+ @console_ns.doc(description="Run draft workflow iteration node for advanced chat")
+ @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
+ @console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
+ @console_ns.response(200, "Iteration node run started successfully")
+ @console_ns.response(403, "Permission denied")
+ @console_ns.response(404, "Node not found")
@setup_required
@login_required
@account_initialization_required
@@ -287,8 +394,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node
"""
current_user, _ = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
- args = parser.parse_args()
+ args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = AppGenerateService.generate_single_iteration(
@@ -309,21 +415,13 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@console_ns.route("/apps//workflows/draft/iteration/nodes//run")
class WorkflowDraftRunIterationNodeApi(Resource):
- @api.doc("run_workflow_draft_iteration_node")
- @api.doc(description="Run draft workflow iteration node")
- @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
- @api.expect(
- api.model(
- "WorkflowIterationNodeRunRequest",
- {
- "task_id": fields.String(required=True, description="Task ID"),
- "inputs": fields.Raw(description="Input variables"),
- },
- )
- )
- @api.response(200, "Workflow iteration node run started successfully")
- @api.response(403, "Permission denied")
- @api.response(404, "Node not found")
+ @console_ns.doc("run_workflow_draft_iteration_node")
+ @console_ns.doc(description="Run draft workflow iteration node")
+ @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
+ @console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
+ @console_ns.response(200, "Workflow iteration node run started successfully")
+ @console_ns.response(403, "Permission denied")
+ @console_ns.response(404, "Node not found")
@setup_required
@login_required
@account_initialization_required
@@ -334,8 +432,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node
"""
current_user, _ = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
- args = parser.parse_args()
+ args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = AppGenerateService.generate_single_iteration(
@@ -356,21 +453,13 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@console_ns.route("/apps//advanced-chat/workflows/draft/loop/nodes//run")
class AdvancedChatDraftRunLoopNodeApi(Resource):
- @api.doc("run_advanced_chat_draft_loop_node")
- @api.doc(description="Run draft workflow loop node for advanced chat")
- @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
- @api.expect(
- api.model(
- "LoopNodeRunRequest",
- {
- "task_id": fields.String(required=True, description="Task ID"),
- "inputs": fields.Raw(description="Input variables"),
- },
- )
- )
- @api.response(200, "Loop node run started successfully")
- @api.response(403, "Permission denied")
- @api.response(404, "Node not found")
+ @console_ns.doc("run_advanced_chat_draft_loop_node")
+ @console_ns.doc(description="Run draft workflow loop node for advanced chat")
+ @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
+ @console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
+ @console_ns.response(200, "Loop node run started successfully")
+ @console_ns.response(403, "Permission denied")
+ @console_ns.response(404, "Node not found")
@setup_required
@login_required
@account_initialization_required
@@ -381,8 +470,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
- args = parser.parse_args()
+ args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = AppGenerateService.generate_single_loop(
@@ -403,21 +491,13 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@console_ns.route("/apps//workflows/draft/loop/nodes//run")
class WorkflowDraftRunLoopNodeApi(Resource):
- @api.doc("run_workflow_draft_loop_node")
- @api.doc(description="Run draft workflow loop node")
- @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
- @api.expect(
- api.model(
- "WorkflowLoopNodeRunRequest",
- {
- "task_id": fields.String(required=True, description="Task ID"),
- "inputs": fields.Raw(description="Input variables"),
- },
- )
- )
- @api.response(200, "Workflow loop node run started successfully")
- @api.response(403, "Permission denied")
- @api.response(404, "Node not found")
+ @console_ns.doc("run_workflow_draft_loop_node")
+ @console_ns.doc(description="Run draft workflow loop node")
+ @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
+ @console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
+ @console_ns.response(200, "Workflow loop node run started successfully")
+ @console_ns.response(403, "Permission denied")
+ @console_ns.response(404, "Node not found")
@setup_required
@login_required
@account_initialization_required
@@ -428,8 +508,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
- args = parser.parse_args()
+ args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = AppGenerateService.generate_single_loop(
@@ -450,20 +529,12 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@console_ns.route("/apps//workflows/draft/run")
class DraftWorkflowRunApi(Resource):
- @api.doc("run_draft_workflow")
- @api.doc(description="Run draft workflow")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "DraftWorkflowRunRequest",
- {
- "inputs": fields.Raw(required=True, description="Input variables"),
- "files": fields.List(fields.Raw, description="File uploads"),
- },
- )
- )
- @api.response(200, "Draft workflow run started successfully")
- @api.response(403, "Permission denied")
+ @console_ns.doc("run_draft_workflow")
+ @console_ns.doc(description="Run draft workflow")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
+ @console_ns.response(200, "Draft workflow run started successfully")
+ @console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@@ -474,12 +545,7 @@ class DraftWorkflowRunApi(Resource):
Run draft workflow
"""
current_user, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
- .add_argument("files", type=list, required=False, location="json")
- )
- args = parser.parse_args()
+ args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
external_trace_id = get_external_trace_id(request)
if external_trace_id:
@@ -501,12 +567,12 @@ class DraftWorkflowRunApi(Resource):
@console_ns.route("/apps//workflow-runs/tasks//stop")
class WorkflowTaskStopApi(Resource):
- @api.doc("stop_workflow_task")
- @api.doc(description="Stop running workflow task")
- @api.doc(params={"app_id": "Application ID", "task_id": "Task ID"})
- @api.response(200, "Task stopped successfully")
- @api.response(404, "Task not found")
- @api.response(403, "Permission denied")
+ @console_ns.doc("stop_workflow_task")
+ @console_ns.doc(description="Stop running workflow task")
+ @console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID"})
+ @console_ns.response(200, "Task stopped successfully")
+ @console_ns.response(404, "Task not found")
+ @console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@@ -528,40 +594,28 @@ class WorkflowTaskStopApi(Resource):
@console_ns.route("/apps//workflows/draft/nodes//run")
class DraftWorkflowNodeRunApi(Resource):
- @api.doc("run_draft_workflow_node")
- @api.doc(description="Run draft workflow node")
- @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
- @api.expect(
- api.model(
- "DraftWorkflowNodeRunRequest",
- {
- "inputs": fields.Raw(description="Input variables"),
- },
- )
- )
- @api.response(200, "Node run started successfully", workflow_run_node_execution_fields)
- @api.response(403, "Permission denied")
- @api.response(404, "Node not found")
+ @console_ns.doc("run_draft_workflow_node")
+ @console_ns.doc(description="Run draft workflow node")
+ @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
+ @console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
+ @console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
+ @console_ns.response(403, "Permission denied")
+ @console_ns.response(404, "Node not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- @marshal_with(workflow_run_node_execution_fields)
+ @marshal_with(workflow_run_node_execution_model)
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow node
"""
current_user, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
- .add_argument("query", type=str, required=False, location="json", default="")
- .add_argument("files", type=list, location="json", default=[])
- )
- args = parser.parse_args()
+ args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
+ args = args_model.model_dump(exclude_none=True)
- user_inputs = args.get("inputs")
+ user_inputs = args_model.inputs
if user_inputs is None:
raise ValueError("missing inputs")
@@ -586,25 +640,18 @@ class DraftWorkflowNodeRunApi(Resource):
return workflow_node_execution
-parser_publish = (
- reqparse.RequestParser()
- .add_argument("marked_name", type=str, required=False, default="", location="json")
- .add_argument("marked_comment", type=str, required=False, default="", location="json")
-)
-
-
@console_ns.route("/apps//workflows/publish")
class PublishedWorkflowApi(Resource):
- @api.doc("get_published_workflow")
- @api.doc(description="Get published workflow for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Published workflow retrieved successfully", workflow_fields)
- @api.response(404, "Published workflow not found")
+ @console_ns.doc("get_published_workflow")
+ @console_ns.doc(description="Get published workflow for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
+ @console_ns.response(404, "Published workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- @marshal_with(workflow_fields)
+ @marshal_with(workflow_model)
@edit_permission_required
def get(self, app_model: App):
"""
@@ -617,7 +664,7 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None
return workflow
- @api.expect(parser_publish)
+ @console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -629,13 +676,7 @@ class PublishedWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
- args = parser_publish.parse_args()
-
- # Validate name and comment length
- if args.marked_name and len(args.marked_name) > 20:
- raise ValueError("Marked name cannot exceed 20 characters")
- if args.marked_comment and len(args.marked_comment) > 100:
- raise ValueError("Marked comment cannot exceed 100 characters")
+ args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
with Session(db.engine) as session:
@@ -666,10 +707,10 @@ class PublishedWorkflowApi(Resource):
@console_ns.route("/apps//workflows/default-workflow-block-configs")
class DefaultBlockConfigsApi(Resource):
- @api.doc("get_default_block_configs")
- @api.doc(description="Get default block configurations for workflow")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Default block configurations retrieved successfully")
+ @console_ns.doc("get_default_block_configs")
+ @console_ns.doc(description="Get default block configurations for workflow")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Default block configurations retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -684,17 +725,14 @@ class DefaultBlockConfigsApi(Resource):
return workflow_service.get_default_block_configs()
-parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
-
-
@console_ns.route("/apps//workflows/default-workflow-block-configs/")
class DefaultBlockConfigApi(Resource):
- @api.doc("get_default_block_config")
- @api.doc(description="Get default block configuration by type")
- @api.doc(params={"app_id": "Application ID", "block_type": "Block type"})
- @api.response(200, "Default block configuration retrieved successfully")
- @api.response(404, "Block type not found")
- @api.expect(parser_block)
+ @console_ns.doc("get_default_block_config")
+ @console_ns.doc(description="Get default block configuration by type")
+ @console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
+ @console_ns.response(200, "Default block configuration retrieved successfully")
+ @console_ns.response(404, "Block type not found")
+ @console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -704,14 +742,12 @@ class DefaultBlockConfigApi(Resource):
"""
Get default block config
"""
- args = parser_block.parse_args()
-
- q = args.get("q")
+ args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
filters = None
- if q:
+ if args.q:
try:
- filters = json.loads(args.get("q", ""))
+ filters = json.loads(args.q)
except json.JSONDecodeError:
raise ValueError("Invalid filters")
@@ -720,24 +756,15 @@ class DefaultBlockConfigApi(Resource):
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
-parser_convert = (
- reqparse.RequestParser()
- .add_argument("name", type=str, required=False, nullable=True, location="json")
- .add_argument("icon_type", type=str, required=False, nullable=True, location="json")
- .add_argument("icon", type=str, required=False, nullable=True, location="json")
- .add_argument("icon_background", type=str, required=False, nullable=True, location="json")
-)
-
-
@console_ns.route("/apps//convert-to-workflow")
class ConvertToWorkflowApi(Resource):
- @api.expect(parser_convert)
- @api.doc("convert_to_workflow")
- @api.doc(description="Convert application to workflow mode")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Application converted to workflow successfully")
- @api.response(400, "Application cannot be converted")
- @api.response(403, "Permission denied")
+ @console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__])
+ @console_ns.doc("convert_to_workflow")
+ @console_ns.doc(description="Convert application to workflow mode")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Application converted to workflow successfully")
+ @console_ns.response(400, "Application cannot be converted")
+ @console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@@ -751,10 +778,8 @@ class ConvertToWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
- if request.data:
- args = parser_convert.parse_args()
- else:
- args = {}
+ payload = console_ns.payload or {}
+ args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
# convert to workflow mode
workflow_service = WorkflowService()
@@ -766,27 +791,18 @@ class ConvertToWorkflowApi(Resource):
}
-parser_workflows = (
- reqparse.RequestParser()
- .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
- .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
- .add_argument("user_id", type=str, required=False, location="args")
- .add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
-)
-
-
@console_ns.route("/apps//workflows")
class PublishedAllWorkflowApi(Resource):
- @api.expect(parser_workflows)
- @api.doc("get_all_published_workflows")
- @api.doc(description="Get all published workflows for an application")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Published workflows retrieved successfully", workflow_pagination_fields)
+ @console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
+ @console_ns.doc("get_all_published_workflows")
+ @console_ns.doc(description="Get all published workflows for an application")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- @marshal_with(workflow_pagination_fields)
+ @marshal_with(workflow_pagination_model)
@edit_permission_required
def get(self, app_model: App):
"""
@@ -794,16 +810,15 @@ class PublishedAllWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
- args = parser_workflows.parse_args()
- page = args["page"]
- limit = args["limit"]
- user_id = args.get("user_id")
- named_only = args.get("named_only", False)
+ args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+ page = args.page
+ limit = args.limit
+ user_id = args.user_id
+ named_only = args.named_only
if user_id:
if user_id != current_user.id:
raise Forbidden()
- user_id = cast(str, user_id)
workflow_service = WorkflowService()
with Session(db.engine) as session:
@@ -826,51 +841,32 @@ class PublishedAllWorkflowApi(Resource):
@console_ns.route("/apps//workflows/")
class WorkflowByIdApi(Resource):
- @api.doc("update_workflow_by_id")
- @api.doc(description="Update workflow by ID")
- @api.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
- @api.expect(
- api.model(
- "UpdateWorkflowRequest",
- {
- "environment_variables": fields.List(fields.Raw, description="Environment variables"),
- "conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
- },
- )
- )
- @api.response(200, "Workflow updated successfully", workflow_fields)
- @api.response(404, "Workflow not found")
- @api.response(403, "Permission denied")
+ @console_ns.doc("update_workflow_by_id")
+ @console_ns.doc(description="Update workflow by ID")
+ @console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
+ @console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
+ @console_ns.response(200, "Workflow updated successfully", workflow_model)
+ @console_ns.response(404, "Workflow not found")
+ @console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- @marshal_with(workflow_fields)
+ @marshal_with(workflow_model)
@edit_permission_required
def patch(self, app_model: App, workflow_id: str):
"""
Update workflow attributes
"""
current_user, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("marked_name", type=str, required=False, location="json")
- .add_argument("marked_comment", type=str, required=False, location="json")
- )
- args = parser.parse_args()
-
- # Validate name and comment length
- if args.marked_name and len(args.marked_name) > 20:
- raise ValueError("Marked name cannot exceed 20 characters")
- if args.marked_comment and len(args.marked_comment) > 100:
- raise ValueError("Marked comment cannot exceed 100 characters")
+ args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
# Prepare update data
update_data = {}
- if args.get("marked_name") is not None:
- update_data["marked_name"] = args["marked_name"]
- if args.get("marked_comment") is not None:
- update_data["marked_comment"] = args["marked_comment"]
+ if args.marked_name is not None:
+ update_data["marked_name"] = args.marked_name
+ if args.marked_comment is not None:
+ update_data["marked_comment"] = args.marked_comment
if not update_data:
return {"message": "No valid fields to update"}, 400
@@ -926,17 +922,17 @@ class WorkflowByIdApi(Resource):
@console_ns.route("/apps//workflows/draft/nodes//last-run")
class DraftWorkflowNodeLastRunApi(Resource):
- @api.doc("get_draft_workflow_node_last_run")
- @api.doc(description="Get last run result for draft workflow node")
- @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
- @api.response(200, "Node last run retrieved successfully", workflow_run_node_execution_fields)
- @api.response(404, "Node last run not found")
- @api.response(403, "Permission denied")
+ @console_ns.doc("get_draft_workflow_node_last_run")
+ @console_ns.doc(description="Get last run result for draft workflow node")
+ @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
+ @console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
+ @console_ns.response(404, "Node last run not found")
+ @console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- @marshal_with(workflow_run_node_execution_fields)
+ @marshal_with(workflow_run_node_execution_model)
def get(self, app_model: App, node_id: str):
srv = WorkflowService()
workflow = srv.get_draft_workflow(app_model)
@@ -959,20 +955,20 @@ class DraftWorkflowTriggerRunApi(Resource):
Path: /apps//workflows/draft/trigger/run
"""
- @api.doc("poll_draft_workflow_trigger_run")
- @api.doc(description="Poll for trigger events and execute full workflow when event arrives")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
+ @console_ns.doc("poll_draft_workflow_trigger_run")
+ @console_ns.doc(description="Poll for trigger events and execute full workflow when event arrives")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(
+ console_ns.model(
"DraftWorkflowTriggerRunRequest",
{
"node_id": fields.String(required=True, description="Node ID"),
},
)
)
- @api.response(200, "Trigger event received and workflow executed successfully")
- @api.response(403, "Permission denied")
- @api.response(500, "Internal server error")
+ @console_ns.response(200, "Trigger event received and workflow executed successfully")
+ @console_ns.response(403, "Permission denied")
+ @console_ns.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@@ -983,11 +979,8 @@ class DraftWorkflowTriggerRunApi(Resource):
Poll for trigger events and execute full workflow when event arrives
"""
current_user, _ = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument(
- "node_id", type=str, required=True, location="json", nullable=False
- )
- args = parser.parse_args()
- node_id = args["node_id"]
+ args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
+ node_id = args.node_id
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
@@ -1033,12 +1026,12 @@ class DraftWorkflowTriggerNodeApi(Resource):
Path: /apps//workflows/draft/nodes//trigger/run
"""
- @api.doc("poll_draft_workflow_trigger_node")
- @api.doc(description="Poll for trigger events and execute single node when event arrives")
- @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
- @api.response(200, "Trigger event received and node executed successfully")
- @api.response(403, "Permission denied")
- @api.response(500, "Internal server error")
+ @console_ns.doc("poll_draft_workflow_trigger_node")
+ @console_ns.doc(description="Poll for trigger events and execute single node when event arrives")
+ @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
+ @console_ns.response(200, "Trigger event received and node executed successfully")
+ @console_ns.response(403, "Permission denied")
+ @console_ns.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@@ -1112,20 +1105,13 @@ class DraftWorkflowTriggerRunAllApi(Resource):
Path: /apps//workflows/draft/trigger/run-all
"""
- @api.doc("draft_workflow_trigger_run_all")
- @api.doc(description="Full workflow debug when the start node is a trigger")
- @api.doc(params={"app_id": "Application ID"})
- @api.expect(
- api.model(
- "DraftWorkflowTriggerRunAllRequest",
- {
- "node_ids": fields.List(fields.String, required=True, description="Node IDs"),
- },
- )
- )
- @api.response(200, "Workflow executed successfully")
- @api.response(403, "Permission denied")
- @api.response(500, "Internal server error")
+ @console_ns.doc("draft_workflow_trigger_run_all")
+ @console_ns.doc(description="Full workflow debug when the start node is a trigger")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__])
+ @console_ns.response(200, "Workflow executed successfully")
+ @console_ns.response(403, "Permission denied")
+ @console_ns.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@@ -1137,11 +1123,8 @@ class DraftWorkflowTriggerRunAllApi(Resource):
"""
current_user, _ = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument(
- "node_ids", type=list, required=True, location="json", nullable=False
- )
- args = parser.parse_args()
- node_ids = args["node_ids"]
+ args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
+ node_ids = args.node_ids
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py
index d7ecc7c91b..fa67fb8154 100644
--- a/api/controllers/console/app/workflow_app_log.py
+++ b/api/controllers/console/app/workflow_app_log.py
@@ -1,86 +1,85 @@
+from datetime import datetime
+
from dateutil.parser import isoparse
-from flask_restx import Resource, marshal_with, reqparse
-from flask_restx.inputs import int_range
+from flask import request
+from flask_restx import Resource, marshal_with
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db
-from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
+from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs.login import login_required
from models import App
from models.model import AppMode
from services.workflow_app_service import WorkflowAppService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class WorkflowAppLogQuery(BaseModel):
+ keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
+ status: WorkflowExecutionStatus | None = Field(
+ default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
+ )
+ created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
+ created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
+ created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
+ created_by_account: str | None = Field(default=None, description="Filter by account")
+ detail: bool = Field(default=False, description="Whether to return detailed logs")
+ page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
+ limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
+
+ @field_validator("created_at__before", "created_at__after", mode="before")
+ @classmethod
+ def parse_datetime(cls, value: str | None) -> datetime | None:
+ if value in (None, ""):
+ return None
+ return isoparse(value) # type: ignore
+
+ @field_validator("detail", mode="before")
+ @classmethod
+ def parse_bool(cls, value: bool | str | None) -> bool:
+ if isinstance(value, bool):
+ return value
+ if value is None:
+ return False
+ lowered = value.lower()
+ if lowered in {"1", "true", "yes", "on"}:
+ return True
+ if lowered in {"0", "false", "no", "off"}:
+ return False
+ raise ValueError("Invalid boolean value for detail")
+
+
+console_ns.schema_model(
+ WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+# Register model for flask_restx to avoid dict type issues in Swagger
+workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
+
@console_ns.route("/apps//workflow-app-logs")
class WorkflowAppLogApi(Resource):
- @api.doc("get_workflow_app_logs")
- @api.doc(description="Get workflow application execution logs")
- @api.doc(params={"app_id": "Application ID"})
- @api.doc(
- params={
- "keyword": "Search keyword for filtering logs",
- "status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
- "created_at__before": "Filter logs created before this timestamp",
- "created_at__after": "Filter logs created after this timestamp",
- "created_by_end_user_session_id": "Filter by end user session ID",
- "created_by_account": "Filter by account",
- "detail": "Whether to return detailed logs",
- "page": "Page number (1-99999)",
- "limit": "Number of items per page (1-100)",
- }
- )
- @api.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_fields)
+ @console_ns.doc("get_workflow_app_logs")
+ @console_ns.doc(description="Get workflow application execution logs")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
+ @console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
- @marshal_with(workflow_app_log_pagination_fields)
+ @marshal_with(workflow_app_log_pagination_model)
def get(self, app_model: App):
"""
Get workflow app logs
"""
- parser = (
- reqparse.RequestParser()
- .add_argument("keyword", type=str, location="args")
- .add_argument(
- "status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
- )
- .add_argument(
- "created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
- )
- .add_argument(
- "created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
- )
- .add_argument(
- "created_by_end_user_session_id",
- type=str,
- location="args",
- required=False,
- default=None,
- )
- .add_argument(
- "created_by_account",
- type=str,
- location="args",
- required=False,
- default=None,
- )
- .add_argument("detail", type=bool, location="args", required=False, default=False)
- .add_argument("page", type=int_range(1, 99999), default=1, location="args")
- .add_argument("limit", type=int_range(1, 100), default=20, location="args")
- )
- args = parser.parse_args()
-
- args.status = WorkflowExecutionStatus(args.status) if args.status else None
- if args.created_at__before:
- args.created_at__before = isoparse(args.created_at__before)
-
- if args.created_at__after:
- args.created_at__after = isoparse(args.created_at__after)
+ args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py
index ca97d8520c..3382b65acc 100644
--- a/api/controllers/console/app/workflow_draft_variable.py
+++ b/api/controllers/console/app/workflow_draft_variable.py
@@ -1,13 +1,14 @@
import logging
from collections.abc import Callable
from functools import wraps
-from typing import NoReturn, ParamSpec, TypeVar
+from typing import Any, NoReturn, ParamSpec, TypeVar
-from flask import Response
-from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
+from flask import Response, request
+from flask_restx import Resource, fields, marshal, marshal_with
+from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
@@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class WorkflowDraftVariableListQuery(BaseModel):
+ page: int = Field(default=1, ge=1, le=100_000, description="Page number")
+ limit: int = Field(default=20, ge=1, le=100, description="Items per page")
+
+
+class WorkflowDraftVariableUpdatePayload(BaseModel):
+ name: str | None = Field(default=None, description="Variable name")
+ value: Any | None = Field(default=None, description="Variable value")
+
+
+console_ns.schema_model(
+ WorkflowDraftVariableListQuery.__name__,
+ WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+console_ns.schema_model(
+ WorkflowDraftVariableUpdatePayload.__name__,
+ WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
def _convert_values_to_json_serializable_object(value: Segment):
@@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
return _convert_values_to_json_serializable_object(value)
-def _create_pagination_parser():
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "page",
- type=inputs.int_range(1, 100_000),
- required=False,
- default=1,
- location="args",
- help="the page of data requested",
- )
- .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
- )
- return parser
-
-
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return value_type.exposed_type().value
@@ -141,6 +147,37 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
+# Register models for flask_restx to avoid dict type issues in Swagger
+workflow_draft_variable_without_value_model = console_ns.model(
+ "WorkflowDraftVariableWithoutValue", _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS
+)
+
+workflow_draft_variable_model = console_ns.model("WorkflowDraftVariable", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
+
+workflow_draft_env_variable_model = console_ns.model("WorkflowDraftEnvVariable", _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)
+
+workflow_draft_env_variable_list_fields_copy = _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS.copy()
+workflow_draft_env_variable_list_fields_copy["items"] = fields.List(fields.Nested(workflow_draft_env_variable_model))
+workflow_draft_env_variable_list_model = console_ns.model(
+ "WorkflowDraftEnvVariableList", workflow_draft_env_variable_list_fields_copy
+)
+
+workflow_draft_variable_list_without_value_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS.copy()
+workflow_draft_variable_list_without_value_fields_copy["items"] = fields.List(
+ fields.Nested(workflow_draft_variable_without_value_model), attribute=_get_items
+)
+workflow_draft_variable_list_without_value_model = console_ns.model(
+ "WorkflowDraftVariableListWithoutValue", workflow_draft_variable_list_without_value_fields_copy
+)
+
+workflow_draft_variable_list_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS.copy()
+workflow_draft_variable_list_fields_copy["items"] = fields.List(
+ fields.Nested(workflow_draft_variable_model), attribute=_get_items
+)
+workflow_draft_variable_list_model = console_ns.model(
+ "WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
+)
+
P = ParamSpec("P")
R = TypeVar("R")
@@ -170,20 +207,21 @@ def _api_prerequisite(f: Callable[P, R]):
@console_ns.route("/apps//workflows/draft/variables")
class WorkflowVariableCollectionApi(Resource):
- @api.expect(_create_pagination_parser())
- @api.doc("get_workflow_variables")
- @api.doc(description="Get draft workflow variables")
- @api.doc(params={"app_id": "Application ID"})
- @api.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
- @api.response(200, "Workflow variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
+ @console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
+ @console_ns.doc("get_workflow_variables")
+ @console_ns.doc(description="Get draft workflow variables")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
+ @console_ns.response(
+ 200, "Workflow variables retrieved successfully", workflow_draft_variable_list_without_value_model
+ )
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
+ @marshal_with(workflow_draft_variable_list_without_value_model)
def get(self, app_model: App):
"""
Get draft workflow
"""
- parser = _create_pagination_parser()
- args = parser.parse_args()
+ args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# fetch draft workflow by app_model
workflow_service = WorkflowService()
@@ -204,9 +242,9 @@ class WorkflowVariableCollectionApi(Resource):
return workflow_vars
- @api.doc("delete_workflow_variables")
- @api.doc(description="Delete all draft workflow variables")
- @api.response(204, "Workflow variables deleted successfully")
+ @console_ns.doc("delete_workflow_variables")
+ @console_ns.doc(description="Delete all draft workflow variables")
+ @console_ns.response(204, "Workflow variables deleted successfully")
@_api_prerequisite
def delete(self, app_model: App):
draft_var_srv = WorkflowDraftVariableService(
@@ -237,12 +275,12 @@ def validate_node_id(node_id: str) -> NoReturn | None:
@console_ns.route("/apps//workflows/draft/nodes//variables")
class NodeVariableCollectionApi(Resource):
- @api.doc("get_node_variables")
- @api.doc(description="Get variables for a specific node")
- @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
- @api.response(200, "Node variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ @console_ns.doc("get_node_variables")
+ @console_ns.doc(description="Get variables for a specific node")
+ @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
+ @console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ @marshal_with(workflow_draft_variable_list_model)
def get(self, app_model: App, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
@@ -253,9 +291,9 @@ class NodeVariableCollectionApi(Resource):
return node_vars
- @api.doc("delete_node_variables")
- @api.doc(description="Delete all variables for a specific node")
- @api.response(204, "Node variables deleted successfully")
+ @console_ns.doc("delete_node_variables")
+ @console_ns.doc(description="Delete all variables for a specific node")
+ @console_ns.response(204, "Node variables deleted successfully")
@_api_prerequisite
def delete(self, app_model: App, node_id: str):
validate_node_id(node_id)
@@ -270,13 +308,13 @@ class VariableApi(Resource):
_PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value"
- @api.doc("get_variable")
- @api.doc(description="Get a specific workflow variable")
- @api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
- @api.response(200, "Variable retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
- @api.response(404, "Variable not found")
+ @console_ns.doc("get_variable")
+ @console_ns.doc(description="Get a specific workflow variable")
+ @console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
+ @console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
+ @console_ns.response(404, "Variable not found")
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
+ @marshal_with(workflow_draft_variable_model)
def get(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
@@ -288,21 +326,13 @@ class VariableApi(Resource):
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
- @api.doc("update_variable")
- @api.doc(description="Update a workflow variable")
- @api.expect(
- api.model(
- "UpdateVariableRequest",
- {
- "name": fields.String(description="Variable name"),
- "value": fields.Raw(description="Variable value"),
- },
- )
- )
- @api.response(200, "Variable updated successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
- @api.response(404, "Variable not found")
+ @console_ns.doc("update_variable")
+ @console_ns.doc(description="Update a workflow variable")
+ @console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
+ @console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
+ @console_ns.response(404, "Variable not found")
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
+ @marshal_with(workflow_draft_variable_model)
def patch(self, app_model: App, variable_id: str):
# Request payload for file types:
#
@@ -325,16 +355,10 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
- parser = (
- reqparse.RequestParser()
- .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
- .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
- )
-
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
- args = parser.parse_args(strict=True)
+ args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
@@ -342,8 +366,8 @@ class VariableApi(Resource):
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
- new_name = args.get(self._PATCH_NAME_FIELD, None)
- raw_value = args.get(self._PATCH_VALUE_FIELD, None)
+ new_name = args_model.name
+ raw_value = args_model.value
if new_name is None and raw_value is None:
return variable
@@ -364,10 +388,10 @@ class VariableApi(Resource):
db.session.commit()
return variable
- @api.doc("delete_variable")
- @api.doc(description="Delete a workflow variable")
- @api.response(204, "Variable deleted successfully")
- @api.response(404, "Variable not found")
+ @console_ns.doc("delete_variable")
+ @console_ns.doc(description="Delete a workflow variable")
+ @console_ns.response(204, "Variable deleted successfully")
+ @console_ns.response(404, "Variable not found")
@_api_prerequisite
def delete(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
@@ -385,12 +409,12 @@ class VariableApi(Resource):
@console_ns.route("/apps//workflows/draft/variables//reset")
class VariableResetApi(Resource):
- @api.doc("reset_variable")
- @api.doc(description="Reset a workflow variable to its default value")
- @api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
- @api.response(200, "Variable reset successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
- @api.response(204, "Variable reset (no content)")
- @api.response(404, "Variable not found")
+ @console_ns.doc("reset_variable")
+ @console_ns.doc(description="Reset a workflow variable to its default value")
+ @console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
+ @console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
+ @console_ns.response(204, "Variable reset (no content)")
+ @console_ns.response(404, "Variable not found")
@_api_prerequisite
def put(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
@@ -414,7 +438,7 @@ class VariableResetApi(Resource):
if resetted is None:
return Response("", 204)
else:
- return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
+ return marshal(resetted, workflow_draft_variable_model)
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
@@ -433,13 +457,13 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
@console_ns.route("/apps//workflows/draft/conversation-variables")
class ConversationVariableCollectionApi(Resource):
- @api.doc("get_conversation_variables")
- @api.doc(description="Get conversation variables for workflow")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Conversation variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
- @api.response(404, "Draft workflow not found")
+ @console_ns.doc("get_conversation_variables")
+ @console_ns.doc(description="Get conversation variables for workflow")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
+ @console_ns.response(404, "Draft workflow not found")
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ @marshal_with(workflow_draft_variable_list_model)
def get(self, app_model: App):
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
# so their IDs can be returned to the caller.
@@ -455,23 +479,23 @@ class ConversationVariableCollectionApi(Resource):
@console_ns.route("/apps//workflows/draft/system-variables")
class SystemVariableCollectionApi(Resource):
- @api.doc("get_system_variables")
- @api.doc(description="Get system variables for workflow")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "System variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ @console_ns.doc("get_system_variables")
+ @console_ns.doc(description="Get system variables for workflow")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ @marshal_with(workflow_draft_variable_list_model)
def get(self, app_model: App):
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
@console_ns.route("/apps//workflows/draft/environment-variables")
class EnvironmentVariableCollectionApi(Resource):
- @api.doc("get_environment_variables")
- @api.doc(description="Get environment variables for workflow")
- @api.doc(params={"app_id": "Application ID"})
- @api.response(200, "Environment variables retrieved successfully")
- @api.response(404, "Draft workflow not found")
+ @console_ns.doc("get_environment_variables")
+ @console_ns.doc(description="Get environment variables for workflow")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.response(200, "Environment variables retrieved successfully")
+ @console_ns.response(404, "Draft workflow not found")
@_api_prerequisite
def get(self, app_model: App):
"""
diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py
index 23c228efbe..8f1871f1e9 100644
--- a/api/controllers/console/app/workflow_run.py
+++ b/api/controllers/console/app/workflow_run.py
@@ -1,15 +1,21 @@
-from typing import cast
+from typing import Literal, cast
-from flask_restx import Resource, marshal_with, reqparse
-from flask_restx.inputs import int_range
+from flask import request
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field, field_validator
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
+from fields.end_user_fields import simple_end_user_fields
+from fields.member_fields import simple_account_fields
from fields.workflow_run_fields import (
+ advanced_chat_workflow_run_for_list_fields,
advanced_chat_workflow_run_pagination_fields,
workflow_run_count_fields,
workflow_run_detail_fields,
+ workflow_run_for_list_fields,
+ workflow_run_node_execution_fields,
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
@@ -22,96 +28,148 @@ from services.workflow_run_service import WorkflowRunService
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
+# Register models for flask_restx to avoid dict type issues in Swagger
+# Register in dependency order: base models first, then dependent models
-def _parse_workflow_run_list_args():
- """
- Parse common arguments for workflow run list endpoints.
+# Base models
+simple_account_model = console_ns.model("SimpleAccount", simple_account_fields)
- Returns:
- Parsed arguments containing last_id, limit, status, and triggered_from filters
- """
- parser = (
- reqparse.RequestParser()
- .add_argument("last_id", type=uuid_value, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- .add_argument(
- "status",
- type=str,
- choices=WORKFLOW_RUN_STATUS_CHOICES,
- location="args",
- required=False,
- )
- .add_argument(
- "triggered_from",
- type=str,
- choices=["debugging", "app-run"],
- location="args",
- required=False,
- help="Filter by trigger source: debugging or app-run",
- )
+simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
+
+# Models that depend on simple_account_fields
+workflow_run_for_list_fields_copy = workflow_run_for_list_fields.copy()
+workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
+ simple_account_model, attribute="created_by_account", allow_null=True
+)
+workflow_run_for_list_model = console_ns.model("WorkflowRunForList", workflow_run_for_list_fields_copy)
+
+advanced_chat_workflow_run_for_list_fields_copy = advanced_chat_workflow_run_for_list_fields.copy()
+advanced_chat_workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
+ simple_account_model, attribute="created_by_account", allow_null=True
+)
+advanced_chat_workflow_run_for_list_model = console_ns.model(
+ "AdvancedChatWorkflowRunForList", advanced_chat_workflow_run_for_list_fields_copy
+)
+
+workflow_run_detail_fields_copy = workflow_run_detail_fields.copy()
+workflow_run_detail_fields_copy["created_by_account"] = fields.Nested(
+ simple_account_model, attribute="created_by_account", allow_null=True
+)
+workflow_run_detail_fields_copy["created_by_end_user"] = fields.Nested(
+ simple_end_user_model, attribute="created_by_end_user", allow_null=True
+)
+workflow_run_detail_model = console_ns.model("WorkflowRunDetail", workflow_run_detail_fields_copy)
+
+workflow_run_node_execution_fields_copy = workflow_run_node_execution_fields.copy()
+workflow_run_node_execution_fields_copy["created_by_account"] = fields.Nested(
+ simple_account_model, attribute="created_by_account", allow_null=True
+)
+workflow_run_node_execution_fields_copy["created_by_end_user"] = fields.Nested(
+ simple_end_user_model, attribute="created_by_end_user", allow_null=True
+)
+workflow_run_node_execution_model = console_ns.model(
+ "WorkflowRunNodeExecution", workflow_run_node_execution_fields_copy
+)
+
+# Simple models without nested dependencies
+workflow_run_count_model = console_ns.model("WorkflowRunCount", workflow_run_count_fields)
+
+# Pagination models that depend on list models
+advanced_chat_workflow_run_pagination_fields_copy = advanced_chat_workflow_run_pagination_fields.copy()
+advanced_chat_workflow_run_pagination_fields_copy["data"] = fields.List(
+ fields.Nested(advanced_chat_workflow_run_for_list_model), attribute="data"
+)
+advanced_chat_workflow_run_pagination_model = console_ns.model(
+ "AdvancedChatWorkflowRunPagination", advanced_chat_workflow_run_pagination_fields_copy
+)
+
+workflow_run_pagination_fields_copy = workflow_run_pagination_fields.copy()
+workflow_run_pagination_fields_copy["data"] = fields.List(fields.Nested(workflow_run_for_list_model), attribute="data")
+workflow_run_pagination_model = console_ns.model("WorkflowRunPagination", workflow_run_pagination_fields_copy)
+
+workflow_run_node_execution_list_fields_copy = workflow_run_node_execution_list_fields.copy()
+workflow_run_node_execution_list_fields_copy["data"] = fields.List(fields.Nested(workflow_run_node_execution_model))
+workflow_run_node_execution_list_model = console_ns.model(
+ "WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
+)
+
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class WorkflowRunListQuery(BaseModel):
+ last_id: str | None = Field(default=None, description="Last run ID for pagination")
+ limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
+ status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
+ default=None, description="Workflow run status filter"
)
- return parser.parse_args()
-
-
-def _parse_workflow_run_count_args():
- """
- Parse common arguments for workflow run count endpoints.
-
- Returns:
- Parsed arguments containing status, time_range, and triggered_from filters
- """
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "status",
- type=str,
- choices=WORKFLOW_RUN_STATUS_CHOICES,
- location="args",
- required=False,
- )
- .add_argument(
- "time_range",
- type=time_duration,
- location="args",
- required=False,
- help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
- )
- .add_argument(
- "triggered_from",
- type=str,
- choices=["debugging", "app-run"],
- location="args",
- required=False,
- help="Filter by trigger source: debugging or app-run",
- )
+ triggered_from: Literal["debugging", "app-run"] | None = Field(
+ default=None, description="Filter by trigger source: debugging or app-run"
)
- return parser.parse_args()
+
+ @field_validator("last_id")
+ @classmethod
+ def validate_last_id(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class WorkflowRunCountQuery(BaseModel):
+ status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
+ default=None, description="Workflow run status filter"
+ )
+ time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
+ triggered_from: Literal["debugging", "app-run"] | None = Field(
+ default=None, description="Filter by trigger source: debugging or app-run"
+ )
+
+ @field_validator("time_range")
+ @classmethod
+ def validate_time_range(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return time_duration(value)
+
+
+console_ns.schema_model(
+ WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+console_ns.schema_model(
+ WorkflowRunCountQuery.__name__,
+ WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
@console_ns.route("/apps//advanced-chat/workflow-runs")
class AdvancedChatAppWorkflowRunListApi(Resource):
- @api.doc("get_advanced_chat_workflow_runs")
- @api.doc(description="Get advanced chat workflow run list")
- @api.doc(params={"app_id": "Application ID"})
- @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
- @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
- @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
- @api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
+ @console_ns.doc("get_advanced_chat_workflow_runs")
+ @console_ns.doc(description="Get advanced chat workflow run list")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
+ @console_ns.doc(
+ params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
+ )
+ @console_ns.doc(
+ params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
+ )
+ @console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
+ @console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
- @marshal_with(advanced_chat_workflow_run_pagination_fields)
+ @marshal_with(advanced_chat_workflow_run_pagination_model)
def get(self, app_model: App):
"""
Get advanced chat app workflow run list
"""
- args = _parse_workflow_run_list_args()
+ args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+ args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified
triggered_from = (
- WorkflowRunTriggeredFrom(args.get("triggered_from"))
- if args.get("triggered_from")
+ WorkflowRunTriggeredFrom(args_model.triggered_from)
+ if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)
@@ -125,11 +183,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@console_ns.route("/apps//advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource):
- @api.doc("get_advanced_chat_workflow_runs_count")
- @api.doc(description="Get advanced chat workflow runs count statistics")
- @api.doc(params={"app_id": "Application ID"})
- @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
- @api.doc(
+ @console_ns.doc("get_advanced_chat_workflow_runs_count")
+ @console_ns.doc(description="Get advanced chat workflow runs count statistics")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.doc(
+ params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
+ )
+ @console_ns.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
@@ -137,23 +197,27 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
)
}
)
- @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
- @api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
+ @console_ns.doc(
+ params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
+ )
+ @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
+ @console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
- @marshal_with(workflow_run_count_fields)
+ @marshal_with(workflow_run_count_model)
def get(self, app_model: App):
"""
Get advanced chat workflow runs count statistics
"""
- args = _parse_workflow_run_count_args()
+ args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+ args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified
triggered_from = (
- WorkflowRunTriggeredFrom(args.get("triggered_from"))
- if args.get("triggered_from")
+ WorkflowRunTriggeredFrom(args_model.triggered_from)
+ if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)
@@ -170,28 +234,34 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
@console_ns.route("/apps//workflow-runs")
class WorkflowRunListApi(Resource):
- @api.doc("get_workflow_runs")
- @api.doc(description="Get workflow run list")
- @api.doc(params={"app_id": "Application ID"})
- @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
- @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
- @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
- @api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
+ @console_ns.doc("get_workflow_runs")
+ @console_ns.doc(description="Get workflow run list")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
+ @console_ns.doc(
+ params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
+ )
+ @console_ns.doc(
+ params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
+ )
+ @console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
+ @console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- @marshal_with(workflow_run_pagination_fields)
+ @marshal_with(workflow_run_pagination_model)
def get(self, app_model: App):
"""
Get workflow run list
"""
- args = _parse_workflow_run_list_args()
+ args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+ args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
- WorkflowRunTriggeredFrom(args.get("triggered_from"))
- if args.get("triggered_from")
+ WorkflowRunTriggeredFrom(args_model.triggered_from)
+ if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)
@@ -205,11 +275,13 @@ class WorkflowRunListApi(Resource):
@console_ns.route("/apps//workflow-runs/count")
class WorkflowRunCountApi(Resource):
- @api.doc("get_workflow_runs_count")
- @api.doc(description="Get workflow runs count statistics")
- @api.doc(params={"app_id": "Application ID"})
- @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
- @api.doc(
+ @console_ns.doc("get_workflow_runs_count")
+ @console_ns.doc(description="Get workflow runs count statistics")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.doc(
+ params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
+ )
+ @console_ns.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
@@ -217,23 +289,27 @@ class WorkflowRunCountApi(Resource):
)
}
)
- @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
- @api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
+ @console_ns.doc(
+ params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
+ )
+ @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
+ @console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- @marshal_with(workflow_run_count_fields)
+ @marshal_with(workflow_run_count_model)
def get(self, app_model: App):
"""
Get workflow runs count statistics
"""
- args = _parse_workflow_run_count_args()
+ args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+ args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
- WorkflowRunTriggeredFrom(args.get("triggered_from"))
- if args.get("triggered_from")
+ WorkflowRunTriggeredFrom(args_model.triggered_from)
+ if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)
@@ -250,16 +326,16 @@ class WorkflowRunCountApi(Resource):
@console_ns.route("/apps//workflow-runs/")
class WorkflowRunDetailApi(Resource):
- @api.doc("get_workflow_run_detail")
- @api.doc(description="Get workflow run detail")
- @api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
- @api.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_fields)
- @api.response(404, "Workflow run not found")
+ @console_ns.doc("get_workflow_run_detail")
+ @console_ns.doc(description="Get workflow run detail")
+ @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
+ @console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
+ @console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- @marshal_with(workflow_run_detail_fields)
+ @marshal_with(workflow_run_detail_model)
def get(self, app_model: App, run_id):
"""
Get workflow run detail
@@ -274,16 +350,16 @@ class WorkflowRunDetailApi(Resource):
@console_ns.route("/apps//workflow-runs//node-executions")
class WorkflowRunNodeExecutionListApi(Resource):
- @api.doc("get_workflow_run_node_executions")
- @api.doc(description="Get workflow run node execution list")
- @api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
- @api.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_fields)
- @api.response(404, "Workflow run not found")
+ @console_ns.doc("get_workflow_run_node_executions")
+ @console_ns.doc(description="Get workflow run node execution list")
+ @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
+ @console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
+ @console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
- @marshal_with(workflow_run_node_execution_list_fields)
+ @marshal_with(workflow_run_node_execution_list_model)
def get(self, app_model: App, run_id):
"""
Get workflow run node execution list
diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py
index ef5205c1ee..e48cf42762 100644
--- a/api/controllers/console/app/workflow_statistic.py
+++ b/api/controllers/console/app/workflow_statistic.py
@@ -1,18 +1,38 @@
-from flask import abort, jsonify
-from flask_restx import Resource, reqparse
+from flask import abort, jsonify, request
+from flask_restx import Resource
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
-from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class WorkflowStatisticQuery(BaseModel):
+ start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
+ end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
+
+ @field_validator("start", "end", mode="before")
+ @classmethod
+ def blank_to_none(cls, value: str | None) -> str | None:
+ if value == "":
+ return None
+ return value
+
+
+console_ns.schema_model(
+ WorkflowStatisticQuery.__name__,
+ WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
@console_ns.route("/apps//workflow/statistics/daily-conversations")
class WorkflowDailyRunsStatistic(Resource):
@@ -21,11 +41,11 @@ class WorkflowDailyRunsStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
- @api.doc("get_workflow_daily_runs_statistic")
- @api.doc(description="Get workflow daily runs statistics")
- @api.doc(params={"app_id": "Application ID"})
- @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
- @api.response(200, "Daily runs statistics retrieved successfully")
+ @console_ns.doc("get_workflow_daily_runs_statistic")
+ @console_ns.doc(description="Get workflow daily runs statistics")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
+ @console_ns.response(200, "Daily runs statistics retrieved successfully")
@get_app_model
@setup_required
@login_required
@@ -33,17 +53,12 @@ class WorkflowDailyRunsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- )
- args = parser.parse_args()
+ args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
assert account.timezone is not None
try:
- start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
+ start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@@ -66,11 +81,11 @@ class WorkflowDailyTerminalsStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
- @api.doc("get_workflow_daily_terminals_statistic")
- @api.doc(description="Get workflow daily terminals statistics")
- @api.doc(params={"app_id": "Application ID"})
- @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
- @api.response(200, "Daily terminals statistics retrieved successfully")
+ @console_ns.doc("get_workflow_daily_terminals_statistic")
+ @console_ns.doc(description="Get workflow daily terminals statistics")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
+ @console_ns.response(200, "Daily terminals statistics retrieved successfully")
@get_app_model
@setup_required
@login_required
@@ -78,17 +93,12 @@ class WorkflowDailyTerminalsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- )
- args = parser.parse_args()
+ args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
assert account.timezone is not None
try:
- start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
+ start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@@ -111,11 +121,11 @@ class WorkflowDailyTokenCostStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
- @api.doc("get_workflow_daily_token_cost_statistic")
- @api.doc(description="Get workflow daily token cost statistics")
- @api.doc(params={"app_id": "Application ID"})
- @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
- @api.response(200, "Daily token cost statistics retrieved successfully")
+ @console_ns.doc("get_workflow_daily_token_cost_statistic")
+ @console_ns.doc(description="Get workflow daily token cost statistics")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
+ @console_ns.response(200, "Daily token cost statistics retrieved successfully")
@get_app_model
@setup_required
@login_required
@@ -123,17 +133,12 @@ class WorkflowDailyTokenCostStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- )
- args = parser.parse_args()
+ args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
assert account.timezone is not None
try:
- start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
+ start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
@@ -156,11 +161,11 @@ class WorkflowAverageAppInteractionStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
- @api.doc("get_workflow_average_app_interaction_statistic")
- @api.doc(description="Get workflow average app interaction statistics")
- @api.doc(params={"app_id": "Application ID"})
- @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
- @api.response(200, "Average app interaction statistics retrieved successfully")
+ @console_ns.doc("get_workflow_average_app_interaction_statistic")
+ @console_ns.doc(description="Get workflow average app interaction statistics")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
+ @console_ns.response(200, "Average app interaction statistics retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -168,17 +173,12 @@ class WorkflowAverageAppInteractionStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
- )
- args = parser.parse_args()
+ args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
assert account.timezone is not None
try:
- start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
+ start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py
index 785813c5f0..5d16e4f979 100644
--- a/api/controllers/console/app/workflow_trigger.py
+++ b/api/controllers/console/app/workflow_trigger.py
@@ -1,14 +1,13 @@
import logging
-from flask_restx import Resource, marshal_with, reqparse
+from flask import request
+from flask_restx import Resource, marshal_with
+from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
-from controllers.console import api
-from controllers.console.app.wraps import get_app_model
-from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
@@ -16,12 +15,35 @@ from models.enums import AppTriggerStatus
from models.model import Account, App, AppMode
from models.trigger import AppTrigger, WorkflowWebhookTrigger
+from .. import console_ns
+from ..app.wraps import get_app_model
+from ..wraps import account_initialization_required, edit_permission_required, setup_required
+
logger = logging.getLogger(__name__)
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+class Parser(BaseModel):
+ node_id: str
+
+
+class ParserEnable(BaseModel):
+ trigger_id: str
+ enable_trigger: bool
+
+
+console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+console_ns.schema_model(
+ ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+)
+
+
+@console_ns.route("/apps//workflows/triggers/webhook")
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
+ @console_ns.expect(console_ns.models[Parser.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -29,10 +51,9 @@ class WebhookTriggerApi(Resource):
@marshal_with(webhook_trigger_fields)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
- parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required")
- args = parser.parse_args()
+ args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
- node_id = str(args["node_id"])
+ node_id = args.node_id
with Session(db.engine) as session:
# Get webhook trigger for this app and node
@@ -51,6 +72,7 @@ class WebhookTriggerApi(Resource):
return webhook_trigger
+@console_ns.route("/apps//triggers")
class AppTriggersApi(Resource):
"""App Triggers list API"""
@@ -90,7 +112,9 @@ class AppTriggersApi(Resource):
return {"data": triggers}
+@console_ns.route("/apps//trigger-enable")
class AppTriggerEnableApi(Resource):
+ @console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
@setup_required
@login_required
@account_initialization_required
@@ -99,17 +123,11 @@ class AppTriggerEnableApi(Resource):
@marshal_with(trigger_fields)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
- parser = (
- reqparse.RequestParser()
- .add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
- .add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ args = ParserEnable.model_validate(console_ns.payload)
assert current_user.current_tenant_id is not None
- trigger_id = args["trigger_id"]
-
+ trigger_id = args.trigger_id
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
@@ -124,7 +142,7 @@ class AppTriggerEnableApi(Resource):
raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean
- trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
+ trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
@@ -137,8 +155,3 @@ class AppTriggerEnableApi(Resource):
trigger.icon = "" # type: ignore
return trigger
-
-
-api.add_resource(WebhookTriggerApi, "/apps//workflows/triggers/webhook")
-api.add_resource(AppTriggersApi, "/apps//triggers")
-api.add_resource(AppTriggerEnableApi, "/apps//trigger-enable")
diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py
index 2eeef079a1..6834656a7f 100644
--- a/api/controllers/console/auth/activate.py
+++ b/api/controllers/console/auth/activate.py
@@ -1,32 +1,57 @@
from flask import request
-from flask_restx import Resource, fields, reqparse
+from flask_restx import Resource, fields
+from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
-from libs.helper import StrLen, email, extract_remote_ip, timezone
+from libs.helper import EmailStr, extract_remote_ip, timezone
from models import AccountStatus
from services.account_service import AccountService, RegisterService
-active_check_parser = (
- reqparse.RequestParser()
- .add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
- .add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
- .add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
-)
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class ActivateCheckQuery(BaseModel):
+ workspace_id: str | None = Field(default=None)
+ email: EmailStr | None = Field(default=None)
+ token: str
+
+
+class ActivatePayload(BaseModel):
+ workspace_id: str | None = Field(default=None)
+ email: EmailStr | None = Field(default=None)
+ token: str
+ name: str = Field(..., max_length=30)
+ interface_language: str = Field(...)
+ timezone: str = Field(...)
+
+ @field_validator("interface_language")
+ @classmethod
+ def validate_lang(cls, value: str) -> str:
+ return supported_language(value)
+
+ @field_validator("timezone")
+ @classmethod
+ def validate_tz(cls, value: str) -> str:
+ return timezone(value)
+
+
+for model in (ActivateCheckQuery, ActivatePayload):
+ console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/activate/check")
class ActivateCheckApi(Resource):
- @api.doc("check_activation_token")
- @api.doc(description="Check if activation token is valid")
- @api.expect(active_check_parser)
- @api.response(
+ @console_ns.doc("check_activation_token")
+ @console_ns.doc(description="Check if activation token is valid")
+ @console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
+ @console_ns.response(
200,
"Success",
- api.model(
+ console_ns.model(
"ActivationCheckResponse",
{
"is_valid": fields.Boolean(description="Whether token is valid"),
@@ -35,11 +60,11 @@ class ActivateCheckApi(Resource):
),
)
def get(self):
- args = active_check_parser.parse_args()
+ args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- workspaceId = args["workspace_id"]
- reg_email = args["email"]
- token = args["token"]
+ workspaceId = args.workspace_id
+ reg_email = args.email
+ token = args.token
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation:
@@ -56,26 +81,15 @@ class ActivateCheckApi(Resource):
return {"is_valid": False}
-active_parser = (
- reqparse.RequestParser()
- .add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
- .add_argument("email", type=email, required=False, nullable=True, location="json")
- .add_argument("token", type=str, required=True, nullable=False, location="json")
- .add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
- .add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
- .add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/activate")
class ActivateApi(Resource):
- @api.doc("activate_account")
- @api.doc(description="Activate account with invitation token")
- @api.expect(active_parser)
- @api.response(
+ @console_ns.doc("activate_account")
+ @console_ns.doc(description="Activate account with invitation token")
+ @console_ns.expect(console_ns.models[ActivatePayload.__name__])
+ @console_ns.response(
200,
"Account activated successfully",
- api.model(
+ console_ns.model(
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
@@ -83,21 +97,21 @@ class ActivateApi(Resource):
},
),
)
- @api.response(400, "Already activated or invalid token")
+ @console_ns.response(400, "Already activated or invalid token")
def post(self):
- args = active_parser.parse_args()
+ args = ActivatePayload.model_validate(console_ns.payload)
- invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
+ invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
- RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
+ RegisterService.revoke_token(args.workspace_id, args.email, args.token)
account = invitation["account"]
- account.name = args["name"]
+ account.name = args.name
- account.interface_language = args["interface_language"]
- account.timezone = args["timezone"]
+ account.interface_language = args.interface_language
+ account.timezone = args.timezone
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py
index 9d7fcef183..905d0daef0 100644
--- a/api/controllers/console/auth/data_source_bearer_auth.py
+++ b/api/controllers/console/auth/data_source_bearer_auth.py
@@ -1,12 +1,26 @@
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, Field
-from controllers.console import console_ns
-from controllers.console.auth.error import ApiKeyAuthFailedError
-from controllers.console.wraps import is_admin_or_owner_required
from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
-from ..wraps import account_initialization_required, setup_required
+from .. import console_ns
+from ..auth.error import ApiKeyAuthFailedError
+from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
+
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class ApiKeyAuthBindingPayload(BaseModel):
+ category: str = Field(...)
+ provider: str = Field(...)
+ credentials: dict = Field(...)
+
+
+console_ns.schema_model(
+ ApiKeyAuthBindingPayload.__name__,
+ ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
@console_ns.route("/api-key-auth/data-source")
@@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource):
@login_required
@account_initialization_required
@is_admin_or_owner_required
+ @console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
def post(self):
# The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("category", type=str, required=True, nullable=False, location="json")
- .add_argument("provider", type=str, required=True, nullable=False, location="json")
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
- ApiKeyAuthService.validate_api_key_auth_args(args)
+ payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
+ data = payload.model_dump()
+ ApiKeyAuthService.validate_api_key_auth_args(data)
try:
- ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
+ ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200
diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py
index a27932ccd8..0dd7d33ae9 100644
--- a/api/controllers/console/auth/data_source_oauth.py
+++ b/api/controllers/console/auth/data_source_oauth.py
@@ -5,12 +5,11 @@ from flask import current_app, redirect, request
from flask_restx import Resource, fields
from configs import dify_config
-from controllers.console import api, console_ns
-from controllers.console.wraps import is_admin_or_owner_required
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
-from ..wraps import account_initialization_required, setup_required
+from .. import console_ns
+from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
logger = logging.getLogger(__name__)
@@ -29,19 +28,19 @@ def get_oauth_providers():
@console_ns.route("/oauth/data-source/")
class OAuthDataSource(Resource):
- @api.doc("oauth_data_source")
- @api.doc(description="Get OAuth authorization URL for data source provider")
- @api.doc(params={"provider": "Data source provider name (notion)"})
- @api.response(
+ @console_ns.doc("oauth_data_source")
+ @console_ns.doc(description="Get OAuth authorization URL for data source provider")
+ @console_ns.doc(params={"provider": "Data source provider name (notion)"})
+ @console_ns.response(
200,
"Authorization URL or internal setup success",
- api.model(
+ console_ns.model(
"OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
),
)
- @api.response(400, "Invalid provider")
- @api.response(403, "Admin privileges required")
+ @console_ns.response(400, "Invalid provider")
+ @console_ns.response(403, "Admin privileges required")
@is_admin_or_owner_required
def get(self, provider: str):
# The role of the current user in the table must be admin or owner
@@ -63,17 +62,17 @@ class OAuthDataSource(Resource):
@console_ns.route("/oauth/data-source/callback/")
class OAuthDataSourceCallback(Resource):
- @api.doc("oauth_data_source_callback")
- @api.doc(description="Handle OAuth callback from data source provider")
- @api.doc(
+ @console_ns.doc("oauth_data_source_callback")
+ @console_ns.doc(description="Handle OAuth callback from data source provider")
+ @console_ns.doc(
params={
"provider": "Data source provider name (notion)",
"code": "Authorization code from OAuth provider",
"error": "Error message from OAuth provider",
}
)
- @api.response(302, "Redirect to console with result")
- @api.response(400, "Invalid provider")
+ @console_ns.response(302, "Redirect to console with result")
+ @console_ns.response(400, "Invalid provider")
def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
@@ -94,17 +93,17 @@ class OAuthDataSourceCallback(Resource):
@console_ns.route("/oauth/data-source/binding/")
class OAuthDataSourceBinding(Resource):
- @api.doc("oauth_data_source_binding")
- @api.doc(description="Bind OAuth data source with authorization code")
- @api.doc(
+ @console_ns.doc("oauth_data_source_binding")
+ @console_ns.doc(description="Bind OAuth data source with authorization code")
+ @console_ns.doc(
params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"}
)
- @api.response(
+ @console_ns.response(
200,
"Data source binding success",
- api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
+ console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
)
- @api.response(400, "Invalid provider or code")
+ @console_ns.response(400, "Invalid provider or code")
def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
@@ -128,15 +127,15 @@ class OAuthDataSourceBinding(Resource):
@console_ns.route("/oauth/data-source///sync")
class OAuthDataSourceSync(Resource):
- @api.doc("oauth_data_source_sync")
- @api.doc(description="Sync data from OAuth data source")
- @api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
- @api.response(
+ @console_ns.doc("oauth_data_source_sync")
+ @console_ns.doc(description="Sync data from OAuth data source")
+ @console_ns.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
+ @console_ns.response(
200,
"Data source sync success",
- api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
+ console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
)
- @api.response(400, "Invalid provider or sync failed")
+ @console_ns.response(400, "Invalid provider or sync failed")
@setup_required
@login_required
@account_initialization_required
diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py
index fe2bb54e0b..fa082c735d 100644
--- a/api/controllers/console/auth/email_register.py
+++ b/api/controllers/console/auth/email_register.py
@@ -1,5 +1,6 @@
from flask import request
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -14,16 +15,45 @@ from controllers.console.auth.error import (
InvalidTokenError,
PasswordMismatchError,
)
-from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
-from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
from extensions.ext_database import db
-from libs.helper import email, extract_remote_ip
+from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
+from ..error import AccountInFreezeError, EmailSendIpLimitError
+from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
+
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class EmailRegisterSendPayload(BaseModel):
+ email: EmailStr = Field(..., description="Email address")
+ language: str | None = Field(default=None, description="Language code")
+
+
+class EmailRegisterValidityPayload(BaseModel):
+ email: EmailStr = Field(...)
+ code: str = Field(...)
+ token: str = Field(...)
+
+
+class EmailRegisterResetPayload(BaseModel):
+ token: str = Field(...)
+ new_password: str = Field(...)
+ password_confirm: str = Field(...)
+
+ @field_validator("new_password", "password_confirm")
+ @classmethod
+ def validate_password(cls, value: str) -> str:
+ return valid_password(value)
+
+
+for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
+ console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
@console_ns.route("/email-register/send-email")
class EmailRegisterSendEmailApi(Resource):
@@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- )
- args = parser.parse_args()
+ args = EmailRegisterSendPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
language = "en-US"
- if args["language"] in languages:
- language = args["language"]
+ if args.language in languages:
+ language = args.language
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
+ account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None
- token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
+ token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
return {"result": "success", "data": token}
@@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=str, required=True, location="json")
- .add_argument("code", type=str, required=True, location="json")
- .add_argument("token", type=str, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
- user_email = args["email"]
+ user_email = args.email
- is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
+ is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
- token_data = AccountService.get_email_register_data(args["token"])
+ token_data = AccountService.get_email_register_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
- if args["code"] != token_data.get("code"):
- AccountService.add_email_register_error_rate_limit(args["email"])
+ if args.code != token_data.get("code"):
+ AccountService.add_email_register_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
- AccountService.revoke_email_register_token(args["token"])
+ AccountService.revoke_email_register_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_email_register_token(
- user_email, code=args["code"], additional_data={"phase": "register"}
+ user_email, code=args.code, additional_data={"phase": "register"}
)
- AccountService.reset_email_register_error_rate_limit(args["email"])
+ AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("token", type=str, required=True, nullable=False, location="json")
- .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
- .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ args = EmailRegisterResetPayload.model_validate(console_ns.payload)
# Validate passwords match
- if args["new_password"] != args["password_confirm"]:
+ if args.new_password != args.password_confirm:
raise PasswordMismatchError()
# Validate token and get register data
- register_data = AccountService.get_email_register_data(args["token"])
+ register_data = AccountService.get_email_register_data(args.token)
if not register_data:
raise InvalidTokenError()
# Must use token in reset phase
@@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
- AccountService.revoke_email_register_token(args["token"])
+ AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
@@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource):
if account:
raise EmailAlreadyInUseError()
else:
- account = self._create_new_account(email, args["password_confirm"])
+ account = self._create_new_account(email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py
index 6be6ad51fe..661f591182 100644
--- a/api/controllers/console/auth/forgot_password.py
+++ b/api/controllers/console/auth/forgot_password.py
@@ -2,11 +2,12 @@ import base64
import secrets
from flask import request
-from flask_restx import Resource, fields, reqparse
+from flask_restx import Resource, fields
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.auth.error import (
EmailCodeError,
EmailPasswordResetLimitError,
@@ -18,30 +19,50 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
-from libs.helper import email, extract_remote_ip
+from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class ForgotPasswordSendPayload(BaseModel):
+ email: EmailStr = Field(...)
+ language: str | None = Field(default=None)
+
+
+class ForgotPasswordCheckPayload(BaseModel):
+ email: EmailStr = Field(...)
+ code: str = Field(...)
+ token: str = Field(...)
+
+
+class ForgotPasswordResetPayload(BaseModel):
+ token: str = Field(...)
+ new_password: str = Field(...)
+ password_confirm: str = Field(...)
+
+ @field_validator("new_password", "password_confirm")
+ @classmethod
+ def validate_password(cls, value: str) -> str:
+ return valid_password(value)
+
+
+for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
+ console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
@console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource):
- @api.doc("send_forgot_password_email")
- @api.doc(description="Send password reset email")
- @api.expect(
- api.model(
- "ForgotPasswordEmailRequest",
- {
- "email": fields.String(required=True, description="Email address"),
- "language": fields.String(description="Language for email (zh-Hans/en-US)"),
- },
- )
- )
- @api.response(
+ @console_ns.doc("send_forgot_password_email")
+ @console_ns.doc(description="Send password reset email")
+ @console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
+ @console_ns.response(
200,
"Email sent successfully",
- api.model(
+ console_ns.model(
"ForgotPasswordEmailResponse",
{
"result": fields.String(description="Operation result"),
@@ -50,32 +71,27 @@ class ForgotPasswordSendEmailApi(Resource):
},
),
)
- @api.response(400, "Invalid email or rate limit exceeded")
+ @console_ns.response(400, "Invalid email or rate limit exceeded")
@setup_required
@email_password_login_enabled
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- )
- args = parser.parse_args()
+ args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
- if args["language"] is not None and args["language"] == "zh-Hans":
+ if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
+ account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = AccountService.send_reset_password_email(
account=account,
- email=args["email"],
+ email=args.email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@@ -85,22 +101,13 @@ class ForgotPasswordSendEmailApi(Resource):
@console_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource):
- @api.doc("check_forgot_password_code")
- @api.doc(description="Verify password reset code")
- @api.expect(
- api.model(
- "ForgotPasswordCheckRequest",
- {
- "email": fields.String(required=True, description="Email address"),
- "code": fields.String(required=True, description="Verification code"),
- "token": fields.String(required=True, description="Reset token"),
- },
- )
- )
- @api.response(
+ @console_ns.doc("check_forgot_password_code")
+ @console_ns.doc(description="Verify password reset code")
+ @console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
+ @console_ns.response(
200,
"Code verified successfully",
- api.model(
+ console_ns.model(
"ForgotPasswordCheckResponse",
{
"is_valid": fields.Boolean(description="Whether code is valid"),
@@ -109,84 +116,63 @@ class ForgotPasswordCheckApi(Resource):
},
),
)
- @api.response(400, "Invalid code or token")
+ @console_ns.response(400, "Invalid code or token")
@setup_required
@email_password_login_enabled
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=str, required=True, location="json")
- .add_argument("code", type=str, required=True, location="json")
- .add_argument("token", type=str, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
- user_email = args["email"]
+ user_email = args.email
- is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
+ is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
- token_data = AccountService.get_reset_password_data(args["token"])
+ token_data = AccountService.get_reset_password_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
- if args["code"] != token_data.get("code"):
- AccountService.add_forgot_password_error_rate_limit(args["email"])
+ if args.code != token_data.get("code"):
+ AccountService.add_forgot_password_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
- AccountService.revoke_reset_password_token(args["token"])
+ AccountService.revoke_reset_password_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
- user_email, code=args["code"], additional_data={"phase": "reset"}
+ user_email, code=args.code, additional_data={"phase": "reset"}
)
- AccountService.reset_forgot_password_error_rate_limit(args["email"])
+ AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource):
- @api.doc("reset_password")
- @api.doc(description="Reset password with verification token")
- @api.expect(
- api.model(
- "ForgotPasswordResetRequest",
- {
- "token": fields.String(required=True, description="Verification token"),
- "new_password": fields.String(required=True, description="New password"),
- "password_confirm": fields.String(required=True, description="Password confirmation"),
- },
- )
- )
- @api.response(
+ @console_ns.doc("reset_password")
+ @console_ns.doc(description="Reset password with verification token")
+ @console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
+ @console_ns.response(
200,
"Password reset successfully",
- api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
+ console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
)
- @api.response(400, "Invalid token or password mismatch")
+ @console_ns.response(400, "Invalid token or password mismatch")
@setup_required
@email_password_login_enabled
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("token", type=str, required=True, nullable=False, location="json")
- .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
- .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
# Validate passwords match
- if args["new_password"] != args["password_confirm"]:
+ if args.new_password != args.password_confirm:
raise PasswordMismatchError()
# Validate token and get reset data
- reset_data = AccountService.get_reset_password_data(args["token"])
+ reset_data = AccountService.get_reset_password_data(args.token)
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
@@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
- AccountService.revoke_reset_password_token(args["token"])
+ AccountService.revoke_reset_password_token(args.token)
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
- password_hashed = hash_password(args["new_password"], salt)
+ password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py
index 77ecd5a5e4..f486f4c313 100644
--- a/api/controllers/console/auth/login.py
+++ b/api/controllers/console/auth/login.py
@@ -1,6 +1,7 @@
import flask_login
from flask import make_response, request
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, Field
import services
from configs import dify_config
@@ -23,7 +24,7 @@ from controllers.console.error import (
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
-from libs.helper import email, extract_remote_ip
+from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
@@ -40,6 +41,36 @@ from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class LoginPayload(BaseModel):
+ email: EmailStr = Field(..., description="Email address")
+ password: str = Field(..., description="Password")
+ remember_me: bool = Field(default=False, description="Remember me flag")
+ invite_token: str | None = Field(default=None, description="Invitation token")
+
+
+class EmailPayload(BaseModel):
+ email: EmailStr = Field(...)
+ language: str | None = Field(default=None)
+
+
+class EmailCodeLoginPayload(BaseModel):
+ email: EmailStr = Field(...)
+ code: str = Field(...)
+ token: str = Field(...)
+ language: str | None = Field(default=None)
+
+
+def reg(cls: type[BaseModel]):
+ console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+reg(LoginPayload)
+reg(EmailPayload)
+reg(EmailCodeLoginPayload)
+
@console_ns.route("/login")
class LoginApi(Resource):
@@ -47,41 +78,36 @@ class LoginApi(Resource):
@setup_required
@email_password_login_enabled
+ @console_ns.expect(console_ns.models[LoginPayload.__name__])
def post(self):
"""Authenticate user and login."""
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("password", type=str, required=True, location="json")
- .add_argument("remember_me", type=bool, required=False, default=False, location="json")
- .add_argument("invite_token", type=str, required=False, default=None, location="json")
- )
- args = parser.parse_args()
+ args = LoginPayload.model_validate(console_ns.payload)
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
- is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
+ is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
- invitation = args["invite_token"]
+ # TODO: why invitation is re-assigned with different type?
+ invitation = args.invite_token # type: ignore
if invitation:
- invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
+ invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
try:
if invitation:
- data = invitation.get("data", {})
+ data = invitation.get("data", {}) # type: ignore
invitee_email = data.get("email") if data else None
- if invitee_email != args["email"]:
+ if invitee_email != args.email:
raise InvalidEmailError()
- account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
+ account = AccountService.authenticate(args.email, args.password, args.invite_token)
else:
- account = AccountService.authenticate(args["email"], args["password"])
+ account = AccountService.authenticate(args.email, args.password)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
- AccountService.add_login_error_rate_limit(args["email"])
+ AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError()
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
@@ -97,7 +123,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(args["email"])
+ AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -134,25 +160,21 @@ class LogoutApi(Resource):
class ResetPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
+ @console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- )
- args = parser.parse_args()
+ args = EmailPayload.model_validate(console_ns.payload)
- if args["language"] is not None and args["language"] == "zh-Hans":
+ if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
- account = AccountService.get_user_through_email(args["email"])
+ account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
- email=args["email"],
+ email=args.email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@@ -164,30 +186,26 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
+ @console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- )
- args = parser.parse_args()
+ args = EmailPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
- if args["language"] is not None and args["language"] == "zh-Hans":
+ if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
- account = AccountService.get_user_through_email(args["email"])
+ account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
- token = AccountService.send_email_code_login_email(email=args["email"], language=language)
+ token = AccountService.send_email_code_login_email(email=args.email, language=language)
else:
raise AccountNotFound()
else:
@@ -199,30 +217,24 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource):
@setup_required
+ @console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=str, required=True, location="json")
- .add_argument("code", type=str, required=True, location="json")
- .add_argument("token", type=str, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- )
- args = parser.parse_args()
+ args = EmailCodeLoginPayload.model_validate(console_ns.payload)
- user_email = args["email"]
- language = args["language"]
+ user_email = args.email
+ language = args.language
- token_data = AccountService.get_email_code_login_data(args["token"])
+ token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
- if token_data["email"] != args["email"]:
+ if token_data["email"] != args.email:
raise InvalidEmailError()
- if token_data["code"] != args["code"]:
+ if token_data["code"] != args.code:
raise EmailCodeError()
- AccountService.revoke_email_code_login_token(args["token"])
+ AccountService.revoke_email_code_login_token(args.token)
try:
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError:
@@ -255,7 +267,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(args["email"])
+ AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py
index 29653b32ec..7ad1e56373 100644
--- a/api/controllers/console/auth/oauth.py
+++ b/api/controllers/console/auth/oauth.py
@@ -26,7 +26,7 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService
-from .. import api, console_ns
+from .. import console_ns
logger = logging.getLogger(__name__)
@@ -56,11 +56,13 @@ def get_oauth_providers():
@console_ns.route("/oauth/login/")
class OAuthLogin(Resource):
- @api.doc("oauth_login")
- @api.doc(description="Initiate OAuth login process")
- @api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"})
- @api.response(302, "Redirect to OAuth authorization URL")
- @api.response(400, "Invalid provider")
+ @console_ns.doc("oauth_login")
+ @console_ns.doc(description="Initiate OAuth login process")
+ @console_ns.doc(
+ params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"}
+ )
+ @console_ns.response(302, "Redirect to OAuth authorization URL")
+ @console_ns.response(400, "Invalid provider")
def get(self, provider: str):
invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers()
@@ -75,17 +77,17 @@ class OAuthLogin(Resource):
@console_ns.route("/oauth/authorize/")
class OAuthCallback(Resource):
- @api.doc("oauth_callback")
- @api.doc(description="Handle OAuth callback and complete login process")
- @api.doc(
+ @console_ns.doc("oauth_callback")
+ @console_ns.doc(description="Handle OAuth callback and complete login process")
+ @console_ns.doc(
params={
"provider": "OAuth provider name (github/google)",
"code": "Authorization code from OAuth provider",
"state": "Optional state parameter (used for invite token)",
}
)
- @api.response(302, "Redirect to console with access token")
- @api.response(400, "OAuth process failed")
+ @console_ns.response(302, "Redirect to console with access token")
+ @console_ns.response(400, "OAuth process failed")
def get(self, provider: str):
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py
index 5e12aa7d03..6162d88a0b 100644
--- a/api/controllers/console/auth/oauth_server.py
+++ b/api/controllers/console/auth/oauth_server.py
@@ -3,7 +3,8 @@ from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import jsonify, request
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
@@ -20,15 +21,34 @@ R = TypeVar("R")
T = TypeVar("T")
+class OAuthClientPayload(BaseModel):
+ client_id: str
+
+
+class OAuthProviderRequest(BaseModel):
+ client_id: str
+ redirect_uri: str
+
+
+class OAuthTokenRequest(BaseModel):
+ client_id: str
+ grant_type: str
+ code: str | None = None
+ client_secret: str | None = None
+ redirect_uri: str | None = None
+ refresh_token: str | None = None
+
+
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
- parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
- parsed_args = parser.parse_args()
- client_id = parsed_args.get("client_id")
- if not client_id:
+ json_data = request.get_json()
+ if json_data is None:
raise BadRequest("client_id is required")
+ payload = OAuthClientPayload.model_validate(json_data)
+ client_id = payload.client_id
+
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app:
raise NotFound("client_id is invalid")
@@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
- parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
- parsed_args = parser.parse_args()
- redirect_uri = parsed_args.get("redirect_uri")
+ payload = OAuthProviderRequest.model_validate(request.get_json())
+ redirect_uri = payload.redirect_uri
# check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris:
@@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
- parser = (
- reqparse.RequestParser()
- .add_argument("grant_type", type=str, required=True, location="json")
- .add_argument("code", type=str, required=False, location="json")
- .add_argument("client_secret", type=str, required=False, location="json")
- .add_argument("redirect_uri", type=str, required=False, location="json")
- .add_argument("refresh_token", type=str, required=False, location="json")
- )
- parsed_args = parser.parse_args()
+ payload = OAuthTokenRequest.model_validate(request.get_json())
try:
- grant_type = OAuthGrantType(parsed_args["grant_type"])
+ grant_type = OAuthGrantType(payload.grant_type)
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
- if not parsed_args["code"]:
+ if not payload.code:
raise BadRequest("code is required")
- if parsed_args["client_secret"] != oauth_provider_app.client_secret:
+ if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
- if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
+ if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
- grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
+ grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
@@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource):
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
- if not parsed_args["refresh_token"]:
+ if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
- grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
+ grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py
index 6efb4564ca..7f907dc420 100644
--- a/api/controllers/console/billing/billing.py
+++ b/api/controllers/console/billing/billing.py
@@ -1,14 +1,45 @@
import base64
-from flask_restx import Resource, fields, reqparse
+from flask import request
+from flask_restx import Resource, fields
+from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class SubscriptionQuery(BaseModel):
+ plan: str = Field(..., description="Subscription plan")
+ interval: str = Field(..., description="Billing interval")
+
+ @field_validator("plan")
+ @classmethod
+ def validate_plan(cls, value: str) -> str:
+ if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
+ raise ValueError("Invalid plan")
+ return value
+
+ @field_validator("interval")
+ @classmethod
+ def validate_interval(cls, value: str) -> str:
+ if value not in {"month", "year"}:
+ raise ValueError("Invalid interval")
+ return value
+
+
+class PartnerTenantsPayload(BaseModel):
+ click_id: str = Field(..., description="Click Id from partner referral link")
+
+
+for model in (SubscriptionQuery, PartnerTenantsPayload):
+ console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
@console_ns.route("/billing/subscription")
class Subscription(Resource):
@@ -18,20 +49,9 @@ class Subscription(Resource):
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "plan",
- type=str,
- required=True,
- location="args",
- choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
- )
- .add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
- )
- args = parser.parse_args()
+ args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
BillingService.is_tenant_owner_or_admin(current_user)
- return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
+ return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices")
@@ -48,28 +68,27 @@ class Invoices(Resource):
@console_ns.route("/billing/partners//tenants")
class PartnerTenants(Resource):
- @api.doc("sync_partner_tenants_bindings")
- @api.doc(description="Sync partner tenants bindings")
- @api.doc(params={"partner_key": "Partner key"})
- @api.expect(
- api.model(
+ @console_ns.doc("sync_partner_tenants_bindings")
+ @console_ns.doc(description="Sync partner tenants bindings")
+ @console_ns.doc(params={"partner_key": "Partner key"})
+ @console_ns.expect(
+ console_ns.model(
"SyncPartnerTenantsBindingsRequest",
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
)
)
- @api.response(200, "Tenants synced to partner successfully")
- @api.response(400, "Invalid partner information")
+ @console_ns.response(200, "Tenants synced to partner successfully")
+ @console_ns.response(400, "Invalid partner information")
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
- args = parser.parse_args()
try:
- click_id = args["click_id"]
+ args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
+ click_id = args.click_id
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
except Exception:
raise BadRequest("Invalid partner_key")
diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py
index 2a6889968c..afc5f92b68 100644
--- a/api/controllers/console/billing/compliance.py
+++ b/api/controllers/console/billing/compliance.py
@@ -1,5 +1,6 @@
from flask import request
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, Field
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
@@ -9,16 +10,28 @@ from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
+class ComplianceDownloadQuery(BaseModel):
+ doc_name: str = Field(..., description="Compliance document name")
+
+
+console_ns.schema_model(
+ ComplianceDownloadQuery.__name__,
+ ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
+)
+
+
@console_ns.route("/compliance/download")
class ComplianceApi(Resource):
+ @console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
+ @console_ns.doc("download_compliance_document")
+ @console_ns.doc(description="Get compliance document download link")
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
- args = parser.parse_args()
+ args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device")
diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py
index ef66053075..01f268d94d 100644
--- a/api/controllers/console/datasets/data_source.py
+++ b/api/controllers/console/datasets/data_source.py
@@ -1,15 +1,15 @@
import json
from collections.abc import Generator
-from typing import cast
+from typing import Any, cast
from flask import request
-from flask_restx import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with
+from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
-from controllers.console import console_ns
-from controllers.console.wraps import account_initialization_required, setup_required
+from controllers.common.schema import register_schema_model
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
@@ -25,6 +25,19 @@ from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task
+from .. import console_ns
+from ..wraps import account_initialization_required, setup_required
+
+
+class NotionEstimatePayload(BaseModel):
+ notion_info_list: list[dict[str, Any]]
+ process_rule: dict[str, Any]
+ doc_form: str = Field(default="text_model")
+ doc_language: str = Field(default="English")
+
+
+register_schema_model(console_ns, NotionEstimatePayload)
+
@console_ns.route(
"/data-source/integrates",
@@ -243,20 +256,15 @@ class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
+ @console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
def post(self):
_, current_tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
- .add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
- .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
- .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
- )
- args = parser.parse_args()
+ payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump()
# validate args
DocumentService.estimate_args_validate(args)
- notion_info_list = args["notion_info_list"]
+ notion_info_list = payload.notion_info_list
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index 3aac571300..70b6e932e9 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -1,14 +1,19 @@
from typing import Any, cast
from flask import request
-from flask_restx import Resource, fields, marshal, marshal_with, reqparse
+from flask_restx import Resource, fields, marshal, marshal_with
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
-from controllers.console import api, console_ns
-from controllers.console.apikey import api_key_fields, api_key_list
+from controllers.common.schema import register_schema_models
+from controllers.console import console_ns
+from controllers.console.apikey import (
+ api_key_item_model,
+ api_key_list_model,
+)
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import (
@@ -27,21 +32,151 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
-from fields.app_fields import related_app_list
-from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
+from fields.app_fields import app_detail_kernel_fields, related_app_list
+from fields.dataset_fields import (
+ dataset_detail_fields,
+ dataset_fields,
+ dataset_query_detail_fields,
+ dataset_retrieval_model_fields,
+ doc_metadata_fields,
+ external_knowledge_info_fields,
+ external_retrieval_model_fields,
+ icon_info_fields,
+ keyword_setting_fields,
+ reranking_model_fields,
+ tag_fields,
+ vector_setting_fields,
+ weighted_score_fields,
+)
from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
-from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
-def _validate_name(name: str) -> str:
- if not name or len(name) < 1 or len(name) > 40:
- raise ValueError("Name must be between 1 to 40 characters.")
- return name
+def _get_or_create_model(model_name: str, field_def):
+ existing = console_ns.models.get(model_name)
+ if existing is None:
+ existing = console_ns.model(model_name, field_def)
+ return existing
+
+
+# Register models for flask_restx to avoid dict type issues in Swagger
+dataset_base_model = _get_or_create_model("DatasetBase", dataset_fields)
+
+tag_model = _get_or_create_model("Tag", tag_fields)
+
+keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
+vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
+
+weighted_score_fields_copy = weighted_score_fields.copy()
+weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
+weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
+weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
+
+reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
+
+dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
+dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
+dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
+dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
+
+external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
+
+external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
+
+doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
+
+icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
+
+dataset_detail_fields_copy = dataset_detail_fields.copy()
+dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
+dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
+dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
+dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
+dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
+dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
+dataset_detail_model = _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
+
+dataset_query_detail_model = _get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields)
+
+app_detail_kernel_model = _get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
+related_app_list_copy = related_app_list.copy()
+related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
+related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
+
+
+def _validate_indexing_technique(value: str | None) -> str | None:
+ if value is None:
+ return value
+ if value not in Dataset.INDEXING_TECHNIQUE_LIST:
+ raise ValueError("Invalid indexing technique.")
+ return value
+
+
+class DatasetCreatePayload(BaseModel):
+ name: str = Field(..., min_length=1, max_length=40)
+ description: str = Field("", max_length=400)
+ indexing_technique: str | None = None
+ permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
+ provider: str = "vendor"
+ external_knowledge_api_id: str | None = None
+ external_knowledge_id: str | None = None
+
+ @field_validator("indexing_technique")
+ @classmethod
+ def validate_indexing(cls, value: str | None) -> str | None:
+ return _validate_indexing_technique(value)
+
+ @field_validator("provider")
+ @classmethod
+ def validate_provider(cls, value: str) -> str:
+ if value not in Dataset.PROVIDER_LIST:
+ raise ValueError("Invalid provider.")
+ return value
+
+
+class DatasetUpdatePayload(BaseModel):
+ name: str | None = Field(None, min_length=1, max_length=40)
+ description: str | None = Field(None, max_length=400)
+ permission: DatasetPermissionEnum | None = None
+ indexing_technique: str | None = None
+ embedding_model: str | None = None
+ embedding_model_provider: str | None = None
+ retrieval_model: dict[str, Any] | None = None
+ partial_member_list: list[str] | None = None
+ external_retrieval_model: dict[str, Any] | None = None
+ external_knowledge_id: str | None = None
+ external_knowledge_api_id: str | None = None
+ icon_info: dict[str, Any] | None = None
+ is_multimodal: bool | None = False
+
+ @field_validator("indexing_technique")
+ @classmethod
+ def validate_indexing(cls, value: str | None) -> str | None:
+ return _validate_indexing_technique(value)
+
+
+class IndexingEstimatePayload(BaseModel):
+ info_list: dict[str, Any]
+ process_rule: dict[str, Any]
+ indexing_technique: str
+ doc_form: str = "text_model"
+ dataset_id: str | None = None
+ doc_language: str = "English"
+
+ @field_validator("indexing_technique")
+ @classmethod
+ def validate_indexing(cls, value: str) -> str:
+ result = _validate_indexing_technique(value)
+ if result is None:
+ raise ValueError("indexing_technique is required.")
+ return result
+
+
+register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@@ -119,9 +254,9 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
@console_ns.route("/datasets")
class DatasetListApi(Resource):
- @api.doc("get_datasets")
- @api.doc(description="Get list of datasets")
- @api.doc(
+ @console_ns.doc("get_datasets")
+ @console_ns.doc(description="Get list of datasets")
+ @console_ns.doc(
params={
"page": "Page number (default: 1)",
"limit": "Number of items per page (default: 20)",
@@ -131,7 +266,7 @@ class DatasetListApi(Resource):
"include_all": "Include all datasets (default: false)",
}
)
- @api.response(200, "Datasets retrieved successfully")
+ @console_ns.response(200, "Datasets retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -184,75 +319,17 @@ class DatasetListApi(Resource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200
- @api.doc("create_dataset")
- @api.doc(description="Create a new dataset")
- @api.expect(
- api.model(
- "CreateDatasetRequest",
- {
- "name": fields.String(required=True, description="Dataset name (1-40 characters)"),
- "description": fields.String(description="Dataset description (max 400 characters)"),
- "indexing_technique": fields.String(description="Indexing technique"),
- "permission": fields.String(description="Dataset permission"),
- "provider": fields.String(description="Provider"),
- "external_knowledge_api_id": fields.String(description="External knowledge API ID"),
- "external_knowledge_id": fields.String(description="External knowledge ID"),
- },
- )
- )
- @api.response(201, "Dataset created successfully")
- @api.response(400, "Invalid request parameters")
+ @console_ns.doc("create_dataset")
+ @console_ns.doc(description="Create a new dataset")
+ @console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
+ @console_ns.response(201, "Dataset created successfully")
+ @console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "name",
- nullable=False,
- required=True,
- help="type is required. Name must be between 1 to 40 characters.",
- type=_validate_name,
- )
- .add_argument(
- "description",
- type=validate_description_length,
- nullable=True,
- required=False,
- default="",
- )
- .add_argument(
- "indexing_technique",
- type=str,
- location="json",
- choices=Dataset.INDEXING_TECHNIQUE_LIST,
- nullable=True,
- help="Invalid indexing technique.",
- )
- .add_argument(
- "external_knowledge_api_id",
- type=str,
- nullable=True,
- required=False,
- )
- .add_argument(
- "provider",
- type=str,
- nullable=True,
- choices=Dataset.PROVIDER_LIST,
- required=False,
- default="vendor",
- )
- .add_argument(
- "external_knowledge_id",
- type=str,
- nullable=True,
- required=False,
- )
- )
- args = parser.parse_args()
+ payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@@ -262,14 +339,14 @@ class DatasetListApi(Resource):
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_tenant_id,
- name=args["name"],
- description=args["description"],
- indexing_technique=args["indexing_technique"],
+ name=payload.name,
+ description=payload.description,
+ indexing_technique=payload.indexing_technique,
account=current_user,
- permission=DatasetPermissionEnum.ONLY_ME,
- provider=args["provider"],
- external_knowledge_api_id=args["external_knowledge_api_id"],
- external_knowledge_id=args["external_knowledge_id"],
+ permission=payload.permission or DatasetPermissionEnum.ONLY_ME,
+ provider=payload.provider,
+ external_knowledge_api_id=payload.external_knowledge_api_id,
+ external_knowledge_id=payload.external_knowledge_id,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@@ -279,12 +356,12 @@ class DatasetListApi(Resource):
@console_ns.route("/datasets/")
class DatasetApi(Resource):
- @api.doc("get_dataset")
- @api.doc(description="Get dataset details")
- @api.doc(params={"dataset_id": "Dataset ID"})
- @api.response(200, "Dataset retrieved successfully", dataset_detail_fields)
- @api.response(404, "Dataset not found")
- @api.response(403, "Permission denied")
+ @console_ns.doc("get_dataset")
+ @console_ns.doc(description="Get dataset details")
+ @console_ns.doc(params={"dataset_id": "Dataset ID"})
+ @console_ns.response(200, "Dataset retrieved successfully", dataset_detail_model)
+ @console_ns.response(404, "Dataset not found")
+ @console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@@ -328,23 +405,12 @@ class DatasetApi(Resource):
return data, 200
- @api.doc("update_dataset")
- @api.doc(description="Update dataset details")
- @api.expect(
- api.model(
- "UpdateDatasetRequest",
- {
- "name": fields.String(description="Dataset name"),
- "description": fields.String(description="Dataset description"),
- "permission": fields.String(description="Dataset permission"),
- "indexing_technique": fields.String(description="Indexing technique"),
- "external_retrieval_model": fields.Raw(description="External retrieval model settings"),
- },
- )
- )
- @api.response(200, "Dataset updated successfully", dataset_detail_fields)
- @api.response(404, "Dataset not found")
- @api.response(403, "Permission denied")
+ @console_ns.doc("update_dataset")
+ @console_ns.doc(description="Update dataset details")
+ @console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
+ @console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
+ @console_ns.response(404, "Dataset not found")
+ @console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@@ -355,93 +421,25 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "name",
- nullable=False,
- help="type is required. Name must be between 1 to 40 characters.",
- type=_validate_name,
- )
- .add_argument("description", location="json", store_missing=False, type=validate_description_length)
- .add_argument(
- "indexing_technique",
- type=str,
- location="json",
- choices=Dataset.INDEXING_TECHNIQUE_LIST,
- nullable=True,
- help="Invalid indexing technique.",
- )
- .add_argument(
- "permission",
- type=str,
- location="json",
- choices=(
- DatasetPermissionEnum.ONLY_ME,
- DatasetPermissionEnum.ALL_TEAM,
- DatasetPermissionEnum.PARTIAL_TEAM,
- ),
- help="Invalid permission.",
- )
- .add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
- .add_argument(
- "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
- )
- .add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
- .add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
- .add_argument(
- "external_retrieval_model",
- type=dict,
- required=False,
- nullable=True,
- location="json",
- help="Invalid external retrieval model.",
- )
- .add_argument(
- "external_knowledge_id",
- type=str,
- required=False,
- nullable=True,
- location="json",
- help="Invalid external knowledge id.",
- )
- .add_argument(
- "external_knowledge_api_id",
- type=str,
- required=False,
- nullable=True,
- location="json",
- help="Invalid external knowledge api id.",
- )
- .add_argument(
- "icon_info",
- type=dict,
- required=False,
- nullable=True,
- location="json",
- help="Invalid icon info.",
- )
- )
- args = parser.parse_args()
- data = request.get_json()
+ payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
-
# check embedding model setting
if (
- data.get("indexing_technique") == "high_quality"
- and data.get("embedding_model_provider") is not None
- and data.get("embedding_model") is not None
+ payload.indexing_technique == "high_quality"
+ and payload.embedding_model_provider is not None
+ and payload.embedding_model is not None
):
- DatasetService.check_embedding_model_setting(
- dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
+ is_multimodal = DatasetService.check_is_multimodal_model(
+ dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
)
-
+ payload.is_multimodal = is_multimodal
+ payload_data = payload.model_dump(exclude_unset=True)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
- current_user, dataset, data.get("permission"), data.get("partial_member_list")
+ current_user, dataset, payload.permission, payload.partial_member_list
)
- dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
+ dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
@@ -449,15 +447,10 @@ class DatasetApi(Resource):
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_tenant_id
- if data.get("partial_member_list") and data.get("permission") == "partial_members":
- DatasetPermissionService.update_partial_member_list(
- tenant_id, dataset_id_str, data.get("partial_member_list")
- )
+ if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
+ DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
# clear partial member list when permission is only_me or all_team_members
- elif (
- data.get("permission") == DatasetPermissionEnum.ONLY_ME
- or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
- ):
+ elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@@ -488,10 +481,10 @@ class DatasetApi(Resource):
@console_ns.route("/datasets//use-check")
class DatasetUseCheckApi(Resource):
- @api.doc("check_dataset_use")
- @api.doc(description="Check if dataset is in use")
- @api.doc(params={"dataset_id": "Dataset ID"})
- @api.response(200, "Dataset use status retrieved successfully")
+ @console_ns.doc("check_dataset_use")
+ @console_ns.doc(description="Check if dataset is in use")
+ @console_ns.doc(params={"dataset_id": "Dataset ID"})
+ @console_ns.response(200, "Dataset use status retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -504,10 +497,10 @@ class DatasetUseCheckApi(Resource):
@console_ns.route("/datasets//queries")
class DatasetQueryApi(Resource):
- @api.doc("get_dataset_queries")
- @api.doc(description="Get dataset query history")
- @api.doc(params={"dataset_id": "Dataset ID"})
- @api.response(200, "Query history retrieved successfully", dataset_query_detail_fields)
+ @console_ns.doc("get_dataset_queries")
+ @console_ns.doc(description="Get dataset query history")
+ @console_ns.doc(params={"dataset_id": "Dataset ID"})
+ @console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_model)
@setup_required
@login_required
@account_initialization_required
@@ -529,7 +522,7 @@ class DatasetQueryApi(Resource):
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
response = {
- "data": marshal(dataset_queries, dataset_query_detail_fields),
+ "data": marshal(dataset_queries, dataset_query_detail_model),
"has_more": len(dataset_queries) == limit,
"limit": limit,
"total": total,
@@ -540,30 +533,16 @@ class DatasetQueryApi(Resource):
@console_ns.route("/datasets/indexing-estimate")
class DatasetIndexingEstimateApi(Resource):
- @api.doc("estimate_dataset_indexing")
- @api.doc(description="Estimate dataset indexing cost")
- @api.response(200, "Indexing estimate calculated successfully")
+ @console_ns.doc("estimate_dataset_indexing")
+ @console_ns.doc(description="Estimate dataset indexing cost")
+ @console_ns.response(200, "Indexing estimate calculated successfully")
@setup_required
@login_required
@account_initialization_required
+ @console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("info_list", type=dict, required=True, nullable=True, location="json")
- .add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
- .add_argument(
- "indexing_technique",
- type=str,
- required=True,
- choices=Dataset.INDEXING_TECHNIQUE_LIST,
- nullable=True,
- location="json",
- )
- .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
- .add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
- .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
- )
- args = parser.parse_args()
+ payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump()
_, current_tenant_id = current_account_with_tenant()
# validate args
DocumentService.estimate_args_validate(args)
@@ -650,14 +629,14 @@ class DatasetIndexingEstimateApi(Resource):
@console_ns.route("/datasets//related-apps")
class DatasetRelatedAppListApi(Resource):
- @api.doc("get_dataset_related_apps")
- @api.doc(description="Get applications related to dataset")
- @api.doc(params={"dataset_id": "Dataset ID"})
- @api.response(200, "Related apps retrieved successfully", related_app_list)
+ @console_ns.doc("get_dataset_related_apps")
+ @console_ns.doc(description="Get applications related to dataset")
+ @console_ns.doc(params={"dataset_id": "Dataset ID"})
+ @console_ns.response(200, "Related apps retrieved successfully", related_app_list_model)
@setup_required
@login_required
@account_initialization_required
- @marshal_with(related_app_list)
+ @marshal_with(related_app_list_model)
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@@ -683,10 +662,10 @@ class DatasetRelatedAppListApi(Resource):
@console_ns.route("/datasets//indexing-status")
class DatasetIndexingStatusApi(Resource):
- @api.doc("get_dataset_indexing_status")
- @api.doc(description="Get dataset indexing status")
- @api.doc(params={"dataset_id": "Dataset ID"})
- @api.response(200, "Indexing status retrieved successfully")
+ @console_ns.doc("get_dataset_indexing_status")
+ @console_ns.doc(description="Get dataset indexing status")
+ @console_ns.doc(params={"dataset_id": "Dataset ID"})
+ @console_ns.response(200, "Indexing status retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -738,13 +717,13 @@ class DatasetApiKeyApi(Resource):
token_prefix = "dataset-"
resource_type = "dataset"
- @api.doc("get_dataset_api_keys")
- @api.doc(description="Get dataset API keys")
- @api.response(200, "API keys retrieved successfully", api_key_list)
+ @console_ns.doc("get_dataset_api_keys")
+ @console_ns.doc(description="Get dataset API keys")
+ @console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
@setup_required
@login_required
@account_initialization_required
- @marshal_with(api_key_list)
+ @marshal_with(api_key_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars(
@@ -756,7 +735,7 @@ class DatasetApiKeyApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
- @marshal_with(api_key_fields)
+ @marshal_with(api_key_item_model)
def post(self):
_, current_tenant_id = current_account_with_tenant()
@@ -767,7 +746,7 @@ class DatasetApiKeyApi(Resource):
)
if current_key_count >= self.max_keys:
- api.abort(
+ console_ns.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
@@ -787,10 +766,10 @@ class DatasetApiKeyApi(Resource):
class DatasetApiDeleteApi(Resource):
resource_type = "dataset"
- @api.doc("delete_dataset_api_key")
- @api.doc(description="Delete dataset API key")
- @api.doc(params={"api_key_id": "API key ID"})
- @api.response(204, "API key deleted successfully")
+ @console_ns.doc("delete_dataset_api_key")
+ @console_ns.doc(description="Delete dataset API key")
+ @console_ns.doc(params={"api_key_id": "API key ID"})
+ @console_ns.response(204, "API key deleted successfully")
@setup_required
@login_required
@is_admin_or_owner_required
@@ -809,7 +788,7 @@ class DatasetApiDeleteApi(Resource):
)
if key is None:
- api.abort(404, message="API key not found")
+ console_ns.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()
@@ -832,9 +811,9 @@ class DatasetEnableApiApi(Resource):
@console_ns.route("/datasets/api-base-info")
class DatasetApiBaseUrlApi(Resource):
- @api.doc("get_dataset_api_base_info")
- @api.doc(description="Get dataset API base information")
- @api.response(200, "API base info retrieved successfully")
+ @console_ns.doc("get_dataset_api_base_info")
+ @console_ns.doc(description="Get dataset API base information")
+ @console_ns.response(200, "API base info retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -844,9 +823,9 @@ class DatasetApiBaseUrlApi(Resource):
@console_ns.route("/datasets/retrieval-setting")
class DatasetRetrievalSettingApi(Resource):
- @api.doc("get_dataset_retrieval_setting")
- @api.doc(description="Get dataset retrieval settings")
- @api.response(200, "Retrieval settings retrieved successfully")
+ @console_ns.doc("get_dataset_retrieval_setting")
+ @console_ns.doc(description="Get dataset retrieval settings")
+ @console_ns.response(200, "Retrieval settings retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -857,10 +836,10 @@ class DatasetRetrievalSettingApi(Resource):
@console_ns.route("/datasets/retrieval-setting/")
class DatasetRetrievalSettingMockApi(Resource):
- @api.doc("get_dataset_retrieval_setting_mock")
- @api.doc(description="Get mock dataset retrieval settings by vector type")
- @api.doc(params={"vector_type": "Vector store type"})
- @api.response(200, "Mock retrieval settings retrieved successfully")
+ @console_ns.doc("get_dataset_retrieval_setting_mock")
+ @console_ns.doc(description="Get mock dataset retrieval settings by vector type")
+ @console_ns.doc(params={"vector_type": "Vector store type"})
+ @console_ns.response(200, "Mock retrieval settings retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -870,11 +849,11 @@ class DatasetRetrievalSettingMockApi(Resource):
@console_ns.route("/datasets//error-docs")
class DatasetErrorDocs(Resource):
- @api.doc("get_dataset_error_docs")
- @api.doc(description="Get dataset error documents")
- @api.doc(params={"dataset_id": "Dataset ID"})
- @api.response(200, "Error documents retrieved successfully")
- @api.response(404, "Dataset not found")
+ @console_ns.doc("get_dataset_error_docs")
+ @console_ns.doc(description="Get dataset error documents")
+ @console_ns.doc(params={"dataset_id": "Dataset ID"})
+ @console_ns.response(200, "Error documents retrieved successfully")
+ @console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
@@ -890,12 +869,12 @@ class DatasetErrorDocs(Resource):
@console_ns.route("/datasets//permission-part-users")
class DatasetPermissionUserListApi(Resource):
- @api.doc("get_dataset_permission_users")
- @api.doc(description="Get dataset permission user list")
- @api.doc(params={"dataset_id": "Dataset ID"})
- @api.response(200, "Permission users retrieved successfully")
- @api.response(404, "Dataset not found")
- @api.response(403, "Permission denied")
+ @console_ns.doc("get_dataset_permission_users")
+ @console_ns.doc(description="Get dataset permission user list")
+ @console_ns.doc(params={"dataset_id": "Dataset ID"})
+ @console_ns.response(200, "Permission users retrieved successfully")
+ @console_ns.response(404, "Dataset not found")
+ @console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@@ -919,11 +898,11 @@ class DatasetPermissionUserListApi(Resource):
@console_ns.route("/datasets//auto-disable-logs")
class DatasetAutoDisableLogApi(Resource):
- @api.doc("get_dataset_auto_disable_logs")
- @api.doc(description="Get dataset auto disable logs")
- @api.doc(params={"dataset_id": "Dataset ID"})
- @api.response(200, "Auto disable logs retrieved successfully")
- @api.response(404, "Dataset not found")
+ @console_ns.doc("get_dataset_auto_disable_logs")
+ @console_ns.doc(description="Get dataset auto disable logs")
+ @console_ns.doc(params={"dataset_id": "Dataset ID"})
+ @console_ns.response(200, "Auto disable logs retrieved successfully")
+ @console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index 92c85b4951..6145da31a5 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -6,31 +6,14 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
-from flask_restx import Resource, fields, marshal, marshal_with, reqparse
+from flask_restx import Resource, fields, marshal, marshal_with
+from pydantic import BaseModel
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
-from controllers.console import api, console_ns
-from controllers.console.app.error import (
- ProviderModelCurrentlyNotSupportError,
- ProviderNotInitializeError,
- ProviderQuotaExceededError,
-)
-from controllers.console.datasets.error import (
- ArchivedDocumentImmutableError,
- DocumentAlreadyFinishedError,
- DocumentIndexingError,
- IndexingEstimateError,
- InvalidActionError,
- InvalidMetadataError,
-)
-from controllers.console.wraps import (
- account_initialization_required,
- cloud_edition_billing_rate_limit_check,
- cloud_edition_billing_resource_check,
- setup_required,
-)
+from controllers.common.schema import register_schema_models
+from controllers.console import console_ns
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
@@ -45,22 +28,92 @@ from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from extensions.ext_database import db
+from fields.dataset_fields import dataset_fields
from fields.document_fields import (
dataset_and_document_fields,
document_fields,
+ document_metadata_fields,
document_status_fields,
document_with_segments_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
-from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
+from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
-from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
+from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
+
+from ..app.error import (
+ ProviderModelCurrentlyNotSupportError,
+ ProviderNotInitializeError,
+ ProviderQuotaExceededError,
+)
+from ..datasets.error import (
+ ArchivedDocumentImmutableError,
+ DocumentAlreadyFinishedError,
+ DocumentIndexingError,
+ IndexingEstimateError,
+ InvalidActionError,
+ InvalidMetadataError,
+)
+from ..wraps import (
+ account_initialization_required,
+ cloud_edition_billing_rate_limit_check,
+ cloud_edition_billing_resource_check,
+ setup_required,
+)
logger = logging.getLogger(__name__)
+def _get_or_create_model(model_name: str, field_def):
+ existing = console_ns.models.get(model_name)
+ if existing is None:
+ existing = console_ns.model(model_name, field_def)
+ return existing
+
+
+# Register models for flask_restx to avoid dict type issues in Swagger
+dataset_model = _get_or_create_model("Dataset", dataset_fields)
+
+document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields)
+
+document_fields_copy = document_fields.copy()
+document_fields_copy["doc_metadata"] = fields.List(
+ fields.Nested(document_metadata_model), attribute="doc_metadata_details"
+)
+document_model = _get_or_create_model("Document", document_fields_copy)
+
+document_with_segments_fields_copy = document_with_segments_fields.copy()
+document_with_segments_fields_copy["doc_metadata"] = fields.List(
+ fields.Nested(document_metadata_model), attribute="doc_metadata_details"
+)
+document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
+
+dataset_and_document_fields_copy = dataset_and_document_fields.copy()
+dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
+dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
+dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
+
+
+class DocumentRetryPayload(BaseModel):
+ document_ids: list[str]
+
+
+class DocumentRenamePayload(BaseModel):
+ name: str
+
+
+register_schema_models(
+ console_ns,
+ KnowledgeConfig,
+ ProcessRule,
+ RetrievalModel,
+ DocumentRetryPayload,
+ DocumentRenamePayload,
+)
+
+
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
@@ -104,10 +157,10 @@ class DocumentResource(Resource):
@console_ns.route("/datasets/process-rule")
class GetProcessRuleApi(Resource):
- @api.doc("get_process_rule")
- @api.doc(description="Get dataset document processing rules")
- @api.doc(params={"document_id": "Document ID (optional)"})
- @api.response(200, "Process rules retrieved successfully")
+ @console_ns.doc("get_process_rule")
+ @console_ns.doc(description="Get dataset document processing rules")
+ @console_ns.doc(params={"document_id": "Document ID (optional)"})
+ @console_ns.response(200, "Process rules retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -152,9 +205,9 @@ class GetProcessRuleApi(Resource):
@console_ns.route("/datasets//documents")
class DatasetDocumentListApi(Resource):
- @api.doc("get_dataset_documents")
- @api.doc(description="Get documents in a dataset")
- @api.doc(
+ @console_ns.doc("get_dataset_documents")
+ @console_ns.doc(description="Get documents in a dataset")
+ @console_ns.doc(
params={
"dataset_id": "Dataset ID",
"page": "Page number (default: 1)",
@@ -165,7 +218,7 @@ class DatasetDocumentListApi(Resource):
"status": "Filter documents by display status",
}
)
- @api.response(200, "Documents retrieved successfully")
+ @console_ns.response(200, "Documents retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -276,9 +329,10 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(dataset_and_document_fields)
+ @marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
+ @console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
@@ -297,23 +351,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
- )
- .add_argument("data_source", type=dict, required=False, location="json")
- .add_argument("process_rule", type=dict, required=False, location="json")
- .add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
- .add_argument("original_document_id", type=str, required=False, location="json")
- .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
- .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
- .add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
- .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
- .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
- )
- args = parser.parse_args()
- knowledge_config = KnowledgeConfig.model_validate(args)
+ knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.")
@@ -357,25 +395,15 @@ class DatasetDocumentListApi(Resource):
@console_ns.route("/datasets/init")
class DatasetInitApi(Resource):
- @api.doc("init_dataset")
- @api.doc(description="Initialize dataset with documents")
- @api.expect(
- api.model(
- "DatasetInitRequest",
- {
- "upload_file_id": fields.String(required=True, description="Upload file ID"),
- "indexing_technique": fields.String(description="Indexing technique"),
- "process_rule": fields.Raw(description="Processing rules"),
- "data_source": fields.Raw(description="Data source configuration"),
- },
- )
- )
- @api.response(201, "Dataset initialized successfully", dataset_and_document_fields)
- @api.response(400, "Invalid request parameters")
+ @console_ns.doc("init_dataset")
+ @console_ns.doc(description="Initialize dataset with documents")
+ @console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
+ @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
+ @console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
- @marshal_with(dataset_and_document_fields)
+ @marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@@ -384,27 +412,7 @@ class DatasetInitApi(Resource):
if not current_user.is_dataset_editor:
raise Forbidden()
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "indexing_technique",
- type=str,
- choices=Dataset.INDEXING_TECHNIQUE_LIST,
- required=True,
- nullable=False,
- location="json",
- )
- .add_argument("data_source", type=dict, required=True, nullable=True, location="json")
- .add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
- .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
- .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
- .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
- .add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
- .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
- )
- args = parser.parse_args()
-
- knowledge_config = KnowledgeConfig.model_validate(args)
+ knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@@ -412,10 +420,14 @@ class DatasetInitApi(Resource):
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_tenant_id,
- provider=args["embedding_model_provider"],
+ provider=knowledge_config.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
- model=args["embedding_model"],
+ model=knowledge_config.embedding_model,
)
+ is_multimodal = DatasetService.check_is_multimodal_model(
+ current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
+ )
+ knowledge_config.is_multimodal = is_multimodal
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
@@ -446,12 +458,12 @@ class DatasetInitApi(Resource):
@console_ns.route("/datasets//documents//indexing-estimate")
class DocumentIndexingEstimateApi(DocumentResource):
- @api.doc("estimate_document_indexing")
- @api.doc(description="Estimate document indexing cost")
- @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
- @api.response(200, "Indexing estimate calculated successfully")
- @api.response(404, "Document not found")
- @api.response(400, "Document already finished")
+ @console_ns.doc("estimate_document_indexing")
+ @console_ns.doc(description="Estimate document indexing cost")
+ @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
+ @console_ns.response(200, "Indexing estimate calculated successfully")
+ @console_ns.response(404, "Document not found")
+ @console_ns.response(400, "Document already finished")
@setup_required
@login_required
@account_initialization_required
@@ -661,11 +673,11 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
@console_ns.route("/datasets//documents//indexing-status")
class DocumentIndexingStatusApi(DocumentResource):
- @api.doc("get_document_indexing_status")
- @api.doc(description="Get document indexing status")
- @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
- @api.response(200, "Indexing status retrieved successfully")
- @api.response(404, "Document not found")
+ @console_ns.doc("get_document_indexing_status")
+ @console_ns.doc(description="Get document indexing status")
+ @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
+ @console_ns.response(200, "Indexing status retrieved successfully")
+ @console_ns.response(404, "Document not found")
@setup_required
@login_required
@account_initialization_required
@@ -711,17 +723,17 @@ class DocumentIndexingStatusApi(DocumentResource):
class DocumentApi(DocumentResource):
METADATA_CHOICES = {"all", "only", "without"}
- @api.doc("get_document")
- @api.doc(description="Get document details")
- @api.doc(
+ @console_ns.doc("get_document")
+ @console_ns.doc(description="Get document details")
+ @console_ns.doc(
params={
"dataset_id": "Dataset ID",
"document_id": "Document ID",
"metadata": "Metadata inclusion (all/only/without)",
}
)
- @api.response(200, "Document retrieved successfully")
- @api.response(404, "Document not found")
+ @console_ns.response(200, "Document retrieved successfully")
+ @console_ns.response(404, "Document not found")
@setup_required
@login_required
@account_initialization_required
@@ -832,14 +844,14 @@ class DocumentApi(DocumentResource):
@console_ns.route("/datasets//documents//processing/")
class DocumentProcessingApi(DocumentResource):
- @api.doc("update_document_processing")
- @api.doc(description="Update document processing status (pause/resume)")
- @api.doc(
+ @console_ns.doc("update_document_processing")
+ @console_ns.doc(description="Update document processing status (pause/resume)")
+ @console_ns.doc(
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"}
)
- @api.response(200, "Processing status updated successfully")
- @api.response(404, "Document not found")
- @api.response(400, "Invalid action")
+ @console_ns.response(200, "Processing status updated successfully")
+ @console_ns.response(404, "Document not found")
+ @console_ns.response(400, "Invalid action")
@setup_required
@login_required
@account_initialization_required
@@ -877,11 +889,11 @@ class DocumentProcessingApi(DocumentResource):
@console_ns.route("/datasets//documents//metadata")
class DocumentMetadataApi(DocumentResource):
- @api.doc("update_document_metadata")
- @api.doc(description="Update document metadata")
- @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
- @api.expect(
- api.model(
+ @console_ns.doc("update_document_metadata")
+ @console_ns.doc(description="Update document metadata")
+ @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
+ @console_ns.expect(
+ console_ns.model(
"UpdateDocumentMetadataRequest",
{
"doc_type": fields.String(description="Document type"),
@@ -889,9 +901,9 @@ class DocumentMetadataApi(DocumentResource):
},
)
)
- @api.response(200, "Document metadata updated successfully")
- @api.response(404, "Document not found")
- @api.response(403, "Permission denied")
+ @console_ns.response(200, "Document metadata updated successfully")
+ @console_ns.response(404, "Document not found")
+ @console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@@ -1045,19 +1057,16 @@ class DocumentRetryApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
+ @console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
def post(self, dataset_id):
"""retry document."""
-
- parser = reqparse.RequestParser().add_argument(
- "document_ids", type=list, required=True, nullable=False, location="json"
- )
- args = parser.parse_args()
+ payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
retry_documents = []
if not dataset:
raise NotFound("Dataset not found.")
- for document_id in args["document_ids"]:
+ for document_id in payload.document_ids:
try:
document_id = str(document_id)
@@ -1090,6 +1099,7 @@ class DocumentRenameApi(DocumentResource):
@login_required
@account_initialization_required
@marshal_with(document_fields)
+ @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
@@ -1099,11 +1109,10 @@ class DocumentRenameApi(DocumentResource):
if not dataset:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(current_user, dataset)
- parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
- args = parser.parse_args()
+ payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
try:
- document = DocumentService.rename_document(dataset_id, document_id, args["name"])
+ document = DocumentService.rename_document(dataset_id, document_id, payload.name)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py
index 2fe7d42e46..e73abc2555 100644
--- a/api/controllers/console/datasets/datasets_segments.py
+++ b/api/controllers/console/datasets/datasets_segments.py
@@ -1,11 +1,13 @@
import uuid
from flask import request
-from flask_restx import Resource, marshal, reqparse
+from flask_restx import Resource, marshal
+from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
+from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import (
@@ -36,6 +38,58 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
+class SegmentListQuery(BaseModel):
+ limit: int = Field(default=20, ge=1, le=100)
+ status: list[str] = Field(default_factory=list)
+ hit_count_gte: int | None = None
+ enabled: str = Field(default="all")
+ keyword: str | None = None
+ page: int = Field(default=1, ge=1)
+
+
+class SegmentCreatePayload(BaseModel):
+ content: str
+ answer: str | None = None
+ keywords: list[str] | None = None
+ attachment_ids: list[str] | None = None
+
+
+class SegmentUpdatePayload(BaseModel):
+ content: str
+ answer: str | None = None
+ keywords: list[str] | None = None
+ regenerate_child_chunks: bool = False
+ attachment_ids: list[str] | None = None
+
+
+class BatchImportPayload(BaseModel):
+ upload_file_id: str
+
+
+class ChildChunkCreatePayload(BaseModel):
+ content: str
+
+
+class ChildChunkUpdatePayload(BaseModel):
+ content: str
+
+
+class ChildChunkBatchUpdatePayload(BaseModel):
+ chunks: list[ChildChunkUpdateArgs]
+
+
+register_schema_models(
+ console_ns,
+ SegmentListQuery,
+ SegmentCreatePayload,
+ SegmentUpdatePayload,
+ BatchImportPayload,
+ ChildChunkCreatePayload,
+ ChildChunkUpdatePayload,
+ ChildChunkBatchUpdatePayload,
+)
+
+
@console_ns.route("/datasets//documents//segments")
class DatasetDocumentSegmentListApi(Resource):
@setup_required
@@ -60,23 +114,18 @@ class DatasetDocumentSegmentListApi(Resource):
if not document:
raise NotFound("Document not found.")
- parser = (
- reqparse.RequestParser()
- .add_argument("limit", type=int, default=20, location="args")
- .add_argument("status", type=str, action="append", default=[], location="args")
- .add_argument("hit_count_gte", type=int, default=None, location="args")
- .add_argument("enabled", type=str, default="all", location="args")
- .add_argument("keyword", type=str, default=None, location="args")
- .add_argument("page", type=int, default=1, location="args")
+ args = SegmentListQuery.model_validate(
+ {
+ **request.args.to_dict(),
+ "status": request.args.getlist("status"),
+ }
)
- args = parser.parse_args()
-
- page = args["page"]
- limit = min(args["limit"], 100)
- status_list = args["status"]
- hit_count_gte = args["hit_count_gte"]
- keyword = args["keyword"]
+ page = args.page
+ limit = min(args.limit, 100)
+ status_list = args.status
+ hit_count_gte = args.hit_count_gte
+ keyword = args.keyword
query = (
select(DocumentSegment)
@@ -96,10 +145,10 @@ class DatasetDocumentSegmentListApi(Resource):
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
- if args["enabled"].lower() != "all":
- if args["enabled"].lower() == "true":
+ if args.enabled.lower() != "all":
+ if args.enabled.lower() == "true":
query = query.where(DocumentSegment.enabled == True)
- elif args["enabled"].lower() == "false":
+ elif args.enabled.lower() == "false":
query = query.where(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@@ -210,6 +259,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
+ @console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
@@ -246,15 +296,10 @@ class DatasetDocumentSegmentAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
- parser = (
- reqparse.RequestParser()
- .add_argument("content", type=str, required=True, nullable=False, location="json")
- .add_argument("answer", type=str, required=False, nullable=True, location="json")
- .add_argument("keywords", type=list, required=False, nullable=True, location="json")
- )
- args = parser.parse_args()
- SegmentService.segment_create_args_validate(args, document)
- segment = SegmentService.create_segment(args, document, dataset)
+ payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
+ payload_dict = payload.model_dump(exclude_none=True)
+ SegmentService.segment_create_args_validate(payload_dict, document)
+ segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@@ -265,6 +310,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
+ @console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
@@ -313,18 +359,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
- parser = (
- reqparse.RequestParser()
- .add_argument("content", type=str, required=True, nullable=False, location="json")
- .add_argument("answer", type=str, required=False, nullable=True, location="json")
- .add_argument("keywords", type=list, required=False, nullable=True, location="json")
- .add_argument(
- "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
- )
+ payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
+ payload_dict = payload.model_dump(exclude_none=True)
+ SegmentService.segment_create_args_validate(payload_dict, document)
+ segment = SegmentService.update_segment(
+ SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
- args = parser.parse_args()
- SegmentService.segment_create_args_validate(args, document)
- segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@@ -377,6 +417,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
+ @console_ns.expect(console_ns.models[BatchImportPayload.__name__])
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
@@ -391,11 +432,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if not document:
raise NotFound("Document not found.")
- parser = reqparse.RequestParser().add_argument(
- "upload_file_id", type=str, required=True, nullable=False, location="json"
- )
- args = parser.parse_args()
- upload_file_id = args["upload_file_id"]
+ payload = BatchImportPayload.model_validate(console_ns.payload or {})
+ upload_file_id = payload.upload_file_id
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file:
@@ -446,6 +484,7 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
+ @console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
@@ -491,13 +530,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
- parser = reqparse.RequestParser().add_argument(
- "content", type=str, required=True, nullable=False, location="json"
- )
- args = parser.parse_args()
try:
- content = args["content"]
- child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
+ payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
+ child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@@ -529,18 +564,17 @@ class ChildChunkAddApi(Resource):
)
if not segment:
raise NotFound("Segment not found.")
- parser = (
- reqparse.RequestParser()
- .add_argument("limit", type=int, default=20, location="args")
- .add_argument("keyword", type=str, default=None, location="args")
- .add_argument("page", type=int, default=1, location="args")
+ args = SegmentListQuery.model_validate(
+ {
+ "limit": request.args.get("limit", default=20, type=int),
+ "keyword": request.args.get("keyword"),
+ "page": request.args.get("page", default=1, type=int),
+ }
)
- args = parser.parse_args()
-
- page = args["page"]
- limit = min(args["limit"], 100)
- keyword = args["keyword"]
+ page = args.page
+ limit = min(args.limit, 100)
+ keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
@@ -588,14 +622,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
- parser = reqparse.RequestParser().add_argument(
- "chunks", type=list, required=True, nullable=False, location="json"
- )
- args = parser.parse_args()
+ payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
try:
- chunks_data = args["chunks"]
- chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
- child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
+ child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
@@ -665,6 +694,7 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
+ @console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
@@ -711,13 +741,9 @@ class ChildChunkUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
- parser = reqparse.RequestParser().add_argument(
- "content", type=str, required=True, nullable=False, location="json"
- )
- args = parser.parse_args()
try:
- content = args["content"]
- child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
+ payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})
+ child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py
index fe96a8199a..89c9fcad36 100644
--- a/api/controllers/console/datasets/external.py
+++ b/api/controllers/console/datasets/external.py
@@ -1,12 +1,26 @@
from flask import request
-from flask_restx import Resource, fields, marshal, reqparse
+from flask_restx import Resource, fields, marshal
+from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
-from controllers.console import api, console_ns
+from controllers.common.schema import register_schema_models
+from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
-from fields.dataset_fields import dataset_detail_fields
+from fields.dataset_fields import (
+ dataset_detail_fields,
+ dataset_retrieval_model_fields,
+ doc_metadata_fields,
+ external_knowledge_info_fields,
+ external_retrieval_model_fields,
+ icon_info_fields,
+ keyword_setting_fields,
+ reranking_model_fields,
+ tag_fields,
+ vector_setting_fields,
+ weighted_score_fields,
+)
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
@@ -14,24 +28,97 @@ from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
-def _validate_name(name: str) -> str:
- if not name or len(name) < 1 or len(name) > 100:
- raise ValueError("Name must be between 1 to 100 characters.")
- return name
+def _get_or_create_model(model_name: str, field_def):
+ existing = console_ns.models.get(model_name)
+ if existing is None:
+ existing = console_ns.model(model_name, field_def)
+ return existing
+
+
+def _build_dataset_detail_model():
+ keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
+ vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
+
+ weighted_score_fields_copy = weighted_score_fields.copy()
+ weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
+ weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
+ weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
+
+ reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
+
+ dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
+ dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
+ dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
+ dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
+
+ tag_model = _get_or_create_model("Tag", tag_fields)
+ doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
+ external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
+ external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
+ icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
+
+ dataset_detail_fields_copy = dataset_detail_fields.copy()
+ dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
+ dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
+ dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
+ dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
+ dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
+ dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
+ return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
+
+
+try:
+ dataset_detail_model = console_ns.models["DatasetDetail"]
+except KeyError:
+ dataset_detail_model = _build_dataset_detail_model()
+
+
+class ExternalKnowledgeApiPayload(BaseModel):
+ name: str = Field(..., min_length=1, max_length=40)
+ settings: dict[str, object]
+
+
+class ExternalDatasetCreatePayload(BaseModel):
+ external_knowledge_api_id: str
+ external_knowledge_id: str
+ name: str = Field(..., min_length=1, max_length=40)
+ description: str | None = Field(None, max_length=400)
+ external_retrieval_model: dict[str, object] | None = None
+
+
+class ExternalHitTestingPayload(BaseModel):
+ query: str
+ external_retrieval_model: dict[str, object] | None = None
+ metadata_filtering_conditions: dict[str, object] | None = None
+
+
+class BedrockRetrievalPayload(BaseModel):
+ retrieval_setting: dict[str, object]
+ query: str
+ knowledge_id: str
+
+
+register_schema_models(
+ console_ns,
+ ExternalKnowledgeApiPayload,
+ ExternalDatasetCreatePayload,
+ ExternalHitTestingPayload,
+ BedrockRetrievalPayload,
+)
@console_ns.route("/datasets/external-knowledge-api")
class ExternalApiTemplateListApi(Resource):
- @api.doc("get_external_api_templates")
- @api.doc(description="Get external knowledge API templates")
- @api.doc(
+ @console_ns.doc("get_external_api_templates")
+ @console_ns.doc(description="Get external knowledge API templates")
+ @console_ns.doc(
params={
"page": "Page number (default: 1)",
"limit": "Number of items per page (default: 20)",
"keyword": "Search keyword",
}
)
- @api.response(200, "External API templates retrieved successfully")
+ @console_ns.response(200, "External API templates retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -56,28 +143,12 @@ class ExternalApiTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
+ @console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "name",
- nullable=False,
- required=True,
- help="Name is required. Name must be between 1 to 100 characters.",
- type=_validate_name,
- )
- .add_argument(
- "settings",
- type=dict,
- location="json",
- nullable=False,
- required=True,
- )
- )
- args = parser.parse_args()
+ payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
- ExternalDatasetService.validate_api_list(args["settings"])
+ ExternalDatasetService.validate_api_list(payload.settings)
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@@ -85,7 +156,7 @@ class ExternalApiTemplateListApi(Resource):
try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
- tenant_id=current_tenant_id, user_id=current_user.id, args=args
+ tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump()
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@@ -95,11 +166,11 @@ class ExternalApiTemplateListApi(Resource):
@console_ns.route("/datasets/external-knowledge-api/")
class ExternalApiTemplateApi(Resource):
- @api.doc("get_external_api_template")
- @api.doc(description="Get external knowledge API template details")
- @api.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
- @api.response(200, "External API template retrieved successfully")
- @api.response(404, "Template not found")
+ @console_ns.doc("get_external_api_template")
+ @console_ns.doc(description="Get external knowledge API template details")
+ @console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
+ @console_ns.response(200, "External API template retrieved successfully")
+ @console_ns.response(404, "Template not found")
@setup_required
@login_required
@account_initialization_required
@@ -114,35 +185,19 @@ class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
+ @console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "name",
- nullable=False,
- required=True,
- help="type is required. Name must be between 1 to 100 characters.",
- type=_validate_name,
- )
- .add_argument(
- "settings",
- type=dict,
- location="json",
- nullable=False,
- required=True,
- )
- )
- args = parser.parse_args()
- ExternalDatasetService.validate_api_list(args["settings"])
+ payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
+ ExternalDatasetService.validate_api_list(payload.settings)
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_tenant_id,
user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id,
- args=args,
+ args=payload.model_dump(),
)
return external_knowledge_api.to_dict(), 200
@@ -163,10 +218,10 @@ class ExternalApiTemplateApi(Resource):
@console_ns.route("/datasets/external-knowledge-api//use-check")
class ExternalApiUseCheckApi(Resource):
- @api.doc("check_external_api_usage")
- @api.doc(description="Check if external knowledge API is being used")
- @api.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
- @api.response(200, "Usage check completed successfully")
+ @console_ns.doc("check_external_api_usage")
+ @console_ns.doc(description="Check if external knowledge API is being used")
+ @console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
+ @console_ns.response(200, "Usage check completed successfully")
@setup_required
@login_required
@account_initialization_required
@@ -181,22 +236,12 @@ class ExternalApiUseCheckApi(Resource):
@console_ns.route("/datasets/external")
class ExternalDatasetCreateApi(Resource):
- @api.doc("create_external_dataset")
- @api.doc(description="Create external knowledge dataset")
- @api.expect(
- api.model(
- "CreateExternalDatasetRequest",
- {
- "external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
- "external_knowledge_id": fields.String(required=True, description="External knowledge ID"),
- "name": fields.String(required=True, description="Dataset name"),
- "description": fields.String(description="Dataset description"),
- },
- )
- )
- @api.response(201, "External dataset created successfully", dataset_detail_fields)
- @api.response(400, "Invalid parameters")
- @api.response(403, "Permission denied")
+ @console_ns.doc("create_external_dataset")
+ @console_ns.doc(description="Create external knowledge dataset")
+ @console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__])
+ @console_ns.response(201, "External dataset created successfully", dataset_detail_model)
+ @console_ns.response(400, "Invalid parameters")
+ @console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@@ -204,22 +249,8 @@ class ExternalDatasetCreateApi(Resource):
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
- .add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "name",
- nullable=False,
- required=True,
- help="name is required. Name must be between 1 to 100 characters.",
- type=_validate_name,
- )
- .add_argument("description", type=str, required=False, nullable=True, location="json")
- .add_argument("external_retrieval_model", type=dict, required=False, location="json")
- )
-
- args = parser.parse_args()
+ payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@@ -239,22 +270,13 @@ class ExternalDatasetCreateApi(Resource):
@console_ns.route("/datasets//external-hit-testing")
class ExternalKnowledgeHitTestingApi(Resource):
- @api.doc("test_external_knowledge_retrieval")
- @api.doc(description="Test external knowledge retrieval for dataset")
- @api.doc(params={"dataset_id": "Dataset ID"})
- @api.expect(
- api.model(
- "ExternalHitTestingRequest",
- {
- "query": fields.String(required=True, description="Query text for testing"),
- "retrieval_model": fields.Raw(description="Retrieval model configuration"),
- "external_retrieval_model": fields.Raw(description="External retrieval model configuration"),
- },
- )
- )
- @api.response(200, "External hit testing completed successfully")
- @api.response(404, "Dataset not found")
- @api.response(400, "Invalid parameters")
+ @console_ns.doc("test_external_knowledge_retrieval")
+ @console_ns.doc(description="Test external knowledge retrieval for dataset")
+ @console_ns.doc(params={"dataset_id": "Dataset ID"})
+ @console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__])
+ @console_ns.response(200, "External hit testing completed successfully")
+ @console_ns.response(404, "Dataset not found")
+ @console_ns.response(400, "Invalid parameters")
@setup_required
@login_required
@account_initialization_required
@@ -270,23 +292,16 @@ class ExternalKnowledgeHitTestingApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
- parser = (
- reqparse.RequestParser()
- .add_argument("query", type=str, location="json")
- .add_argument("external_retrieval_model", type=dict, required=False, location="json")
- .add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
- )
- args = parser.parse_args()
-
- HitTestingService.hit_testing_args_check(args)
+ payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {})
+ HitTestingService.hit_testing_args_check(payload.model_dump())
try:
response = HitTestingService.external_retrieve(
dataset=dataset,
- query=args["query"],
+ query=payload.query,
account=current_user,
- external_retrieval_model=args["external_retrieval_model"],
- metadata_filtering_conditions=args["metadata_filtering_conditions"],
+ external_retrieval_model=payload.external_retrieval_model,
+ metadata_filtering_conditions=payload.metadata_filtering_conditions,
)
return response
@@ -297,35 +312,15 @@ class ExternalKnowledgeHitTestingApi(Resource):
@console_ns.route("/test/retrieval")
class BedrockRetrievalApi(Resource):
# this api is only for internal testing
- @api.doc("bedrock_retrieval_test")
- @api.doc(description="Bedrock retrieval test (internal use only)")
- @api.expect(
- api.model(
- "BedrockRetrievalTestRequest",
- {
- "retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
- "query": fields.String(required=True, description="Query text"),
- "knowledge_id": fields.String(required=True, description="Knowledge ID"),
- },
- )
- )
- @api.response(200, "Bedrock retrieval test completed")
+ @console_ns.doc("bedrock_retrieval_test")
+ @console_ns.doc(description="Bedrock retrieval test (internal use only)")
+ @console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__])
+ @console_ns.response(200, "Bedrock retrieval test completed")
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
- .add_argument(
- "query",
- nullable=False,
- required=True,
- type=str,
- )
- .add_argument("knowledge_id", nullable=False, required=True, type=str)
- )
- args = parser.parse_args()
+ payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {})
# Call the knowledge retrieval service
result = ExternalDatasetTestService.knowledge_retrieval(
- args["retrieval_setting"], args["query"], args["knowledge_id"]
+ payload.retrieval_setting, payload.query, payload.knowledge_id
)
return result, 200
diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py
index abaca88090..932cb4fcce 100644
--- a/api/controllers/console/datasets/hit_testing.py
+++ b/api/controllers/console/datasets/hit_testing.py
@@ -1,34 +1,28 @@
-from flask_restx import Resource, fields
+from flask_restx import Resource
-from controllers.console import api, console_ns
-from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
-from controllers.console.wraps import (
+from controllers.common.schema import register_schema_model
+from libs.login import login_required
+
+from .. import console_ns
+from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
+from ..wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
-from libs.login import login_required
+
+register_schema_model(console_ns, HitTestingPayload)
@console_ns.route("/datasets//hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase):
- @api.doc("test_dataset_retrieval")
- @api.doc(description="Test dataset knowledge retrieval")
- @api.doc(params={"dataset_id": "Dataset ID"})
- @api.expect(
- api.model(
- "HitTestingRequest",
- {
- "query": fields.String(required=True, description="Query text for testing"),
- "retrieval_model": fields.Raw(description="Retrieval model configuration"),
- "top_k": fields.Integer(description="Number of top results to return"),
- "score_threshold": fields.Float(description="Score threshold for filtering results"),
- },
- )
- )
- @api.response(200, "Hit testing completed successfully")
- @api.response(404, "Dataset not found")
- @api.response(400, "Invalid parameters")
+ @console_ns.doc("test_dataset_retrieval")
+ @console_ns.doc(description="Test dataset knowledge retrieval")
+ @console_ns.doc(params={"dataset_id": "Dataset ID"})
+ @console_ns.expect(console_ns.models[HitTestingPayload.__name__])
+ @console_ns.response(200, "Hit testing completed successfully")
+ @console_ns.response(404, "Dataset not found")
+ @console_ns.response(400, "Invalid parameters")
@setup_required
@login_required
@account_initialization_required
@@ -37,7 +31,8 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
- args = self.parse_args()
+ payload = HitTestingPayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)
diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py
index 99d4d5a29c..db7c50f422 100644
--- a/api/controllers/console/datasets/hit_testing_base.py
+++ b/api/controllers/console/datasets/hit_testing_base.py
@@ -1,6 +1,8 @@
import logging
+from typing import Any
from flask_restx import marshal, reqparse
+from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@@ -27,6 +29,13 @@ from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__)
+class HitTestingPayload(BaseModel):
+ query: str = Field(max_length=250)
+ retrieval_model: dict[str, Any] | None = None
+ external_retrieval_model: dict[str, Any] | None = None
+ attachment_ids: list[str] | None = None
+
+
class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
@@ -43,14 +52,15 @@ class DatasetsHitTestingBase:
return dataset
@staticmethod
- def hit_testing_args_check(args):
+ def hit_testing_args_check(args: dict[str, Any]):
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = (
reqparse.RequestParser()
- .add_argument("query", type=str, location="json")
+ .add_argument("query", type=str, required=False, location="json")
+ .add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
@@ -62,10 +72,11 @@ class DatasetsHitTestingBase:
try:
response = HitTestingService.retrieve(
dataset=dataset,
- query=args["query"],
+ query=args.get("query"),
account=current_user,
- retrieval_model=args["retrieval_model"],
- external_retrieval_model=args["external_retrieval_model"],
+ retrieval_model=args.get("retrieval_model"),
+ external_retrieval_model=args.get("external_retrieval_model"),
+ attachment_ids=args.get("attachment_ids"),
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py
index 72b2ff0ff8..8eead1696a 100644
--- a/api/controllers/console/datasets/metadata.py
+++ b/api/controllers/console/datasets/metadata.py
@@ -1,8 +1,10 @@
from typing import Literal
-from flask_restx import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with
+from pydantic import BaseModel
from werkzeug.exceptions import NotFound
+from controllers.common.schema import register_schema_model, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
@@ -15,6 +17,14 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService
+class MetadataUpdatePayload(BaseModel):
+ name: str
+
+
+register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
+register_schema_model(console_ns, MetadataUpdatePayload)
+
+
@console_ns.route("/datasets//metadata")
class DatasetMetadataCreateApi(Resource):
@setup_required
@@ -22,15 +32,10 @@ class DatasetMetadataCreateApi(Resource):
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
+ @console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("type", type=str, required=True, nullable=False, location="json")
- .add_argument("name", type=str, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
- metadata_args = MetadataArgs.model_validate(args)
+ metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@@ -60,11 +65,11 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
+ @console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
- args = parser.parse_args()
- name = args["name"]
+ payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
+ name = payload.name
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@@ -131,6 +136,7 @@ class DocumentMetadataEditApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
+ @console_ns.expect(console_ns.models[MetadataOperationData.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@@ -139,11 +145,7 @@ class DocumentMetadataEditApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
- parser = reqparse.RequestParser().add_argument(
- "operation_data", type=list, required=True, nullable=False, location="json"
- )
- args = parser.parse_args()
- metadata_args = MetadataOperationData.model_validate(args)
+ metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
MetadataService.update_documents_metadata(dataset, metadata_args)
diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py
index f83ee69beb..1a47e226e5 100644
--- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py
+++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py
@@ -1,20 +1,63 @@
+from typing import Any
+
from flask import make_response, redirect, request
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
-from controllers.console import api, console_ns
+from controllers.common.schema import register_schema_models
+from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
-from libs.helper import StrLen
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
+class DatasourceCredentialPayload(BaseModel):
+ name: str | None = Field(default=None, max_length=100)
+ credentials: dict[str, Any]
+
+
+class DatasourceCredentialDeletePayload(BaseModel):
+ credential_id: str
+
+
+class DatasourceCredentialUpdatePayload(BaseModel):
+ credential_id: str
+ name: str | None = Field(default=None, max_length=100)
+ credentials: dict[str, Any] | None = None
+
+
+class DatasourceCustomClientPayload(BaseModel):
+ client_params: dict[str, Any] | None = None
+ enable_oauth_custom_client: bool | None = None
+
+
+class DatasourceDefaultPayload(BaseModel):
+ id: str
+
+
+class DatasourceUpdateNamePayload(BaseModel):
+ credential_id: str
+ name: str = Field(max_length=100)
+
+
+register_schema_models(
+ console_ns,
+ DatasourceCredentialPayload,
+ DatasourceCredentialDeletePayload,
+ DatasourceCredentialUpdatePayload,
+ DatasourceCustomClientPayload,
+ DatasourceDefaultPayload,
+ DatasourceUpdateNamePayload,
+)
+
+
@console_ns.route("/oauth/plugin//datasource/get-authorization-url")
class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required
@@ -121,16 +164,9 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
-parser_datasource = (
- reqparse.RequestParser()
- .add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/auth/plugin/datasource/")
class DatasourceAuth(Resource):
- @api.expect(parser_datasource)
+ @console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -138,7 +174,7 @@ class DatasourceAuth(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_datasource.parse_args()
+ payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@@ -146,8 +182,8 @@ class DatasourceAuth(Resource):
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_tenant_id,
provider_id=datasource_provider_id,
- credentials=args["credentials"],
- name=args["name"],
+ credentials=payload.credentials,
+ name=payload.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
@@ -169,14 +205,9 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200
-parser_datasource_delete = reqparse.RequestParser().add_argument(
- "credential_id", type=str, required=True, nullable=False, location="json"
-)
-
-
@console_ns.route("/auth/plugin/datasource//delete")
class DatasourceAuthDeleteApi(Resource):
- @api.expect(parser_datasource_delete)
+ @console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -188,28 +219,20 @@ class DatasourceAuthDeleteApi(Resource):
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
- args = parser_datasource_delete.parse_args()
+ payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_tenant_id,
- auth_id=args["credential_id"],
+ auth_id=payload.credential_id,
provider=provider_name,
plugin_id=plugin_id,
)
return {"result": "success"}, 200
-parser_datasource_update = (
- reqparse.RequestParser()
- .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
- .add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
- .add_argument("credential_id", type=str, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/auth/plugin/datasource//update")
class DatasourceAuthUpdateApi(Resource):
- @api.expect(parser_datasource_update)
+ @console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -218,16 +241,16 @@ class DatasourceAuthUpdateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
- args = parser_datasource_update.parse_args()
+ payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_tenant_id,
- auth_id=args["credential_id"],
+ auth_id=payload.credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
- credentials=args.get("credentials", {}),
- name=args.get("name", None),
+ credentials=payload.credentials or {},
+ name=payload.name,
)
return {"result": "success"}, 201
@@ -258,16 +281,9 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200
-parser_datasource_custom = (
- reqparse.RequestParser()
- .add_argument("client_params", type=dict, required=False, nullable=True, location="json")
- .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
-)
-
-
@console_ns.route("/auth/plugin/datasource//custom-client")
class DatasourceAuthOauthCustomClient(Resource):
- @api.expect(parser_datasource_custom)
+ @console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -275,14 +291,14 @@ class DatasourceAuthOauthCustomClient(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_datasource_custom.parse_args()
+ payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
- client_params=args.get("client_params", {}),
- enabled=args.get("enable_oauth_custom_client", False),
+ client_params=payload.client_params or {},
+ enabled=payload.enable_oauth_custom_client or False,
)
return {"result": "success"}, 200
@@ -301,12 +317,9 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200
-parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
-
-
@console_ns.route("/auth/plugin/datasource//default")
class DatasourceAuthDefaultApi(Resource):
- @api.expect(parser_default)
+ @console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -314,27 +327,20 @@ class DatasourceAuthDefaultApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_default.parse_args()
+ payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
- credential_id=args["id"],
+ credential_id=payload.id,
)
return {"result": "success"}, 200
-parser_update_name = (
- reqparse.RequestParser()
- .add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
- .add_argument("credential_id", type=str, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/auth/plugin/datasource//update-name")
class DatasourceUpdateProviderNameApi(Resource):
- @api.expect(parser_update_name)
+ @console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -342,13 +348,13 @@ class DatasourceUpdateProviderNameApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_update_name.parse_args()
+ payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
- name=args["name"],
- credential_id=args["credential_id"],
+ name=payload.name,
+ credential_id=payload.credential_id,
)
return {"result": "success"}, 200
diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py
index 5e3b3428eb..42387557d6 100644
--- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py
+++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py
@@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required
@@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
@console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//preview")
class DataSourceContentPreviewApi(Resource):
- @api.expect(console_ns.models[Parser.__name__], validate=True)
+ @console_ns.expect(console_ns.models[Parser.__name__], validate=True)
@setup_required
@login_required
@account_initialization_required
@@ -38,7 +38,7 @@ class DataSourceContentPreviewApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
- args = Parser.model_validate(api.payload)
+ args = Parser.model_validate(console_ns.payload)
inputs = args.inputs
datasource_type = args.datasource_type
diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py
index f589bba3bf..6e0cd31b8d 100644
--- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py
+++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py
@@ -1,9 +1,11 @@
import logging
from flask import request
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
+from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
@@ -20,18 +22,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
-def _validate_name(name: str) -> str:
- if not name or len(name) < 1 or len(name) > 40:
- raise ValueError("Name must be between 1 to 40 characters.")
- return name
-
-
-def _validate_description_length(description: str) -> str:
- if len(description) > 400:
- raise ValueError("Description cannot exceed 400 characters.")
- return description
-
-
@console_ns.route("/rag/pipeline/templates")
class PipelineTemplateListApi(Resource):
@setup_required
@@ -59,6 +49,15 @@ class PipelineTemplateDetailApi(Resource):
return pipeline_template, 200
+class Payload(BaseModel):
+ name: str = Field(..., min_length=1, max_length=40)
+ description: str = Field(default="", max_length=400)
+ icon_info: dict[str, object] | None = None
+
+
+register_schema_models(console_ns, Payload)
+
+
@console_ns.route("/rag/pipeline/customized/templates/")
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@@ -66,31 +65,8 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "name",
- nullable=False,
- required=True,
- help="Name must be between 1 to 40 characters.",
- type=_validate_name,
- )
- .add_argument(
- "description",
- type=_validate_description_length,
- nullable=True,
- required=False,
- default="",
- )
- .add_argument(
- "icon_info",
- type=dict,
- location="json",
- nullable=True,
- )
- )
- args = parser.parse_args()
- pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
+ payload = Payload.model_validate(console_ns.payload or {})
+ pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200
@@ -119,36 +95,14 @@ class CustomizedPipelineTemplateApi(Resource):
@console_ns.route("/rag/pipelines//customized/publish")
class PublishCustomizedPipelineTemplateApi(Resource):
+ @console_ns.expect(console_ns.models[Payload.__name__])
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str):
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "name",
- nullable=False,
- required=True,
- help="Name must be between 1 to 40 characters.",
- type=_validate_name,
- )
- .add_argument(
- "description",
- type=_validate_description_length,
- nullable=True,
- required=False,
- default="",
- )
- .add_argument(
- "icon_info",
- type=dict,
- location="json",
- nullable=True,
- )
- )
- args = parser.parse_args()
+ payload = Payload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
- rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
+ rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
return {"result": "success"}
diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py
index 98876e9f5e..e65cb19b39 100644
--- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py
+++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py
@@ -1,8 +1,10 @@
-from flask_restx import Resource, marshal, reqparse
+from flask_restx import Resource, marshal
+from pydantic import BaseModel
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
import services
+from controllers.common.schema import register_schema_model
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
@@ -19,22 +21,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
+class RagPipelineDatasetImportPayload(BaseModel):
+ yaml_content: str
+
+
+register_schema_model(console_ns, RagPipelineDatasetImportPayload)
+
+
@console_ns.route("/rag/pipeline/dataset")
class CreateRagPipelineDatasetApi(Resource):
+ @console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
- parser = reqparse.RequestParser().add_argument(
- "yaml_content",
- type=str,
- nullable=False,
- required=True,
- help="yaml_content is required.",
- )
-
- args = parser.parse_args()
+ payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@@ -49,7 +51,7 @@ class CreateRagPipelineDatasetApi(Resource):
),
permission=DatasetPermissionEnum.ONLY_ME,
partial_member_list=None,
- yaml_content=args["yaml_content"],
+ yaml_content=payload.yaml_content,
)
try:
with Session(db.engine) as session:
diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py
index 858ba94bf8..720e2ce365 100644
--- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py
+++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py
@@ -1,11 +1,13 @@
import logging
-from typing import NoReturn
+from typing import Any, NoReturn
-from flask import Response
-from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
+from flask import Response, request
+from flask_restx import Resource, fields, marshal, marshal_with
+from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
+from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
DraftWorkflowNotExist,
@@ -33,19 +35,21 @@ logger = logging.getLogger(__name__)
def _create_pagination_parser():
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "page",
- type=inputs.int_range(1, 100_000),
- required=False,
- default=1,
- location="args",
- help="the page of data requested",
- )
- .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
- )
- return parser
+ class PaginationQuery(BaseModel):
+ page: int = Field(default=1, ge=1, le=100_000)
+ limit: int = Field(default=20, ge=1, le=100)
+
+ register_schema_models(console_ns, PaginationQuery)
+
+ return PaginationQuery
+
+
+class WorkflowDraftVariablePatchPayload(BaseModel):
+ name: str | None = None
+ value: Any | None = None
+
+
+register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
@@ -93,8 +97,8 @@ class RagPipelineVariableCollectionApi(Resource):
"""
Get draft workflow
"""
- parser = _create_pagination_parser()
- args = parser.parse_args()
+ pagination = _create_pagination_parser()
+ query = pagination.model_validate(request.args.to_dict())
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
@@ -109,8 +113,8 @@ class RagPipelineVariableCollectionApi(Resource):
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=pipeline.id,
- page=args.page,
- limit=args.limit,
+ page=query.page,
+ limit=query.limit,
)
return workflow_vars
@@ -186,6 +190,7 @@ class RagPipelineVariableApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
+ @console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
#
@@ -208,16 +213,11 @@ class RagPipelineVariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
- parser = (
- reqparse.RequestParser()
- .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
- .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
- )
-
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
- args = parser.parse_args(strict=True)
+ payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py
index d658d65b71..d43ee9a6e0 100644
--- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py
+++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py
@@ -1,6 +1,9 @@
-from flask_restx import Resource, marshal_with, reqparse # type: ignore
+from flask import request
+from flask_restx import Resource, marshal_with # type: ignore
+from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
+from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
@@ -16,6 +19,25 @@ from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
+class RagPipelineImportPayload(BaseModel):
+ mode: str
+ yaml_content: str | None = None
+ yaml_url: str | None = None
+ name: str | None = None
+ description: str | None = None
+ icon_type: str | None = None
+ icon: str | None = None
+ icon_background: str | None = None
+ pipeline_id: str | None = None
+
+
+class IncludeSecretQuery(BaseModel):
+ include_secret: str = Field(default="false")
+
+
+register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
+
+
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource):
@setup_required
@@ -23,23 +45,11 @@ class RagPipelineImportApi(Resource):
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields)
+ @console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
-
- parser = (
- reqparse.RequestParser()
- .add_argument("mode", type=str, required=True, location="json")
- .add_argument("yaml_content", type=str, location="json")
- .add_argument("yaml_url", type=str, location="json")
- .add_argument("name", type=str, location="json")
- .add_argument("description", type=str, location="json")
- .add_argument("icon_type", type=str, location="json")
- .add_argument("icon", type=str, location="json")
- .add_argument("icon_background", type=str, location="json")
- .add_argument("pipeline_id", type=str, location="json")
- )
- args = parser.parse_args()
+ payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Create service with session
with Session(db.engine) as session:
@@ -48,11 +58,11 @@ class RagPipelineImportApi(Resource):
account = current_user
result = import_service.import_rag_pipeline(
account=account,
- import_mode=args["mode"],
- yaml_content=args.get("yaml_content"),
- yaml_url=args.get("yaml_url"),
- pipeline_id=args.get("pipeline_id"),
- dataset_name=args.get("name"),
+ import_mode=payload.mode,
+ yaml_content=payload.yaml_content,
+ yaml_url=payload.yaml_url,
+ pipeline_id=payload.pipeline_id,
+ dataset_name=payload.name,
)
session.commit()
@@ -114,13 +124,12 @@ class RagPipelineExportApi(Resource):
@edit_permission_required
def get(self, pipeline: Pipeline):
# Add include_secret params
- parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
- args = parser.parse_args()
+ query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(
- pipeline=pipeline, include_secret=args["include_secret"] == "true"
+ pipeline=pipeline, include_secret=query.include_secret == "true"
)
return {"data": result}, 200
diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
index bc8d4fbf81..debe8eed97 100644
--- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
+++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
@@ -1,15 +1,17 @@
import json
import logging
-from typing import cast
+from typing import Any, Literal, cast
+from uuid import UUID
from flask import abort, request
-from flask_restx import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore
-from flask_restx.inputs import int_range # type: ignore
+from flask_restx import Resource, marshal_with # type: ignore
+from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
-from controllers.console import api, console_ns
+from controllers.common.schema import register_schema_models
+from controllers.console import console_ns
from controllers.console.app.error import (
ConversationCompletedError,
DraftWorkflowNotExist,
@@ -36,7 +38,7 @@ from fields.workflow_run_fields import (
workflow_run_pagination_fields,
)
from libs import helper
-from libs.helper import TimestampField, uuid_value
+from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
@@ -51,6 +53,91 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
logger = logging.getLogger(__name__)
+class DraftWorkflowSyncPayload(BaseModel):
+ graph: dict[str, Any]
+ hash: str | None = None
+ environment_variables: list[dict[str, Any]] | None = None
+ conversation_variables: list[dict[str, Any]] | None = None
+ rag_pipeline_variables: list[dict[str, Any]] | None = None
+ features: dict[str, Any] | None = None
+
+
+class NodeRunPayload(BaseModel):
+ inputs: dict[str, Any] | None = None
+
+
+class NodeRunRequiredPayload(BaseModel):
+ inputs: dict[str, Any]
+
+
+class DatasourceNodeRunPayload(BaseModel):
+ inputs: dict[str, Any]
+ datasource_type: str
+ credential_id: str | None = None
+
+
+class DraftWorkflowRunPayload(BaseModel):
+ inputs: dict[str, Any]
+ datasource_type: str
+ datasource_info_list: list[dict[str, Any]]
+ start_node_id: str
+
+
+class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
+ is_preview: bool = False
+ response_mode: Literal["streaming", "blocking"] = "streaming"
+ original_document_id: str | None = None
+
+
+class DefaultBlockConfigQuery(BaseModel):
+ q: str | None = None
+
+
+class WorkflowListQuery(BaseModel):
+ page: int = Field(default=1, ge=1, le=99999)
+ limit: int = Field(default=10, ge=1, le=100)
+ user_id: str | None = None
+ named_only: bool = False
+
+
+class WorkflowUpdatePayload(BaseModel):
+ marked_name: str | None = Field(default=None, max_length=20)
+ marked_comment: str | None = Field(default=None, max_length=100)
+
+
+class NodeIdQuery(BaseModel):
+ node_id: str
+
+
+class WorkflowRunQuery(BaseModel):
+ last_id: UUID | None = None
+ limit: int = Field(default=20, ge=1, le=100)
+
+
+class DatasourceVariablesPayload(BaseModel):
+ datasource_type: str
+ datasource_info: dict[str, Any]
+ start_node_id: str
+ start_node_title: str
+
+
+register_schema_models(
+ console_ns,
+ DraftWorkflowSyncPayload,
+ NodeRunPayload,
+ NodeRunRequiredPayload,
+ DatasourceNodeRunPayload,
+ DraftWorkflowRunPayload,
+ PublishedWorkflowRunPayload,
+ DefaultBlockConfigQuery,
+ WorkflowListQuery,
+ WorkflowUpdatePayload,
+ NodeIdQuery,
+ WorkflowRunQuery,
+ DatasourceVariablesPayload,
+)
+
+
@console_ns.route("/rag/pipelines//workflows/draft")
class DraftRagPipelineApi(Resource):
@setup_required
@@ -88,15 +175,7 @@ class DraftRagPipelineApi(Resource):
content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type:
- parser = (
- reqparse.RequestParser()
- .add_argument("graph", type=dict, required=True, nullable=False, location="json")
- .add_argument("hash", type=str, required=False, location="json")
- .add_argument("environment_variables", type=list, required=False, location="json")
- .add_argument("conversation_variables", type=list, required=False, location="json")
- .add_argument("rag_pipeline_variables", type=list, required=False, location="json")
- )
- args = parser.parse_args()
+ payload_dict = console_ns.payload or {}
elif "text/plain" in content_type:
try:
data = json.loads(request.data.decode("utf-8"))
@@ -106,7 +185,7 @@ class DraftRagPipelineApi(Resource):
if not isinstance(data.get("graph"), dict):
raise ValueError("graph is not a dict")
- args = {
+ payload_dict = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
@@ -119,24 +198,26 @@ class DraftRagPipelineApi(Resource):
else:
abort(415)
+ payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
+
try:
- environment_variables_list = args.get("environment_variables") or []
+ environment_variables_list = payload.environment_variables or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
- conversation_variables_list = args.get("conversation_variables") or []
+ conversation_variables_list = payload.conversation_variables or []
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
- graph=args["graph"],
- unique_hash=args.get("hash"),
+ graph=payload.graph,
+ unique_hash=payload.hash,
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
- rag_pipeline_variables=args.get("rag_pipeline_variables") or [],
+ rag_pipeline_variables=payload.rag_pipeline_variables or [],
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
@@ -148,12 +229,9 @@ class DraftRagPipelineApi(Resource):
}
-parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
-
-
@console_ns.route("/rag/pipelines//workflows/draft/iteration/nodes//run")
class RagPipelineDraftRunIterationNodeApi(Resource):
- @api.expect(parser_run)
+ @console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -166,7 +244,8 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
- args = parser_run.parse_args()
+ payload = NodeRunPayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
try:
response = PipelineGenerateService.generate_single_iteration(
@@ -187,7 +266,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines//workflows/draft/loop/nodes//run")
class RagPipelineDraftRunLoopNodeApi(Resource):
- @api.expect(parser_run)
+ @console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -200,7 +279,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
- args = parser_run.parse_args()
+ payload = NodeRunPayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
try:
response = PipelineGenerateService.generate_single_loop(
@@ -219,18 +299,9 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError()
-parser_draft_run = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
- .add_argument("datasource_type", type=str, required=True, location="json")
- .add_argument("datasource_info_list", type=list, required=True, location="json")
- .add_argument("start_node_id", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/rag/pipelines//workflows/draft/run")
class DraftRagPipelineRunApi(Resource):
- @api.expect(parser_draft_run)
+ @console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -243,7 +314,8 @@ class DraftRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
- args = parser_draft_run.parse_args()
+ payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump()
try:
response = PipelineGenerateService.generate(
@@ -259,21 +331,9 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
-parser_published_run = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
- .add_argument("datasource_type", type=str, required=True, location="json")
- .add_argument("datasource_info_list", type=list, required=True, location="json")
- .add_argument("start_node_id", type=str, required=True, location="json")
- .add_argument("is_preview", type=bool, required=True, location="json", default=False)
- .add_argument("response_mode", type=str, required=True, location="json", default="streaming")
- .add_argument("original_document_id", type=str, required=False, location="json")
-)
-
-
@console_ns.route("/rag/pipelines//workflows/published/run")
class PublishedRagPipelineRunApi(Resource):
- @api.expect(parser_published_run)
+ @console_ns.expect(console_ns.models[PublishedWorkflowRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -286,16 +346,16 @@ class PublishedRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
- args = parser_published_run.parse_args()
-
- streaming = args["response_mode"] == "streaming"
+ payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
+ streaming = payload.response_mode == "streaming"
try:
response = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
- invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
+ invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
streaming=streaming,
)
@@ -387,17 +447,9 @@ class PublishedRagPipelineRunApi(Resource):
#
# return result
#
-parser_rag_run = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
- .add_argument("datasource_type", type=str, required=True, location="json")
- .add_argument("credential_id", type=str, required=False, location="json")
-)
-
-
@console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
- @api.expect(parser_rag_run)
+ @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -410,14 +462,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
- args = parser_rag_run.parse_args()
-
- inputs = args.get("inputs")
- if inputs is None:
- raise ValueError("missing inputs")
- datasource_type = args.get("datasource_type")
- if datasource_type is None:
- raise ValueError("missing datasource_type")
+ payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response(
@@ -425,11 +470,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
- user_inputs=inputs,
+ user_inputs=payload.inputs,
account=current_user,
- datasource_type=datasource_type,
+ datasource_type=payload.datasource_type,
is_published=False,
- credential_id=args.get("credential_id"),
+ credential_id=payload.credential_id,
)
)
)
@@ -437,7 +482,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines//workflows/draft/datasource/nodes//run")
class RagPipelineDraftDatasourceNodeRunApi(Resource):
- @api.expect(parser_rag_run)
+ @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -450,14 +495,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
- args = parser_rag_run.parse_args()
-
- inputs = args.get("inputs")
- if inputs is None:
- raise ValueError("missing inputs")
- datasource_type = args.get("datasource_type")
- if datasource_type is None:
- raise ValueError("missing datasource_type")
+ payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response(
@@ -465,24 +503,19 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
- user_inputs=inputs,
+ user_inputs=payload.inputs,
account=current_user,
- datasource_type=datasource_type,
+ datasource_type=payload.datasource_type,
is_published=False,
- credential_id=args.get("credential_id"),
+ credential_id=payload.credential_id,
)
)
)
-parser_run_api = reqparse.RequestParser().add_argument(
- "inputs", type=dict, required=True, nullable=False, location="json"
-)
-
-
@console_ns.route("/rag/pipelines//workflows/draft/nodes//run")
class RagPipelineDraftNodeRunApi(Resource):
- @api.expect(parser_run_api)
+ @console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -496,11 +529,8 @@ class RagPipelineDraftNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
- args = parser_run_api.parse_args()
-
- inputs = args.get("inputs")
- if inputs == None:
- raise ValueError("missing inputs")
+ payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
+ inputs = payload.inputs
rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
@@ -602,12 +632,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs()
-parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
-
-
@console_ns.route("/rag/pipelines//workflows/default-workflow-block-configs/")
class DefaultRagPipelineBlockConfigApi(Resource):
- @api.expect(parser_default)
@setup_required
@login_required
@account_initialization_required
@@ -617,14 +643,12 @@ class DefaultRagPipelineBlockConfigApi(Resource):
"""
Get default block config
"""
- args = parser_default.parse_args()
-
- q = args.get("q")
+ query = DefaultBlockConfigQuery.model_validate(request.args.to_dict())
filters = None
- if q:
+ if query.q:
try:
- filters = json.loads(args.get("q", ""))
+ filters = json.loads(query.q)
except json.JSONDecodeError:
raise ValueError("Invalid filters")
@@ -633,18 +657,8 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
-parser_wf = (
- reqparse.RequestParser()
- .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
- .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
- .add_argument("user_id", type=str, required=False, location="args")
- .add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
-)
-
-
@console_ns.route("/rag/pipelines//workflows")
class PublishedAllRagPipelineApi(Resource):
- @api.expect(parser_wf)
@setup_required
@login_required
@account_initialization_required
@@ -657,16 +671,16 @@ class PublishedAllRagPipelineApi(Resource):
"""
current_user, _ = current_account_with_tenant()
- args = parser_wf.parse_args()
- page = args["page"]
- limit = args["limit"]
- user_id = args.get("user_id")
- named_only = args.get("named_only", False)
+ query = WorkflowListQuery.model_validate(request.args.to_dict())
+
+ page = query.page
+ limit = query.limit
+ user_id = query.user_id
+ named_only = query.named_only
if user_id:
if user_id != current_user.id:
raise Forbidden()
- user_id = cast(str, user_id)
rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session:
@@ -687,16 +701,8 @@ class PublishedAllRagPipelineApi(Resource):
}
-parser_wf_id = (
- reqparse.RequestParser()
- .add_argument("marked_name", type=str, required=False, location="json")
- .add_argument("marked_comment", type=str, required=False, location="json")
-)
-
-
@console_ns.route("/rag/pipelines//workflows/")
class RagPipelineByIdApi(Resource):
- @api.expect(parser_wf_id)
@setup_required
@login_required
@account_initialization_required
@@ -710,20 +716,8 @@ class RagPipelineByIdApi(Resource):
# Check permission
current_user, _ = current_account_with_tenant()
- args = parser_wf_id.parse_args()
-
- # Validate name and comment length
- if args.marked_name and len(args.marked_name) > 20:
- raise ValueError("Marked name cannot exceed 20 characters")
- if args.marked_comment and len(args.marked_comment) > 100:
- raise ValueError("Marked comment cannot exceed 100 characters")
-
- # Prepare update data
- update_data = {}
- if args.get("marked_name") is not None:
- update_data["marked_name"] = args["marked_name"]
- if args.get("marked_comment") is not None:
- update_data["marked_comment"] = args["marked_comment"]
+ payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
+ update_data = payload.model_dump(exclude_unset=True)
if not update_data:
return {"message": "No valid fields to update"}, 400
@@ -749,12 +743,8 @@ class RagPipelineByIdApi(Resource):
return workflow
-parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
-
-
@console_ns.route("/rag/pipelines//workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource):
- @api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@@ -764,10 +754,8 @@ class PublishedRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
- args = parser_parameters.parse_args()
- node_id = args.get("node_id")
- if not node_id:
- raise ValueError("Node ID is required")
+ query = NodeIdQuery.model_validate(request.args.to_dict())
+ node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return {
@@ -777,7 +765,6 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines//workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource):
- @api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@@ -787,10 +774,8 @@ class PublishedRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
- args = parser_parameters.parse_args()
- node_id = args.get("node_id")
- if not node_id:
- raise ValueError("Node ID is required")
+ query = NodeIdQuery.model_validate(request.args.to_dict())
+ node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return {
@@ -800,7 +785,6 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines//workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource):
- @api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@@ -810,10 +794,8 @@ class DraftRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
- args = parser_parameters.parse_args()
- node_id = args.get("node_id")
- if not node_id:
- raise ValueError("Node ID is required")
+ query = NodeIdQuery.model_validate(request.args.to_dict())
+ node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
return {
@@ -823,7 +805,6 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines//workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource):
- @api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@@ -833,10 +814,8 @@ class DraftRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
- args = parser_parameters.parse_args()
- node_id = args.get("node_id")
- if not node_id:
- raise ValueError("Node ID is required")
+ query = NodeIdQuery.model_validate(request.args.to_dict())
+ node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
@@ -845,16 +824,8 @@ class DraftRagPipelineSecondStepApi(Resource):
}
-parser_wf_run = (
- reqparse.RequestParser()
- .add_argument("last_id", type=uuid_value, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
-)
-
-
@console_ns.route("/rag/pipelines//workflow-runs")
class RagPipelineWorkflowRunListApi(Resource):
- @api.expect(parser_wf_run)
@setup_required
@login_required
@account_initialization_required
@@ -864,7 +835,16 @@ class RagPipelineWorkflowRunListApi(Resource):
"""
Get workflow run list
"""
- args = parser_wf_run.parse_args()
+ query = WorkflowRunQuery.model_validate(
+ {
+ "last_id": request.args.get("last_id"),
+ "limit": request.args.get("limit", type=int, default=20),
+ }
+ )
+ args = {
+ "last_id": str(query.last_id) if query.last_id else None,
+ "limit": query.limit,
+ }
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
@@ -964,18 +944,9 @@ class RagPipelineTransformApi(Resource):
return result
-parser_var = (
- reqparse.RequestParser()
- .add_argument("datasource_type", type=str, required=True, location="json")
- .add_argument("datasource_info", type=dict, required=True, location="json")
- .add_argument("start_node_id", type=str, required=True, location="json")
- .add_argument("start_node_title", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/rag/pipelines//workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource):
- @api.expect(parser_var)
+ @console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -987,7 +958,7 @@ class RagPipelineDatasourceVariableApi(Resource):
Set datasource variables
"""
current_user, _ = current_account_with_tenant()
- args = parser_var.parse_args()
+ args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.set_datasource_variables(
diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py
index fe6eaaa0de..335c8f6030 100644
--- a/api/controllers/console/datasets/website.py
+++ b/api/controllers/console/datasets/website.py
@@ -1,54 +1,46 @@
-from flask_restx import Resource, fields, reqparse
+from typing import Literal
-from controllers.console import api, console_ns
+from flask import request
+from flask_restx import Resource
+from pydantic import BaseModel
+
+from controllers.common.schema import register_schema_models
+from controllers.console import console_ns
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
+class WebsiteCrawlPayload(BaseModel):
+ provider: Literal["firecrawl", "watercrawl", "jinareader"]
+ url: str
+ options: dict[str, object]
+
+
+class WebsiteCrawlStatusQuery(BaseModel):
+ provider: Literal["firecrawl", "watercrawl", "jinareader"]
+
+
+register_schema_models(console_ns, WebsiteCrawlPayload, WebsiteCrawlStatusQuery)
+
+
@console_ns.route("/website/crawl")
class WebsiteCrawlApi(Resource):
- @api.doc("crawl_website")
- @api.doc(description="Crawl website content")
- @api.expect(
- api.model(
- "WebsiteCrawlRequest",
- {
- "provider": fields.String(
- required=True,
- description="Crawl provider (firecrawl/watercrawl/jinareader)",
- enum=["firecrawl", "watercrawl", "jinareader"],
- ),
- "url": fields.String(required=True, description="URL to crawl"),
- "options": fields.Raw(required=True, description="Crawl options"),
- },
- )
- )
- @api.response(200, "Website crawl initiated successfully")
- @api.response(400, "Invalid crawl parameters")
+ @console_ns.doc("crawl_website")
+ @console_ns.doc(description="Crawl website content")
+ @console_ns.expect(console_ns.models[WebsiteCrawlPayload.__name__])
+ @console_ns.response(200, "Website crawl initiated successfully")
+ @console_ns.response(400, "Invalid crawl parameters")
@setup_required
@login_required
@account_initialization_required
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument(
- "provider",
- type=str,
- choices=["firecrawl", "watercrawl", "jinareader"],
- required=True,
- nullable=True,
- location="json",
- )
- .add_argument("url", type=str, required=True, nullable=True, location="json")
- .add_argument("options", type=dict, required=True, nullable=True, location="json")
- )
- args = parser.parse_args()
+ payload = WebsiteCrawlPayload.model_validate(console_ns.payload or {})
# Create typed request and validate
try:
- api_request = WebsiteCrawlApiRequest.from_args(args)
+ api_request = WebsiteCrawlApiRequest.from_args(payload.model_dump())
except ValueError as e:
raise WebsiteCrawlError(str(e))
@@ -62,24 +54,22 @@ class WebsiteCrawlApi(Resource):
@console_ns.route("/website/crawl/status/")
class WebsiteCrawlStatusApi(Resource):
- @api.doc("get_crawl_status")
- @api.doc(description="Get website crawl status")
- @api.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
- @api.response(200, "Crawl status retrieved successfully")
- @api.response(404, "Crawl job not found")
- @api.response(400, "Invalid provider")
+ @console_ns.doc("get_crawl_status")
+ @console_ns.doc(description="Get website crawl status")
+ @console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
+ @console_ns.expect(console_ns.models[WebsiteCrawlStatusQuery.__name__])
+ @console_ns.response(200, "Crawl status retrieved successfully")
+ @console_ns.response(404, "Crawl job not found")
+ @console_ns.response(400, "Invalid provider")
@setup_required
@login_required
@account_initialization_required
def get(self, job_id: str):
- parser = reqparse.RequestParser().add_argument(
- "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
- )
- args = parser.parse_args()
+ args = WebsiteCrawlStatusQuery.model_validate(request.args.to_dict())
# Create typed request and validate
try:
- api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id)
+ api_request = WebsiteCrawlStatusApiRequest.from_args(args.model_dump(), job_id)
except ValueError as e:
raise WebsiteCrawlError(str(e))
diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py
index a8c1298e3e..3ef1341abc 100644
--- a/api/controllers/console/datasets/wraps.py
+++ b/api/controllers/console/datasets/wraps.py
@@ -1,44 +1,40 @@
from collections.abc import Callable
from functools import wraps
+from typing import ParamSpec, TypeVar
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.dataset import Pipeline
+P = ParamSpec("P")
+R = TypeVar("R")
-def get_rag_pipeline(
- view: Callable | None = None,
-):
- def decorator(view_func):
- @wraps(view_func)
- def decorated_view(*args, **kwargs):
- if not kwargs.get("pipeline_id"):
- raise ValueError("missing pipeline_id in path parameters")
- _, current_tenant_id = current_account_with_tenant()
+def get_rag_pipeline(view_func: Callable[P, R]):
+ @wraps(view_func)
+ def decorated_view(*args: P.args, **kwargs: P.kwargs):
+ if not kwargs.get("pipeline_id"):
+ raise ValueError("missing pipeline_id in path parameters")
- pipeline_id = kwargs.get("pipeline_id")
- pipeline_id = str(pipeline_id)
+ _, current_tenant_id = current_account_with_tenant()
- del kwargs["pipeline_id"]
+ pipeline_id = kwargs.get("pipeline_id")
+ pipeline_id = str(pipeline_id)
- pipeline = (
- db.session.query(Pipeline)
- .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
- .first()
- )
+ del kwargs["pipeline_id"]
- if not pipeline:
- raise PipelineNotFoundError()
+ pipeline = (
+ db.session.query(Pipeline)
+ .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
+ .first()
+ )
- kwargs["pipeline"] = pipeline
+ if not pipeline:
+ raise PipelineNotFoundError()
- return view_func(*args, **kwargs)
+ kwargs["pipeline"] = pipeline
- return decorated_view
+ return view_func(*args, **kwargs)
- if view is None:
- return decorator
- else:
- return decorator(view)
+ return decorated_view
diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py
index 2a248cf20d..0311db1584 100644
--- a/api/controllers/console/explore/audio.py
+++ b/api/controllers/console/explore/audio.py
@@ -1,9 +1,11 @@
import logging
from flask import request
+from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
+from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@@ -31,6 +33,16 @@ from .. import console_ns
logger = logging.getLogger(__name__)
+class TextToAudioPayload(BaseModel):
+ message_id: str | None = None
+ voice: str | None = None
+ text: str | None = None
+ streaming: bool | None = Field(default=None, description="Enable streaming response")
+
+
+register_schema_model(console_ns, TextToAudioPayload)
+
+
@console_ns.route(
"/installed-apps//audio-to-text",
endpoint="installed_app_audio",
@@ -76,23 +88,15 @@ class ChatAudioApi(InstalledAppResource):
endpoint="installed_app_text",
)
class ChatTextApi(InstalledAppResource):
+ @console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
def post(self, installed_app):
- from flask_restx import reqparse
-
app_model = installed_app.app
try:
- parser = (
- reqparse.RequestParser()
- .add_argument("message_id", type=str, required=False, location="json")
- .add_argument("voice", type=str, location="json")
- .add_argument("text", type=str, location="json")
- .add_argument("streaming", type=bool, location="json")
- )
- args = parser.parse_args()
+ payload = TextToAudioPayload.model_validate(console_ns.payload or {})
- message_id = args.get("message_id", None)
- text = args.get("text", None)
- voice = args.get("voice", None)
+ message_id = payload.message_id
+ text = payload.text
+ voice = payload.voice
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
return response
diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py
index 9386ecebae..5901eca915 100644
--- a/api/controllers/console/explore/completion.py
+++ b/api/controllers/console/explore/completion.py
@@ -1,9 +1,12 @@
import logging
+from typing import Any, Literal
+from uuid import UUID
-from flask_restx import reqparse
+from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
+from controllers.common.schema import register_schema_models
from controllers.console.app.error import (
AppUnavailableError,
CompletionRequestError,
@@ -15,7 +18,6 @@ from controllers.console.app.error import (
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
-from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
@@ -26,11 +28,11 @@ from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper
from libs.datetime_utils import naive_utc_now
-from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
+from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
from .. import console_ns
@@ -38,28 +40,56 @@ from .. import console_ns
logger = logging.getLogger(__name__)
+class CompletionMessagePayload(BaseModel):
+ inputs: dict[str, Any]
+ query: str = ""
+ files: list[dict[str, Any]] | None = None
+ response_mode: Literal["blocking", "streaming"] | None = None
+ retriever_from: str = Field(default="explore_app")
+
+
+class ChatMessagePayload(BaseModel):
+ inputs: dict[str, Any]
+ query: str
+ files: list[dict[str, Any]] | None = None
+ conversation_id: str | None = None
+ parent_message_id: str | None = None
+ retriever_from: str = Field(default="explore_app")
+
+ @field_validator("conversation_id", "parent_message_id", mode="before")
+ @classmethod
+ def normalize_uuid(cls, value: str | UUID | None) -> str | None:
+ """
+ Accept blank IDs and validate UUID format when provided.
+ """
+ if not value:
+ return None
+
+ try:
+ return helper.uuid_value(value)
+ except ValueError as exc:
+ raise ValueError("must be a valid UUID") from exc
+
+
+register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
+
+
# define completion api for user
@console_ns.route(
"/installed-apps//completion-messages",
endpoint="installed_app_completion",
)
class CompletionApi(InstalledAppResource):
+ @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
def post(self, installed_app):
app_model = installed_app.app
- if app_model.mode != "completion":
+ if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, location="json")
- .add_argument("query", type=str, location="json", default="")
- .add_argument("files", type=list, required=False, location="json")
- .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
- .add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
- )
- args = parser.parse_args()
+ payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
- streaming = args["response_mode"] == "streaming"
+ streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
installed_app.last_used_at = naive_utc_now()
@@ -102,12 +132,18 @@ class CompletionApi(InstalledAppResource):
class CompletionStopApi(InstalledAppResource):
def post(self, installed_app, task_id):
app_model = installed_app.app
- if app_model.mode != "completion":
+ if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
- AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
+
+ AppTaskService.stop_task(
+ task_id=task_id,
+ invoke_from=InvokeFrom.EXPLORE,
+ user_id=current_user.id,
+ app_mode=AppMode.value_of(app_model.mode),
+ )
return {"result": "success"}, 200
@@ -117,22 +153,15 @@ class CompletionStopApi(InstalledAppResource):
endpoint="installed_app_chat_completion",
)
class ChatApi(InstalledAppResource):
+ @console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
def post(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, location="json")
- .add_argument("query", type=str, required=True, location="json")
- .add_argument("files", type=list, required=False, location="json")
- .add_argument("conversation_id", type=uuid_value, location="json")
- .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
- .add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
- )
- args = parser.parse_args()
+ payload = ChatMessagePayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
args["auto_generate_name"] = False
@@ -184,6 +213,12 @@ class ChatStopApi(InstalledAppResource):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
- AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
+
+ AppTaskService.stop_task(
+ task_id=task_id,
+ invoke_from=InvokeFrom.EXPLORE,
+ user_id=current_user.id,
+ app_mode=app_mode,
+ )
return {"result": "success"}, 200
diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py
index 5a39363cc2..92da591ab4 100644
--- a/api/controllers/console/explore/conversation.py
+++ b/api/controllers/console/explore/conversation.py
@@ -1,14 +1,18 @@
-from flask_restx import marshal_with, reqparse
-from flask_restx.inputs import int_range
+from typing import Any
+from uuid import UUID
+
+from flask import request
+from flask_restx import marshal_with
+from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
+from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
-from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode
@@ -19,29 +23,51 @@ from services.web_conversation_service import WebConversationService
from .. import console_ns
+class ConversationListQuery(BaseModel):
+ last_id: UUID | None = None
+ limit: int = Field(default=20, ge=1, le=100)
+ pinned: bool | None = None
+
+
+class ConversationRenamePayload(BaseModel):
+ name: str | None = None
+ auto_generate: bool = False
+
+ @model_validator(mode="after")
+ def validate_name_requirement(self):
+ if not self.auto_generate:
+ if self.name is None or not self.name.strip():
+ raise ValueError("name is required when auto_generate is false")
+ return self
+
+
+register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
+
+
@console_ns.route(
"/installed-apps//conversations",
endpoint="installed_app_conversations",
)
class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
+ @console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("last_id", type=uuid_value, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- .add_argument("pinned", type=str, choices=["true", "false", None], location="args")
- )
- args = parser.parse_args()
-
- pinned = None
- if "pinned" in args and args["pinned"] is not None:
- pinned = args["pinned"] == "true"
+ raw_args: dict[str, Any] = {
+ "last_id": request.args.get("last_id"),
+ "limit": request.args.get("limit", default=20, type=int),
+ "pinned": request.args.get("pinned"),
+ }
+ if raw_args["last_id"] is None:
+ raw_args["last_id"] = None
+ pinned_value = raw_args["pinned"]
+ if isinstance(pinned_value, str):
+ raw_args["pinned"] = pinned_value == "true"
+ args = ConversationListQuery.model_validate(raw_args)
try:
if not isinstance(current_user, Account):
@@ -51,10 +77,10 @@ class ConversationListApi(InstalledAppResource):
session=session,
app_model=app_model,
user=current_user,
- last_id=args["last_id"],
- limit=args["limit"],
+ last_id=str(args.last_id) if args.last_id else None,
+ limit=args.limit,
invoke_from=InvokeFrom.EXPLORE,
- pinned=pinned,
+ pinned=args.pinned,
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@@ -88,6 +114,7 @@ class ConversationApi(InstalledAppResource):
)
class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields)
+ @console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
@@ -96,18 +123,13 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id)
- parser = (
- reqparse.RequestParser()
- .add_argument("name", type=str, required=False, location="json")
- .add_argument("auto_generate", type=bool, required=False, default=False, location="json")
- )
- args = parser.parse_args()
+ payload = ConversationRenamePayload.model_validate(console_ns.payload or {})
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return ConversationService.rename(
- app_model, conversation_id, current_user, args["name"], args["auto_generate"]
+ app_model, conversation_id, current_user, payload.name, payload.auto_generate
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py
index db854e09bb..229b7c8865 100644
--- a/api/controllers/console/explore/message.py
+++ b/api/controllers/console/explore/message.py
@@ -1,9 +1,13 @@
import logging
+from typing import Literal
+from uuid import UUID
-from flask_restx import marshal_with, reqparse
-from flask_restx.inputs import int_range
+from flask import request
+from flask_restx import marshal_with
+from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError, NotFound
+from controllers.common.schema import register_schema_models
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
CompletionRequestError,
@@ -22,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
-from libs.helper import uuid_value
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@@ -40,12 +43,31 @@ from .. import console_ns
logger = logging.getLogger(__name__)
+class MessageListQuery(BaseModel):
+ conversation_id: UUID
+ first_id: UUID | None = None
+ limit: int = Field(default=20, ge=1, le=100)
+
+
+class MessageFeedbackPayload(BaseModel):
+ rating: Literal["like", "dislike"] | None = None
+ content: str | None = None
+
+
+class MoreLikeThisQuery(BaseModel):
+ response_mode: Literal["blocking", "streaming"]
+
+
+register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, MoreLikeThisQuery)
+
+
@console_ns.route(
"/installed-apps//messages",
endpoint="installed_app_messages",
)
class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields)
+ @console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
@@ -53,18 +75,15 @@ class MessageListApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
-
- parser = (
- reqparse.RequestParser()
- .add_argument("conversation_id", required=True, type=uuid_value, location="args")
- .add_argument("first_id", type=uuid_value, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- )
- args = parser.parse_args()
+ args = MessageListQuery.model_validate(request.args.to_dict())
try:
return MessageService.pagination_by_first_id(
- app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
+ app_model,
+ current_user,
+ str(args.conversation_id),
+ str(args.first_id) if args.first_id else None,
+ args.limit,
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -77,26 +96,22 @@ class MessageListApi(InstalledAppResource):
endpoint="installed_app_message_feedback",
)
class MessageFeedbackApi(InstalledAppResource):
+ @console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
message_id = str(message_id)
- parser = (
- reqparse.RequestParser()
- .add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
- .add_argument("content", type=str, location="json")
- )
- args = parser.parse_args()
+ payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=current_user,
- rating=args.get("rating"),
- content=args.get("content"),
+ rating=payload.rating,
+ content=payload.content,
)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@@ -109,6 +124,7 @@ class MessageFeedbackApi(InstalledAppResource):
endpoint="installed_app_more_like_this",
)
class MessageMoreLikeThisApi(InstalledAppResource):
+ @console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
@@ -117,12 +133,9 @@ class MessageMoreLikeThisApi(InstalledAppResource):
message_id = str(message_id)
- parser = reqparse.RequestParser().add_argument(
- "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
- )
- args = parser.parse_args()
+ args = MoreLikeThisQuery.model_validate(request.args.to_dict())
- streaming = args["response_mode"] == "streaming"
+ streaming = args.response_mode == "streaming"
try:
response = AppGenerateService.generate_more_like_this(
diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py
index 11c7a1bc18..2b2f807694 100644
--- a/api/controllers/console/explore/recommended_app.py
+++ b/api/controllers/console/explore/recommended_app.py
@@ -1,7 +1,9 @@
-from flask_restx import Resource, fields, marshal_with, reqparse
+from flask import request
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field
from constants.languages import languages
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from libs.login import current_user, login_required
@@ -35,20 +37,26 @@ recommended_app_list_fields = {
}
-parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
+class RecommendedAppsQuery(BaseModel):
+ language: str | None = Field(default=None)
+
+
+console_ns.schema_model(
+ RecommendedAppsQuery.__name__,
+ RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
+)
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
- @api.expect(parser_apps)
+ @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_fields)
def get(self):
# language args
- args = parser_apps.parse_args()
-
- language = args.get("language")
+ args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+ language = args.language
if language and language in languages:
language_prefix = language
elif current_user and current_user.interface_language:
diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py
index 9775c951f7..6a9e274a0e 100644
--- a/api/controllers/console/explore/saved_message.py
+++ b/api/controllers/console/explore/saved_message.py
@@ -1,16 +1,33 @@
-from flask_restx import fields, marshal_with, reqparse
-from flask_restx.inputs import int_range
+from uuid import UUID
+
+from flask import request
+from flask_restx import fields, marshal_with
+from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
+from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
-from libs.helper import TimestampField, uuid_value
+from libs.helper import TimestampField
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
+
+class SavedMessageListQuery(BaseModel):
+ last_id: UUID | None = None
+ limit: int = Field(default=20, ge=1, le=100)
+
+
+class SavedMessageCreatePayload(BaseModel):
+ message_id: UUID
+
+
+register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
+
+
feedback_fields = {"rating": fields.String}
message_fields = {
@@ -33,32 +50,33 @@ class SavedMessageListApi(InstalledAppResource):
}
@marshal_with(saved_message_infinite_scroll_pagination_fields)
+ @console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("last_id", type=uuid_value, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
+ args = SavedMessageListQuery.model_validate(request.args.to_dict())
+
+ return SavedMessageService.pagination_by_last_id(
+ app_model,
+ current_user,
+ str(args.last_id) if args.last_id else None,
+ args.limit,
)
- args = parser.parse_args()
-
- return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
+ @console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
def post(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
- parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
- args = parser.parse_args()
+ payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {})
try:
- SavedMessageService.save(app_model, current_user, args["message_id"])
+ SavedMessageService.save(app_model, current_user, str(payload.message_id))
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py
index 125f603a5a..d679d0722d 100644
--- a/api/controllers/console/explore/workflow.py
+++ b/api/controllers/console/explore/workflow.py
@@ -1,8 +1,10 @@
import logging
+from typing import Any
-from flask_restx import reqparse
+from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError
+from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
@@ -32,8 +34,17 @@ from .. import console_ns
logger = logging.getLogger(__name__)
+class WorkflowRunPayload(BaseModel):
+ inputs: dict[str, Any]
+ files: list[dict[str, Any]] | None = None
+
+
+register_schema_model(console_ns, WorkflowRunPayload)
+
+
@console_ns.route("/installed-apps//workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
+ @console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
def post(self, installed_app: InstalledApp):
"""
Run workflow
@@ -46,12 +57,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
- .add_argument("files", type=list, required=False, location="json")
- )
- args = parser.parse_args()
+ payload = WorkflowRunPayload.model_validate(console_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py
index a1d36def0d..08f29b4655 100644
--- a/api/controllers/console/extension.py
+++ b/api/controllers/console/extension.py
@@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from constants import HIDDEN_VALUE
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_account_with_tenant, login_required
@@ -9,18 +9,24 @@ from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
+api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
+
+api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
+
@console_ns.route("/code-based-extension")
class CodeBasedExtensionAPI(Resource):
- @api.doc("get_code_based_extension")
- @api.doc(description="Get code-based extension data by module name")
- @api.expect(
- api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name")
+ @console_ns.doc("get_code_based_extension")
+ @console_ns.doc(description="Get code-based extension data by module name")
+ @console_ns.expect(
+ console_ns.parser().add_argument(
+ "module", type=str, required=True, location="args", help="Extension module name"
+ )
)
- @api.response(
+ @console_ns.response(
200,
"Success",
- api.model(
+ console_ns.model(
"CodeBasedExtensionResponse",
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
),
@@ -37,21 +43,21 @@ class CodeBasedExtensionAPI(Resource):
@console_ns.route("/api-based-extension")
class APIBasedExtensionAPI(Resource):
- @api.doc("get_api_based_extensions")
- @api.doc(description="Get all API-based extensions for current tenant")
- @api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields)))
+ @console_ns.doc("get_api_based_extensions")
+ @console_ns.doc(description="Get all API-based extensions for current tenant")
+ @console_ns.response(200, "Success", api_based_extension_list_model)
@setup_required
@login_required
@account_initialization_required
- @marshal_with(api_based_extension_fields)
+ @marshal_with(api_based_extension_model)
def get(self):
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
- @api.doc("create_api_based_extension")
- @api.doc(description="Create a new API-based extension")
- @api.expect(
- api.model(
+ @console_ns.doc("create_api_based_extension")
+ @console_ns.doc(description="Create a new API-based extension")
+ @console_ns.expect(
+ console_ns.model(
"CreateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
@@ -60,13 +66,13 @@ class APIBasedExtensionAPI(Resource):
},
)
)
- @api.response(201, "Extension created successfully", api_based_extension_fields)
+ @console_ns.response(201, "Extension created successfully", api_based_extension_model)
@setup_required
@login_required
@account_initialization_required
- @marshal_with(api_based_extension_fields)
+ @marshal_with(api_based_extension_model)
def post(self):
- args = api.payload
+ args = console_ns.payload
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
@@ -81,25 +87,25 @@ class APIBasedExtensionAPI(Resource):
@console_ns.route("/api-based-extension/")
class APIBasedExtensionDetailAPI(Resource):
- @api.doc("get_api_based_extension")
- @api.doc(description="Get API-based extension by ID")
- @api.doc(params={"id": "Extension ID"})
- @api.response(200, "Success", api_based_extension_fields)
+ @console_ns.doc("get_api_based_extension")
+ @console_ns.doc(description="Get API-based extension by ID")
+ @console_ns.doc(params={"id": "Extension ID"})
+ @console_ns.response(200, "Success", api_based_extension_model)
@setup_required
@login_required
@account_initialization_required
- @marshal_with(api_based_extension_fields)
+ @marshal_with(api_based_extension_model)
def get(self, id):
api_based_extension_id = str(id)
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
- @api.doc("update_api_based_extension")
- @api.doc(description="Update API-based extension")
- @api.doc(params={"id": "Extension ID"})
- @api.expect(
- api.model(
+ @console_ns.doc("update_api_based_extension")
+ @console_ns.doc(description="Update API-based extension")
+ @console_ns.doc(params={"id": "Extension ID"})
+ @console_ns.expect(
+ console_ns.model(
"UpdateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
@@ -108,18 +114,18 @@ class APIBasedExtensionDetailAPI(Resource):
},
)
)
- @api.response(200, "Extension updated successfully", api_based_extension_fields)
+ @console_ns.response(200, "Extension updated successfully", api_based_extension_model)
@setup_required
@login_required
@account_initialization_required
- @marshal_with(api_based_extension_fields)
+ @marshal_with(api_based_extension_model)
def post(self, id):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
- args = api.payload
+ args = console_ns.payload
extension_data_from_db.name = args["name"]
extension_data_from_db.api_endpoint = args["api_endpoint"]
@@ -129,10 +135,10 @@ class APIBasedExtensionDetailAPI(Resource):
return APIBasedExtensionService.save(extension_data_from_db)
- @api.doc("delete_api_based_extension")
- @api.doc(description="Delete API-based extension")
- @api.doc(params={"id": "Extension ID"})
- @api.response(204, "Extension deleted successfully")
+ @console_ns.doc("delete_api_based_extension")
+ @console_ns.doc(description="Delete API-based extension")
+ @console_ns.doc(params={"id": "Extension ID"})
+ @console_ns.response(204, "Extension deleted successfully")
@setup_required
@login_required
@account_initialization_required
diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py
index 39bcf3424c..6951c906e9 100644
--- a/api/controllers/console/feature.py
+++ b/api/controllers/console/feature.py
@@ -3,18 +3,18 @@ from flask_restx import Resource, fields
from libs.login import current_account_with_tenant, login_required
from services.feature_service import FeatureService
-from . import api, console_ns
+from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required
@console_ns.route("/features")
class FeatureApi(Resource):
- @api.doc("get_tenant_features")
- @api.doc(description="Get feature configuration for current tenant")
- @api.response(
+ @console_ns.doc("get_tenant_features")
+ @console_ns.doc(description="Get feature configuration for current tenant")
+ @console_ns.response(
200,
"Success",
- api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
+ console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
)
@setup_required
@login_required
@@ -29,12 +29,14 @@ class FeatureApi(Resource):
@console_ns.route("/system-features")
class SystemFeatureApi(Resource):
- @api.doc("get_system_features")
- @api.doc(description="Get system-wide feature configuration")
- @api.response(
+ @console_ns.doc("get_system_features")
+ @console_ns.doc(description="Get system-wide feature configuration")
+ @console_ns.response(
200,
"Success",
- api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}),
+ console_ns.model(
+ "SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}
+ ),
)
def get(self):
"""Get system-wide feature configuration"""
diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py
index fdd7c2f479..29417dc896 100644
--- a/api/controllers/console/files.py
+++ b/api/controllers/console/files.py
@@ -45,6 +45,9 @@ class FileApi(Resource):
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
+ "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
+ "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
+ "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
}, 200
@setup_required
diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py
index f219425d07..2bebe79eac 100644
--- a/api/controllers/console/init_validate.py
+++ b/api/controllers/console/init_validate.py
@@ -1,29 +1,41 @@
import os
from flask import session
-from flask_restx import Resource, fields, reqparse
+from flask_restx import Resource, fields
+from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from extensions.ext_database import db
-from libs.helper import StrLen
from models.model import DifySetup
from services.account_service import TenantService
-from . import api, console_ns
+from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class InitValidatePayload(BaseModel):
+ password: str = Field(..., max_length=30)
+
+
+console_ns.schema_model(
+ InitValidatePayload.__name__,
+ InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
@console_ns.route("/init")
class InitValidateAPI(Resource):
- @api.doc("get_init_status")
- @api.doc(description="Get initialization validation status")
- @api.response(
+ @console_ns.doc("get_init_status")
+ @console_ns.doc(description="Get initialization validation status")
+ @console_ns.response(
200,
"Success",
- model=api.model(
+ model=console_ns.model(
"InitStatusResponse",
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
),
@@ -35,20 +47,15 @@ class InitValidateAPI(Resource):
return {"status": "finished"}
return {"status": "not_started"}
- @api.doc("validate_init_password")
- @api.doc(description="Validate initialization password for self-hosted edition")
- @api.expect(
- api.model(
- "InitValidateRequest",
- {"password": fields.String(required=True, description="Initialization password", max_length=30)},
- )
- )
- @api.response(
+ @console_ns.doc("validate_init_password")
+ @console_ns.doc(description="Validate initialization password for self-hosted edition")
+ @console_ns.expect(console_ns.models[InitValidatePayload.__name__])
+ @console_ns.response(
201,
"Success",
- model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
+ model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
)
- @api.response(400, "Already setup or validation failed")
+ @console_ns.response(400, "Already setup or validation failed")
@only_edition_self_hosted
def post(self):
"""Validate initialization password"""
@@ -57,8 +64,8 @@ class InitValidateAPI(Resource):
if tenant_count > 0:
raise AlreadySetupError()
- parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json")
- input_password = parser.parse_args()["password"]
+ payload = InitValidatePayload.model_validate(console_ns.payload)
+ input_password = payload.password
if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py
index 29f49b99de..25a3d80522 100644
--- a/api/controllers/console/ping.py
+++ b/api/controllers/console/ping.py
@@ -1,16 +1,16 @@
from flask_restx import Resource, fields
-from . import api, console_ns
+from . import console_ns
@console_ns.route("/ping")
class PingApi(Resource):
- @api.doc("health_check")
- @api.doc(description="Health check endpoint for connection testing")
- @api.response(
+ @console_ns.doc("health_check")
+ @console_ns.doc(description="Health check endpoint for connection testing")
+ @console_ns.response(
200,
"Success",
- api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
+ console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
)
def get(self):
"""Health check endpoint for connection testing"""
diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py
index 47c7ecde9a..47eef7eb7e 100644
--- a/api/controllers/console/remote_files.py
+++ b/api/controllers/console/remote_files.py
@@ -1,7 +1,8 @@
import urllib.parse
import httpx
-from flask_restx import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with
+from pydantic import BaseModel, Field
import services
from controllers.common import helpers
@@ -10,7 +11,6 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
-from controllers.console import api
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
@@ -37,17 +37,23 @@ class RemoteFileInfoApi(Resource):
}
-parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
+class RemoteFileUploadPayload(BaseModel):
+ url: str = Field(..., description="URL to fetch")
+
+
+console_ns.schema_model(
+ RemoteFileUploadPayload.__name__,
+ RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
+)
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource):
- @api.expect(parser_upload)
+ @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@marshal_with(file_fields_with_signed_url)
def post(self):
- args = parser_upload.parse_args()
-
- url = args["url"]
+ args = RemoteFileUploadPayload.model_validate(console_ns.payload)
+ url = args.url
try:
resp = ssrf_proxy.head(url=url)
diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py
index 22929c851e..7fa02ae280 100644
--- a/api/controllers/console/setup.py
+++ b/api/controllers/console/setup.py
@@ -1,26 +1,47 @@
from flask import request
-from flask_restx import Resource, fields, reqparse
+from flask_restx import Resource, fields
+from pydantic import BaseModel, Field, field_validator
from configs import dify_config
-from libs.helper import StrLen, email, extract_remote_ip
+from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService
-from . import api, console_ns
+from . import console_ns
from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class SetupRequestPayload(BaseModel):
+ email: EmailStr = Field(..., description="Admin email address")
+ name: str = Field(..., max_length=30, description="Admin name (max 30 characters)")
+ password: str = Field(..., description="Admin password")
+ language: str | None = Field(default=None, description="Admin language")
+
+ @field_validator("password")
+ @classmethod
+ def validate_password(cls, value: str) -> str:
+ return valid_password(value)
+
+
+console_ns.schema_model(
+ SetupRequestPayload.__name__,
+ SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
@console_ns.route("/setup")
class SetupApi(Resource):
- @api.doc("get_setup_status")
- @api.doc(description="Get system setup status")
- @api.response(
+ @console_ns.doc("get_setup_status")
+ @console_ns.doc(description="Get system setup status")
+ @console_ns.response(
200,
"Success",
- api.model(
+ console_ns.model(
"SetupStatusResponse",
{
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
@@ -40,21 +61,13 @@ class SetupApi(Resource):
return {"step": "not_started"}
return {"step": "finished"}
- @api.doc("setup_system")
- @api.doc(description="Initialize system setup with admin account")
- @api.expect(
- api.model(
- "SetupRequest",
- {
- "email": fields.String(required=True, description="Admin email address"),
- "name": fields.String(required=True, description="Admin name (max 30 characters)"),
- "password": fields.String(required=True, description="Admin password"),
- "language": fields.String(required=False, description="Admin language"),
- },
- )
+ @console_ns.doc("setup_system")
+ @console_ns.doc(description="Initialize system setup with admin account")
+ @console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
+ @console_ns.response(
+ 201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
)
- @api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")}))
- @api.response(400, "Already setup or validation failed")
+ @console_ns.response(400, "Already setup or validation failed")
@only_edition_self_hosted
def post(self):
"""Initialize system setup with admin account"""
@@ -70,22 +83,15 @@ class SetupApi(Resource):
if not get_init_validate_status():
raise NotInitValidateError()
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("name", type=StrLen(30), required=True, location="json")
- .add_argument("password", type=valid_password, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- )
- args = parser.parse_args()
+ args = SetupRequestPayload.model_validate(console_ns.payload)
# setup
RegisterService.setup(
- email=args["email"],
- name=args["name"],
- password=args["password"],
+ email=args.email,
+ name=args.name,
+ password=args.password,
ip_address=extract_remote_ip(request),
- language=args["language"],
+ language=args.language,
)
return {"result": "success"}, 201
diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py
index ee032756eb..17cfc3ff4b 100644
--- a/api/controllers/console/tag/tags.py
+++ b/api/controllers/console/tag/tags.py
@@ -2,7 +2,7 @@ from flask import request
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required
@@ -43,7 +43,7 @@ class TagListApi(Resource):
return tags, 200
- @api.expect(parser_tags)
+ @console_ns.expect(parser_tags)
@setup_required
@login_required
@account_initialization_required
@@ -68,7 +68,7 @@ parser_tag_id = reqparse.RequestParser().add_argument(
@console_ns.route("/tags/")
class TagUpdateDeleteApi(Resource):
- @api.expect(parser_tag_id)
+ @console_ns.expect(parser_tag_id)
@setup_required
@login_required
@account_initialization_required
@@ -110,7 +110,7 @@ parser_create = (
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
- @api.expect(parser_create)
+ @console_ns.expect(parser_create)
@setup_required
@login_required
@account_initialization_required
@@ -136,7 +136,7 @@ parser_remove = (
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
- @api.expect(parser_remove)
+ @console_ns.expect(parser_remove)
@setup_required
@login_required
@account_initialization_required
diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py
index 104a205fc8..419261ba2a 100644
--- a/api/controllers/console/version.py
+++ b/api/controllers/console/version.py
@@ -2,29 +2,37 @@ import json
import logging
import httpx
-from flask_restx import Resource, fields, reqparse
+from flask import request
+from flask_restx import Resource, fields
from packaging import version
+from pydantic import BaseModel, Field
from configs import dify_config
-from . import api, console_ns
+from . import console_ns
logger = logging.getLogger(__name__)
-parser = reqparse.RequestParser().add_argument(
- "current_version", type=str, required=True, location="args", help="Current application version"
+
+class VersionQuery(BaseModel):
+ current_version: str = Field(..., description="Current application version")
+
+
+console_ns.schema_model(
+ VersionQuery.__name__,
+ VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/version")
class VersionApi(Resource):
- @api.doc("check_version_update")
- @api.doc(description="Check for application version updates")
- @api.expect(parser)
- @api.response(
+ @console_ns.doc("check_version_update")
+ @console_ns.doc(description="Check for application version updates")
+ @console_ns.expect(console_ns.models[VersionQuery.__name__])
+ @console_ns.response(
200,
"Success",
- api.model(
+ console_ns.model(
"VersionResponse",
{
"version": fields.String(description="Latest version number"),
@@ -37,7 +45,7 @@ class VersionApi(Resource):
)
def get(self):
"""Check for application version updates"""
- args = parser.parse_args()
+ args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
check_update_url = dify_config.CHECK_UPDATE_URL
result = {
@@ -57,16 +65,16 @@ class VersionApi(Resource):
try:
response = httpx.get(
check_update_url,
- params={"current_version": args["current_version"]},
- timeout=httpx.Timeout(connect=3, read=10),
+ params={"current_version": args.current_version},
+ timeout=httpx.Timeout(timeout=10.0, connect=3.0),
)
except Exception as error:
logger.warning("Check update version error: %s.", str(error))
- result["version"] = args["current_version"]
+ result["version"] = args.current_version
return result
content = json.loads(response.content)
- if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"):
+ if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
result["version"] = content["version"]
result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"]
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index 0833b39f41..55eaa2f09f 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -1,14 +1,16 @@
from datetime import datetime
+from typing import Literal
import pytz
from flask import request
-from flask_restx import Resource, fields, marshal_with, reqparse
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailChangeLimitError,
@@ -35,27 +37,142 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
-from libs.helper import TimestampField, email, extract_remote_ip, timezone
+from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-def _init_parser():
- parser = reqparse.RequestParser()
- if dify_config.EDITION == "CLOUD":
- parser.add_argument("invitation_code", type=str, location="json")
- parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
- "timezone", type=timezone, required=True, location="json"
- )
- return parser
+
+class AccountInitPayload(BaseModel):
+ interface_language: str
+ timezone: str
+ invitation_code: str | None = None
+
+ @field_validator("interface_language")
+ @classmethod
+ def validate_language(cls, value: str) -> str:
+ return supported_language(value)
+
+ @field_validator("timezone")
+ @classmethod
+ def validate_timezone(cls, value: str) -> str:
+ return timezone(value)
+
+
+class AccountNamePayload(BaseModel):
+ name: str = Field(min_length=3, max_length=30)
+
+
+class AccountAvatarPayload(BaseModel):
+ avatar: str
+
+
+class AccountInterfaceLanguagePayload(BaseModel):
+ interface_language: str
+
+ @field_validator("interface_language")
+ @classmethod
+ def validate_language(cls, value: str) -> str:
+ return supported_language(value)
+
+
+class AccountInterfaceThemePayload(BaseModel):
+ interface_theme: Literal["light", "dark"]
+
+
+class AccountTimezonePayload(BaseModel):
+ timezone: str
+
+ @field_validator("timezone")
+ @classmethod
+ def validate_timezone(cls, value: str) -> str:
+ return timezone(value)
+
+
+class AccountPasswordPayload(BaseModel):
+ password: str | None = None
+ new_password: str
+ repeat_new_password: str
+
+ @model_validator(mode="after")
+ def check_passwords_match(self) -> "AccountPasswordPayload":
+ if self.new_password != self.repeat_new_password:
+ raise RepeatPasswordNotMatchError()
+ return self
+
+
+class AccountDeletePayload(BaseModel):
+ token: str
+ code: str
+
+
+class AccountDeletionFeedbackPayload(BaseModel):
+ email: EmailStr
+ feedback: str
+
+
+class EducationActivatePayload(BaseModel):
+ token: str
+ institution: str
+ role: str
+
+
+class EducationAutocompleteQuery(BaseModel):
+ keywords: str
+ page: int = 0
+ limit: int = 20
+
+
+class ChangeEmailSendPayload(BaseModel):
+ email: EmailStr
+ language: str | None = None
+ phase: str | None = None
+ token: str | None = None
+
+
+class ChangeEmailValidityPayload(BaseModel):
+ email: EmailStr
+ code: str
+ token: str
+
+
+class ChangeEmailResetPayload(BaseModel):
+ new_email: EmailStr
+ token: str
+
+
+class CheckEmailUniquePayload(BaseModel):
+ email: EmailStr
+
+
+def reg(cls: type[BaseModel]):
+ console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+reg(AccountInitPayload)
+reg(AccountNamePayload)
+reg(AccountAvatarPayload)
+reg(AccountInterfaceLanguagePayload)
+reg(AccountInterfaceThemePayload)
+reg(AccountTimezonePayload)
+reg(AccountPasswordPayload)
+reg(AccountDeletePayload)
+reg(AccountDeletionFeedbackPayload)
+reg(EducationActivatePayload)
+reg(EducationAutocompleteQuery)
+reg(ChangeEmailSendPayload)
+reg(ChangeEmailValidityPayload)
+reg(ChangeEmailResetPayload)
+reg(CheckEmailUniquePayload)
@console_ns.route("/account/init")
class AccountInitApi(Resource):
- @api.expect(_init_parser())
+ @console_ns.expect(console_ns.models[AccountInitPayload.__name__])
@setup_required
@login_required
def post(self):
@@ -64,17 +181,18 @@ class AccountInitApi(Resource):
if account.status == "active":
raise AccountAlreadyInitedError()
- args = _init_parser().parse_args()
+ payload = console_ns.payload or {}
+ args = AccountInitPayload.model_validate(payload)
if dify_config.EDITION == "CLOUD":
- if not args["invitation_code"]:
+ if not args.invitation_code:
raise ValueError("invitation_code is required")
# check invitation code
invitation_code = (
db.session.query(InvitationCode)
.where(
- InvitationCode.code == args["invitation_code"],
+ InvitationCode.code == args.invitation_code,
InvitationCode.status == "unused",
)
.first()
@@ -88,8 +206,8 @@ class AccountInitApi(Resource):
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id
- account.interface_language = args["interface_language"]
- account.timezone = args["timezone"]
+ account.interface_language = args.interface_language
+ account.timezone = args.timezone
account.interface_theme = "light"
account.status = "active"
account.initialized_at = naive_utc_now()
@@ -110,137 +228,104 @@ class AccountProfileApi(Resource):
return current_user
-parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
-
-
@console_ns.route("/account/name")
class AccountNameApi(Resource):
- @api.expect(parser_name)
+ @console_ns.expect(console_ns.models[AccountNamePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_name.parse_args()
-
- # Validate account name length
- if len(args["name"]) < 3 or len(args["name"]) > 30:
- raise ValueError("Account name must be between 3 and 30 characters.")
-
- updated_account = AccountService.update_account(current_user, name=args["name"])
+ payload = console_ns.payload or {}
+ args = AccountNamePayload.model_validate(payload)
+ updated_account = AccountService.update_account(current_user, name=args.name)
return updated_account
-parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
-
-
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
- @api.expect(parser_avatar)
+ @console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_avatar.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountAvatarPayload.model_validate(payload)
- updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
+ updated_account = AccountService.update_account(current_user, avatar=args.avatar)
return updated_account
-parser_interface = reqparse.RequestParser().add_argument(
- "interface_language", type=supported_language, required=True, location="json"
-)
-
-
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource):
- @api.expect(parser_interface)
+ @console_ns.expect(console_ns.models[AccountInterfaceLanguagePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_interface.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountInterfaceLanguagePayload.model_validate(payload)
- updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
+ updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
return updated_account
-parser_theme = reqparse.RequestParser().add_argument(
- "interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
-)
-
-
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource):
- @api.expect(parser_theme)
+ @console_ns.expect(console_ns.models[AccountInterfaceThemePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_theme.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountInterfaceThemePayload.model_validate(payload)
- updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
+ updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
return updated_account
-parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
-
-
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource):
- @api.expect(parser_timezone)
+ @console_ns.expect(console_ns.models[AccountTimezonePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_timezone.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountTimezonePayload.model_validate(payload)
- # Validate timezone string, e.g. America/New_York, Asia/Shanghai
- if args["timezone"] not in pytz.all_timezones:
- raise ValueError("Invalid timezone string.")
-
- updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
+ updated_account = AccountService.update_account(current_user, timezone=args.timezone)
return updated_account
-parser_pw = (
- reqparse.RequestParser()
- .add_argument("password", type=str, required=False, location="json")
- .add_argument("new_password", type=str, required=True, location="json")
- .add_argument("repeat_new_password", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/account/password")
class AccountPasswordApi(Resource):
- @api.expect(parser_pw)
+ @console_ns.expect(console_ns.models[AccountPasswordPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_pw.parse_args()
-
- if args["new_password"] != args["repeat_new_password"]:
- raise RepeatPasswordNotMatchError()
+ payload = console_ns.payload or {}
+ args = AccountPasswordPayload.model_validate(payload)
try:
- AccountService.update_account_password(current_user, args["password"], args["new_password"])
+ AccountService.update_account_password(current_user, args.password, args.new_password)
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
@@ -316,25 +401,19 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token}
-parser_delete = (
- reqparse.RequestParser()
- .add_argument("token", type=str, required=True, location="json")
- .add_argument("code", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource):
- @api.expect(parser_delete)
+ @console_ns.expect(console_ns.models[AccountDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
account, _ = current_account_with_tenant()
- args = parser_delete.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountDeletePayload.model_validate(payload)
- if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
+ if not AccountService.verify_account_deletion_code(args.token, args.code):
raise InvalidAccountDeletionCodeError()
AccountService.delete_account(account)
@@ -342,21 +421,15 @@ class AccountDeleteApi(Resource):
return {"result": "success"}
-parser_feedback = (
- reqparse.RequestParser()
- .add_argument("email", type=str, required=True, location="json")
- .add_argument("feedback", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource):
- @api.expect(parser_feedback)
+ @console_ns.expect(console_ns.models[AccountDeletionFeedbackPayload.__name__])
@setup_required
def post(self):
- args = parser_feedback.parse_args()
+ payload = console_ns.payload or {}
+ args = AccountDeletionFeedbackPayload.model_validate(payload)
- BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
+ BillingService.update_account_deletion_feedback(args.email, args.feedback)
return {"result": "success"}
@@ -379,14 +452,6 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email)
-parser_edu = (
- reqparse.RequestParser()
- .add_argument("token", type=str, required=True, location="json")
- .add_argument("institution", type=str, required=True, location="json")
- .add_argument("role", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
@@ -396,7 +461,7 @@ class EducationApi(Resource):
"allow_refresh": fields.Boolean,
}
- @api.expect(parser_edu)
+ @console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -405,9 +470,10 @@ class EducationApi(Resource):
def post(self):
account, _ = current_account_with_tenant()
- args = parser_edu.parse_args()
+ payload = console_ns.payload or {}
+ args = EducationActivatePayload.model_validate(payload)
- return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
+ return BillingService.EducationIdentity.activate(account, args.token, args.institution, args.role)
@setup_required
@login_required
@@ -425,14 +491,6 @@ class EducationApi(Resource):
return res
-parser_autocomplete = (
- reqparse.RequestParser()
- .add_argument("keywords", type=str, required=True, location="args")
- .add_argument("page", type=int, required=False, location="args", default=0)
- .add_argument("limit", type=int, required=False, location="args", default=20)
-)
-
-
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
@@ -441,7 +499,7 @@ class EducationAutoCompleteApi(Resource):
"has_next": fields.Boolean,
}
- @api.expect(parser_autocomplete)
+ @console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -449,46 +507,39 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
- args = parser_autocomplete.parse_args()
+ payload = request.args.to_dict(flat=True) # type: ignore
+ args = EducationAutocompleteQuery.model_validate(payload)
- return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
-
-
-parser_change_email = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- .add_argument("phase", type=str, required=False, location="json")
- .add_argument("token", type=str, required=False, location="json")
-)
+ return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource):
- @api.expect(parser_change_email)
+ @console_ns.expect(console_ns.models[ChangeEmailSendPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
- args = parser_change_email.parse_args()
+ payload = console_ns.payload or {}
+ args = ChangeEmailSendPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
- if args["language"] is not None and args["language"] == "zh-Hans":
+ if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = None
- user_email = args["email"]
- if args["phase"] is not None and args["phase"] == "new_email":
- if args["token"] is None:
+ user_email = args.email
+ if args.phase is not None and args.phase == "new_email":
+ if args.token is None:
raise InvalidTokenError()
- reset_data = AccountService.get_change_email_data(args["token"])
+ reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
user_email = reset_data.get("email", "")
@@ -497,118 +548,103 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidEmailError()
else:
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
+ account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
if account is None:
raise AccountNotFound()
token = AccountService.send_change_email_email(
- account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"]
+ account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
)
return {"result": "success", "data": token}
-parser_validity = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("code", type=str, required=True, location="json")
- .add_argument("token", type=str, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource):
- @api.expect(parser_validity)
+ @console_ns.expect(console_ns.models[ChangeEmailValidityPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
- args = parser_validity.parse_args()
+ payload = console_ns.payload or {}
+ args = ChangeEmailValidityPayload.model_validate(payload)
- user_email = args["email"]
+ user_email = args.email
- is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"])
+ is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
- token_data = AccountService.get_change_email_data(args["token"])
+ token_data = AccountService.get_change_email_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
- if args["code"] != token_data.get("code"):
- AccountService.add_change_email_error_rate_limit(args["email"])
+ if args.code != token_data.get("code"):
+ AccountService.add_change_email_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
- AccountService.revoke_change_email_token(args["token"])
+ AccountService.revoke_change_email_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_change_email_token(
- user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={}
+ user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
- AccountService.reset_change_email_error_rate_limit(args["email"])
+ AccountService.reset_change_email_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
-parser_reset = (
- reqparse.RequestParser()
- .add_argument("new_email", type=email, required=True, location="json")
- .add_argument("token", type=str, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource):
- @api.expect(parser_reset)
+ @console_ns.expect(console_ns.models[ChangeEmailResetPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
- args = parser_reset.parse_args()
+ payload = console_ns.payload or {}
+ args = ChangeEmailResetPayload.model_validate(payload)
- if AccountService.is_account_in_freeze(args["new_email"]):
+ if AccountService.is_account_in_freeze(args.new_email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args["new_email"]):
+ if not AccountService.check_email_unique(args.new_email):
raise EmailAlreadyInUseError()
- reset_data = AccountService.get_change_email_data(args["token"])
+ reset_data = AccountService.get_change_email_data(args.token)
if not reset_data:
raise InvalidTokenError()
- AccountService.revoke_change_email_token(args["token"])
+ AccountService.revoke_change_email_token(args.token)
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email != old_email:
raise AccountNotFound()
- updated_account = AccountService.update_account_email(current_user, email=args["new_email"])
+ updated_account = AccountService.update_account_email(current_user, email=args.new_email)
AccountService.send_change_email_completed_notify_email(
- email=args["new_email"],
+ email=args.new_email,
)
return updated_account
-parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
-
-
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource):
- @api.expect(parser_check)
+ @console_ns.expect(console_ns.models[CheckEmailUniquePayload.__name__])
@setup_required
def post(self):
- args = parser_check.parse_args()
- if AccountService.is_account_in_freeze(args["email"]):
+ payload = console_ns.payload or {}
+ args = CheckEmailUniquePayload.model_validate(payload)
+ if AccountService.is_account_in_freeze(args.email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args["email"]):
+ if not AccountService.check_email_unique(args.email):
raise EmailAlreadyInUseError()
return {"result": "success"}
diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py
index 0a8f49d2e5..9527fe782e 100644
--- a/api/controllers/console/workspace/agent_providers.py
+++ b/api/controllers/console/workspace/agent_providers.py
@@ -1,6 +1,6 @@
from flask_restx import Resource, fields
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
@@ -9,9 +9,9 @@ from services.agent_service import AgentService
@console_ns.route("/workspaces/current/agent-providers")
class AgentProviderListApi(Resource):
- @api.doc("list_agent_providers")
- @api.doc(description="Get list of available agent providers")
- @api.response(
+ @console_ns.doc("list_agent_providers")
+ @console_ns.doc(description="Get list of available agent providers")
+ @console_ns.response(
200,
"Success",
fields.List(fields.Raw(description="Agent provider information")),
@@ -31,10 +31,10 @@ class AgentProviderListApi(Resource):
@console_ns.route("/workspaces/current/agent-provider/")
class AgentProviderApi(Resource):
- @api.doc("get_agent_provider")
- @api.doc(description="Get specific agent provider details")
- @api.doc(params={"provider_name": "Agent provider name"})
- @api.response(
+ @console_ns.doc("get_agent_provider")
+ @console_ns.doc(description="Get specific agent provider details")
+ @console_ns.doc(params={"provider_name": "Agent provider name"})
+ @console_ns.response(
200,
"Success",
fields.Raw(description="Agent provider details"),
diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py
index ae870a630e..bfd9fc6c29 100644
--- a/api/controllers/console/workspace/endpoint.py
+++ b/api/controllers/console/workspace/endpoint.py
@@ -1,33 +1,65 @@
-from flask_restx import Resource, fields, reqparse
+from typing import Any
-from controllers.console import api, console_ns
+from flask import request
+from flask_restx import Resource, fields
+from pydantic import BaseModel, Field
+
+from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class EndpointCreatePayload(BaseModel):
+ plugin_unique_identifier: str
+ settings: dict[str, Any]
+ name: str = Field(min_length=1)
+
+
+class EndpointIdPayload(BaseModel):
+ endpoint_id: str
+
+
+class EndpointUpdatePayload(EndpointIdPayload):
+ settings: dict[str, Any]
+ name: str = Field(min_length=1)
+
+
+class EndpointListQuery(BaseModel):
+ page: int = Field(ge=1)
+ page_size: int = Field(gt=0)
+
+
+class EndpointListForPluginQuery(EndpointListQuery):
+ plugin_id: str
+
+
+def reg(cls: type[BaseModel]):
+ console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+reg(EndpointCreatePayload)
+reg(EndpointIdPayload)
+reg(EndpointUpdatePayload)
+reg(EndpointListQuery)
+reg(EndpointListForPluginQuery)
+
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
- @api.doc("create_endpoint")
- @api.doc(description="Create a new plugin endpoint")
- @api.expect(
- api.model(
- "EndpointCreateRequest",
- {
- "plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
- "settings": fields.Raw(required=True, description="Endpoint settings"),
- "name": fields.String(required=True, description="Endpoint name"),
- },
- )
- )
- @api.response(
+ @console_ns.doc("create_endpoint")
+ @console_ns.doc(description="Create a new plugin endpoint")
+ @console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
+ @console_ns.response(
200,
"Endpoint created successfully",
- api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
+ console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
)
- @api.response(403, "Admin privileges required")
+ @console_ns.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@@ -35,26 +67,16 @@ class EndpointCreateApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("plugin_unique_identifier", type=str, required=True)
- .add_argument("settings", type=dict, required=True)
- .add_argument("name", type=str, required=True)
- )
- args = parser.parse_args()
-
- plugin_unique_identifier = args["plugin_unique_identifier"]
- settings = args["settings"]
- name = args["name"]
+ args = EndpointCreatePayload.model_validate(console_ns.payload)
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=tenant_id,
user_id=user.id,
- plugin_unique_identifier=plugin_unique_identifier,
- name=name,
- settings=settings,
+ plugin_unique_identifier=args.plugin_unique_identifier,
+ name=args.name,
+ settings=args.settings,
)
}
except PluginPermissionDeniedError as e:
@@ -63,17 +85,15 @@ class EndpointCreateApi(Resource):
@console_ns.route("/workspaces/current/endpoints/list")
class EndpointListApi(Resource):
- @api.doc("list_endpoints")
- @api.doc(description="List plugin endpoints with pagination")
- @api.expect(
- api.parser()
- .add_argument("page", type=int, required=True, location="args", help="Page number")
- .add_argument("page_size", type=int, required=True, location="args", help="Page size")
- )
- @api.response(
+ @console_ns.doc("list_endpoints")
+ @console_ns.doc(description="List plugin endpoints with pagination")
+ @console_ns.expect(console_ns.models[EndpointListQuery.__name__])
+ @console_ns.response(
200,
"Success",
- api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}),
+ console_ns.model(
+ "EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
+ ),
)
@setup_required
@login_required
@@ -81,15 +101,10 @@ class EndpointListApi(Resource):
def get(self):
user, tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("page", type=int, required=True, location="args")
- .add_argument("page_size", type=int, required=True, location="args")
- )
- args = parser.parse_args()
+ args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- page = args["page"]
- page_size = args["page_size"]
+ page = args.page
+ page_size = args.page_size
return jsonable_encoder(
{
@@ -105,18 +120,13 @@ class EndpointListApi(Resource):
@console_ns.route("/workspaces/current/endpoints/list/plugin")
class EndpointListForSinglePluginApi(Resource):
- @api.doc("list_plugin_endpoints")
- @api.doc(description="List endpoints for a specific plugin")
- @api.expect(
- api.parser()
- .add_argument("page", type=int, required=True, location="args", help="Page number")
- .add_argument("page_size", type=int, required=True, location="args", help="Page size")
- .add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
- )
- @api.response(
+ @console_ns.doc("list_plugin_endpoints")
+ @console_ns.doc(description="List endpoints for a specific plugin")
+ @console_ns.expect(console_ns.models[EndpointListForPluginQuery.__name__])
+ @console_ns.response(
200,
"Success",
- api.model(
+ console_ns.model(
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
)
@@ -126,17 +136,11 @@ class EndpointListForSinglePluginApi(Resource):
def get(self):
user, tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("page", type=int, required=True, location="args")
- .add_argument("page_size", type=int, required=True, location="args")
- .add_argument("plugin_id", type=str, required=True, location="args")
- )
- args = parser.parse_args()
+ args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- page = args["page"]
- page_size = args["page_size"]
- plugin_id = args["plugin_id"]
+ page = args.page
+ page_size = args.page_size
+ plugin_id = args.plugin_id
return jsonable_encoder(
{
@@ -153,17 +157,15 @@ class EndpointListForSinglePluginApi(Resource):
@console_ns.route("/workspaces/current/endpoints/delete")
class EndpointDeleteApi(Resource):
- @api.doc("delete_endpoint")
- @api.doc(description="Delete a plugin endpoint")
- @api.expect(
- api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
- )
- @api.response(
+ @console_ns.doc("delete_endpoint")
+ @console_ns.doc(description="Delete a plugin endpoint")
+ @console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
+ @console_ns.response(
200,
"Endpoint deleted successfully",
- api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
+ console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
)
- @api.response(403, "Admin privileges required")
+ @console_ns.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@@ -171,36 +173,26 @@ class EndpointDeleteApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
- args = parser.parse_args()
-
- endpoint_id = args["endpoint_id"]
+ args = EndpointIdPayload.model_validate(console_ns.payload)
return {
- "success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
+ "success": EndpointService.delete_endpoint(
+ tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
+ )
}
@console_ns.route("/workspaces/current/endpoints/update")
class EndpointUpdateApi(Resource):
- @api.doc("update_endpoint")
- @api.doc(description="Update a plugin endpoint")
- @api.expect(
- api.model(
- "EndpointUpdateRequest",
- {
- "endpoint_id": fields.String(required=True, description="Endpoint ID"),
- "settings": fields.Raw(required=True, description="Updated settings"),
- "name": fields.String(required=True, description="Updated name"),
- },
- )
- )
- @api.response(
+ @console_ns.doc("update_endpoint")
+ @console_ns.doc(description="Update a plugin endpoint")
+ @console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
+ @console_ns.response(
200,
"Endpoint updated successfully",
- api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
+ console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
)
- @api.response(403, "Admin privileges required")
+ @console_ns.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@@ -208,42 +200,30 @@ class EndpointUpdateApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("endpoint_id", type=str, required=True)
- .add_argument("settings", type=dict, required=True)
- .add_argument("name", type=str, required=True)
- )
- args = parser.parse_args()
-
- endpoint_id = args["endpoint_id"]
- settings = args["settings"]
- name = args["name"]
+ args = EndpointUpdatePayload.model_validate(console_ns.payload)
return {
"success": EndpointService.update_endpoint(
tenant_id=tenant_id,
user_id=user.id,
- endpoint_id=endpoint_id,
- name=name,
- settings=settings,
+ endpoint_id=args.endpoint_id,
+ name=args.name,
+ settings=args.settings,
)
}
@console_ns.route("/workspaces/current/endpoints/enable")
class EndpointEnableApi(Resource):
- @api.doc("enable_endpoint")
- @api.doc(description="Enable a plugin endpoint")
- @api.expect(
- api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
- )
- @api.response(
+ @console_ns.doc("enable_endpoint")
+ @console_ns.doc(description="Enable a plugin endpoint")
+ @console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
+ @console_ns.response(
200,
"Endpoint enabled successfully",
- api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
+ console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
)
- @api.response(403, "Admin privileges required")
+ @console_ns.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@@ -251,29 +231,26 @@ class EndpointEnableApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
- args = parser.parse_args()
-
- endpoint_id = args["endpoint_id"]
+ args = EndpointIdPayload.model_validate(console_ns.payload)
return {
- "success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
+ "success": EndpointService.enable_endpoint(
+ tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
+ )
}
@console_ns.route("/workspaces/current/endpoints/disable")
class EndpointDisableApi(Resource):
- @api.doc("disable_endpoint")
- @api.doc(description="Disable a plugin endpoint")
- @api.expect(
- api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
- )
- @api.response(
+ @console_ns.doc("disable_endpoint")
+ @console_ns.doc(description="Disable a plugin endpoint")
+ @console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
+ @console_ns.response(
200,
"Endpoint disabled successfully",
- api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
+ console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
)
- @api.response(403, "Admin privileges required")
+ @console_ns.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@@ -281,11 +258,10 @@ class EndpointDisableApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
- parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
- args = parser.parse_args()
-
- endpoint_id = args["endpoint_id"]
+ args = EndpointIdPayload.model_validate(console_ns.payload)
return {
- "success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
+ "success": EndpointService.disable_endpoint(
+ tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
+ )
}
diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py
index 3ca453f1da..0142e14fb0 100644
--- a/api/controllers/console/workspace/members.py
+++ b/api/controllers/console/workspace/members.py
@@ -1,11 +1,12 @@
from urllib import parse
from flask import abort, request
-from flask_restx import Resource, marshal_with, reqparse
+from flask_restx import Resource, marshal_with
+from pydantic import BaseModel, Field
import services
from configs import dify_config
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
@@ -31,6 +32,42 @@ from services.account_service import AccountService, RegisterService, TenantServ
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class MemberInvitePayload(BaseModel):
+ emails: list[str] = Field(default_factory=list)
+ role: TenantAccountRole
+ language: str | None = None
+
+
+class MemberRoleUpdatePayload(BaseModel):
+ role: str
+
+
+class OwnerTransferEmailPayload(BaseModel):
+ language: str | None = None
+
+
+class OwnerTransferCheckPayload(BaseModel):
+ code: str
+ token: str
+
+
+class OwnerTransferPayload(BaseModel):
+ token: str
+
+
+def reg(cls: type[BaseModel]):
+ console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+reg(MemberInvitePayload)
+reg(MemberRoleUpdatePayload)
+reg(OwnerTransferEmailPayload)
+reg(OwnerTransferCheckPayload)
+reg(OwnerTransferPayload)
+
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource):
@@ -48,29 +85,22 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
-parser_invite = (
- reqparse.RequestParser()
- .add_argument("emails", type=list, required=True, location="json")
- .add_argument("role", type=str, required=True, default="admin", location="json")
- .add_argument("language", type=str, required=False, location="json")
-)
-
-
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource):
"""Invite a new member by email."""
- @api.expect(parser_invite)
+ @console_ns.expect(console_ns.models[MemberInvitePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
- args = parser_invite.parse_args()
+ payload = console_ns.payload or {}
+ args = MemberInvitePayload.model_validate(payload)
- invitee_emails = args["emails"]
- invitee_role = args["role"]
- interface_language = args["language"]
+ invitee_emails = args.emails
+ invitee_role = args.role
+ interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
@@ -146,20 +176,18 @@ class MemberCancelInviteApi(Resource):
}, 200
-parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
-
-
@console_ns.route("/workspaces/current/members//update-role")
class MemberUpdateRoleApi(Resource):
"""Update member role."""
- @api.expect(parser_update)
+ @console_ns.expect(console_ns.models[MemberRoleUpdatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def put(self, member_id):
- args = parser_update.parse_args()
- new_role = args["role"]
+ payload = console_ns.payload or {}
+ args = MemberRoleUpdatePayload.model_validate(payload)
+ new_role = args.role
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
@@ -197,20 +225,18 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
-parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
-
-
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email."""
- @api.expect(parser_send)
+ @console_ns.expect(console_ns.models[OwnerTransferEmailPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
- args = parser_send.parse_args()
+ payload = console_ns.payload or {}
+ args = OwnerTransferEmailPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@@ -221,7 +247,7 @@ class SendOwnerTransferEmailApi(Resource):
if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError()
- if args["language"] is not None and args["language"] == "zh-Hans":
+ if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
@@ -238,22 +264,16 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token}
-parser_owner = (
- reqparse.RequestParser()
- .add_argument("code", type=str, required=True, location="json")
- .add_argument("token", type=str, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource):
- @api.expect(parser_owner)
+ @console_ns.expect(console_ns.models[OwnerTransferCheckPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
- args = parser_owner.parse_args()
+ payload = console_ns.payload or {}
+ args = OwnerTransferCheckPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
@@ -267,41 +287,37 @@ class OwnerTransferCheckApi(Resource):
if is_owner_transfer_error_rate_limit:
raise OwnerTransferLimitError()
- token_data = AccountService.get_owner_transfer_data(args["token"])
+ token_data = AccountService.get_owner_transfer_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
- if args["code"] != token_data.get("code"):
+ if args.code != token_data.get("code"):
AccountService.add_owner_transfer_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
- AccountService.revoke_owner_transfer_token(args["token"])
+ AccountService.revoke_owner_transfer_token(args.token)
# Refresh token data by generating a new token
- _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={})
+ _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args.code, additional_data={})
AccountService.reset_owner_transfer_error_rate_limit(user_email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
-parser_owner_transfer = reqparse.RequestParser().add_argument(
- "token", type=str, required=True, nullable=False, location="json"
-)
-
-
@console_ns.route("/workspaces/current/members//owner-transfer")
class OwnerTransfer(Resource):
- @api.expect(parser_owner_transfer)
+ @console_ns.expect(console_ns.models[OwnerTransferPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id):
- args = parser_owner_transfer.parse_args()
+ payload = console_ns.payload or {}
+ args = OwnerTransferPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
@@ -313,14 +329,14 @@ class OwnerTransfer(Resource):
if current_user.id == str(member_id):
raise CannotTransferOwnerToSelfError()
- transfer_token_data = AccountService.get_owner_transfer_data(args["token"])
+ transfer_token_data = AccountService.get_owner_transfer_data(args.token)
if not transfer_token_data:
raise InvalidTokenError()
if transfer_token_data.get("email") != current_user.email:
raise InvalidEmailError()
- AccountService.revoke_owner_transfer_token(args["token"])
+ AccountService.revoke_owner_transfer_token(args.token)
member = db.session.get(Account, str(member_id))
if not member:
diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py
index 05731b3832..7bada2fa12 100644
--- a/api/controllers/console/workspace/model_providers.py
+++ b/api/controllers/console/workspace/model_providers.py
@@ -1,31 +1,97 @@
import io
+from typing import Any, Literal
-from flask import send_file
-from flask_restx import Resource, reqparse
+from flask import request, send_file
+from flask_restx import Resource
+from pydantic import BaseModel, Field, field_validator
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
-from libs.helper import StrLen, uuid_value
+from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
-parser_model = reqparse.RequestParser().add_argument(
- "model_type",
- type=str,
- required=False,
- nullable=True,
- choices=[mt.value for mt in ModelType],
- location="args",
-)
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+class ParserModelList(BaseModel):
+ model_type: ModelType | None = None
+
+
+class ParserCredentialId(BaseModel):
+ credential_id: str | None = None
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_optional_credential_id(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class ParserCredentialCreate(BaseModel):
+ credentials: dict[str, Any]
+ name: str | None = Field(default=None, max_length=30)
+
+
+class ParserCredentialUpdate(BaseModel):
+ credential_id: str
+ credentials: dict[str, Any]
+ name: str | None = Field(default=None, max_length=30)
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_update_credential_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+class ParserCredentialDelete(BaseModel):
+ credential_id: str
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_delete_credential_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+class ParserCredentialSwitch(BaseModel):
+ credential_id: str
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_switch_credential_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+class ParserCredentialValidate(BaseModel):
+ credentials: dict[str, Any]
+
+
+class ParserPreferredProviderType(BaseModel):
+ preferred_provider_type: Literal["system", "custom"]
+
+
+def reg(cls: type[BaseModel]):
+ console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+reg(ParserModelList)
+reg(ParserCredentialId)
+reg(ParserCredentialCreate)
+reg(ParserCredentialUpdate)
+reg(ParserCredentialDelete)
+reg(ParserCredentialSwitch)
+reg(ParserCredentialValidate)
+reg(ParserPreferredProviderType)
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource):
- @api.expect(parser_model)
+ @console_ns.expect(console_ns.models[ParserModelList.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -33,38 +99,18 @@ class ModelProviderListApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
- args = parser_model.parse_args()
+ payload = request.args.to_dict(flat=True) # type: ignore
+ args = ParserModelList.model_validate(payload)
model_provider_service = ModelProviderService()
- provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
+ provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.model_type)
return jsonable_encoder({"data": provider_list})
-parser_cred = reqparse.RequestParser().add_argument(
- "credential_id", type=uuid_value, required=False, nullable=True, location="args"
-)
-parser_post_cred = (
- reqparse.RequestParser()
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
- .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
-)
-
-parser_put_cred = (
- reqparse.RequestParser()
- .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
- .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
-)
-
-parser_delete_cred = reqparse.RequestParser().add_argument(
- "credential_id", type=uuid_value, required=True, nullable=False, location="json"
-)
-
-
@console_ns.route("/workspaces/current/model-providers//credentials")
class ModelProviderCredentialApi(Resource):
- @api.expect(parser_cred)
+ @console_ns.expect(console_ns.models[ParserCredentialId.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -72,23 +118,25 @@ class ModelProviderCredentialApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
- args = parser_cred.parse_args()
+ payload = request.args.to_dict(flat=True) # type: ignore
+ args = ParserCredentialId.model_validate(payload)
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credential(
- tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
+ tenant_id=tenant_id, provider=provider, credential_id=args.credential_id
)
return {"credentials": credentials}
- @api.expect(parser_post_cred)
+ @console_ns.expect(console_ns.models[ParserCredentialCreate.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_post_cred.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserCredentialCreate.model_validate(payload)
model_provider_service = ModelProviderService()
@@ -96,15 +144,15 @@ class ModelProviderCredentialApi(Resource):
model_provider_service.create_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
- credentials=args["credentials"],
- credential_name=args["name"],
+ credentials=args.credentials,
+ credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
- @api.expect(parser_put_cred)
+ @console_ns.expect(console_ns.models[ParserCredentialUpdate.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -112,7 +160,8 @@ class ModelProviderCredentialApi(Resource):
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_put_cred.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserCredentialUpdate.model_validate(payload)
model_provider_service = ModelProviderService()
@@ -120,71 +169,64 @@ class ModelProviderCredentialApi(Resource):
model_provider_service.update_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
- credentials=args["credentials"],
- credential_id=args["credential_id"],
- credential_name=args["name"],
+ credentials=args.credentials,
+ credential_id=args.credential_id,
+ credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}
- @api.expect(parser_delete_cred)
+ @console_ns.expect(console_ns.models[ParserCredentialDelete.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_delete_cred.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserCredentialDelete.model_validate(payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
- tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
+ tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
)
return {"result": "success"}, 204
-parser_switch = reqparse.RequestParser().add_argument(
- "credential_id", type=str, required=True, nullable=False, location="json"
-)
-
-
@console_ns.route("/workspaces/current/model-providers//credentials/switch")
class ModelProviderCredentialSwitchApi(Resource):
- @api.expect(parser_switch)
+ @console_ns.expect(console_ns.models[ParserCredentialSwitch.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_switch.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserCredentialSwitch.model_validate(payload)
service = ModelProviderService()
service.switch_active_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
- credential_id=args["credential_id"],
+ credential_id=args.credential_id,
)
return {"result": "success"}
-parser_validate = reqparse.RequestParser().add_argument(
- "credentials", type=dict, required=True, nullable=False, location="json"
-)
-
-
@console_ns.route("/workspaces/current/model-providers//credentials/validate")
class ModelProviderValidateApi(Resource):
- @api.expect(parser_validate)
+ @console_ns.expect(console_ns.models[ParserCredentialValidate.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_validate.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserCredentialValidate.model_validate(payload)
tenant_id = current_tenant_id
@@ -195,7 +237,7 @@ class ModelProviderValidateApi(Resource):
try:
model_provider_service.validate_provider_credentials(
- tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
+ tenant_id=tenant_id, provider=provider, credentials=args.credentials
)
except CredentialsValidateFailedError as ex:
result = False
@@ -228,19 +270,9 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype)
-parser_preferred = reqparse.RequestParser().add_argument(
- "preferred_provider_type",
- type=str,
- required=True,
- nullable=False,
- choices=["system", "custom"],
- location="json",
-)
-
-
@console_ns.route("/workspaces/current/model-providers//preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource):
- @api.expect(parser_preferred)
+ @console_ns.expect(console_ns.models[ParserPreferredProviderType.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -250,11 +282,12 @@ class PreferredProviderTypeUpdateApi(Resource):
tenant_id = current_tenant_id
- args = parser_preferred.parse_args()
+ payload = console_ns.payload or {}
+ args = ParserPreferredProviderType.model_validate(payload)
model_provider_service = ModelProviderService()
model_provider_service.switch_preferred_provider(
- tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]
+ tenant_id=tenant_id, provider=provider, preferred_provider_type=args.preferred_provider_type
)
return {"result": "success"}
diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py
index 79079f692e..a5b45ef514 100644
--- a/api/controllers/console/workspace/models.py
+++ b/api/controllers/console/workspace/models.py
@@ -1,52 +1,144 @@
import logging
+from typing import Any, cast
-from flask_restx import Resource, reqparse
+from flask import request
+from flask_restx import Resource
+from pydantic import BaseModel, Field, field_validator
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
-from libs.helper import StrLen, uuid_value
+from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__)
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-parser_get_default = reqparse.RequestParser().add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="args",
-)
-parser_post_default = reqparse.RequestParser().add_argument(
- "model_settings", type=list, required=True, nullable=False, location="json"
-)
+class ParserGetDefault(BaseModel):
+ model_type: ModelType
+
+
+class ParserPostDefault(BaseModel):
+ class Inner(BaseModel):
+ model_type: ModelType
+ model: str | None = None
+ provider: str | None = None
+
+ model_settings: list[Inner]
+
+
+class ParserDeleteModels(BaseModel):
+ model: str
+ model_type: ModelType
+
+
+class LoadBalancingPayload(BaseModel):
+ configs: list[dict[str, Any]] | None = None
+ enabled: bool | None = None
+
+
+class ParserPostModels(BaseModel):
+ model: str
+ model_type: ModelType
+ load_balancing: LoadBalancingPayload | None = None
+ config_from: str | None = None
+ credential_id: str | None = None
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_credential_id(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class ParserGetCredentials(BaseModel):
+ model: str
+ model_type: ModelType
+ config_from: str | None = None
+ credential_id: str | None = None
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_get_credential_id(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class ParserCredentialBase(BaseModel):
+ model: str
+ model_type: ModelType
+
+
+class ParserCreateCredential(ParserCredentialBase):
+ name: str | None = Field(default=None, max_length=30)
+ credentials: dict[str, Any]
+
+
+class ParserUpdateCredential(ParserCredentialBase):
+ credential_id: str
+ credentials: dict[str, Any]
+ name: str | None = Field(default=None, max_length=30)
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_update_credential_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+class ParserDeleteCredential(ParserCredentialBase):
+ credential_id: str
+
+ @field_validator("credential_id")
+ @classmethod
+ def validate_delete_credential_id(cls, value: str) -> str:
+ return uuid_value(value)
+
+
+class ParserParameter(BaseModel):
+ model: str
+
+
+def reg(cls: type[BaseModel]):
+ console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
+
+reg(ParserGetDefault)
+reg(ParserPostDefault)
+reg(ParserDeleteModels)
+reg(ParserPostModels)
+reg(ParserGetCredentials)
+reg(ParserCreateCredential)
+reg(ParserUpdateCredential)
+reg(ParserDeleteCredential)
+reg(ParserParameter)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
- @api.expect(parser_get_default)
+ @console_ns.expect(console_ns.models[ParserGetDefault.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
- args = parser_get_default.parse_args()
+ args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
- tenant_id=tenant_id, model_type=args["model_type"]
+ tenant_id=tenant_id, model_type=args.model_type
)
return jsonable_encoder({"data": default_model_entity})
- @api.expect(parser_post_default)
+ @console_ns.expect(console_ns.models[ParserPostDefault.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -54,66 +146,31 @@ class DefaultModelApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_post_default.parse_args()
+ args = ParserPostDefault.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
- model_settings = args["model_settings"]
+ model_settings = args.model_settings
for model_setting in model_settings:
- if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
- raise ValueError("invalid model type")
-
- if "provider" not in model_setting:
+ if model_setting.provider is None:
continue
- if "model" not in model_setting:
- raise ValueError("invalid model")
-
try:
model_provider_service.update_default_model_of_model_type(
tenant_id=tenant_id,
- model_type=model_setting["model_type"],
- provider=model_setting["provider"],
- model=model_setting["model"],
+ model_type=model_setting.model_type,
+ provider=model_setting.provider,
+ model=cast(str, model_setting.model),
)
except Exception as ex:
logger.exception(
"Failed to update default model, model type: %s, model: %s",
- model_setting["model_type"],
- model_setting.get("model"),
+ model_setting.model_type,
+ model_setting.model,
)
raise ex
return {"result": "success"}
-parser_post_models = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
- .add_argument("config_from", type=str, required=False, nullable=True, location="json")
- .add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
-)
-parser_delete_models = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
-)
-
-
@console_ns.route("/workspaces/current/model-providers//models")
class ModelProviderModelApi(Resource):
@setup_required
@@ -127,7 +184,7 @@ class ModelProviderModelApi(Resource):
return jsonable_encoder({"data": models})
- @api.expect(parser_post_models)
+ @console_ns.expect(console_ns.models[ParserPostModels.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -135,45 +192,45 @@ class ModelProviderModelApi(Resource):
def post(self, provider: str):
# To save the model's load balance configs
_, tenant_id = current_account_with_tenant()
- args = parser_post_models.parse_args()
+ args = ParserPostModels.model_validate(console_ns.payload)
- if args.get("config_from", "") == "custom-model":
- if not args.get("credential_id"):
+ if args.config_from == "custom-model":
+ if not args.credential_id:
raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService()
service.switch_active_custom_model_credential(
tenant_id=tenant_id,
provider=provider,
- model_type=args["model_type"],
- model=args["model"],
- credential_id=args["credential_id"],
+ model_type=args.model_type,
+ model=args.model,
+ credential_id=args.credential_id,
)
model_load_balancing_service = ModelLoadBalancingService()
- if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
+ if args.load_balancing and args.load_balancing.configs:
# save load balancing configs
model_load_balancing_service.update_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- configs=args["load_balancing"]["configs"],
- config_from=args.get("config_from", ""),
+ model=args.model,
+ model_type=args.model_type,
+ configs=args.load_balancing.configs,
+ config_from=args.config_from or "",
)
- if args.get("load_balancing", {}).get("enabled"):
+ if args.load_balancing.enabled:
model_load_balancing_service.enable_model_load_balancing(
- tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
+ tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
else:
model_load_balancing_service.disable_model_load_balancing(
- tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
+ tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}, 200
- @api.expect(parser_delete_models)
+ @console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -181,113 +238,53 @@ class ModelProviderModelApi(Resource):
def delete(self, provider: str):
_, tenant_id = current_account_with_tenant()
- args = parser_delete_models.parse_args()
+ args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_model(
- tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
+ tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}, 204
-parser_get_credentials = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="args")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="args",
- )
- .add_argument("config_from", type=str, required=False, nullable=True, location="args")
- .add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
-)
-
-
-parser_post_cred = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
-)
-parser_put_cred = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
- .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
-)
-parser_delete_cred = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
-)
-
-
@console_ns.route("/workspaces/current/model-providers//models/credentials")
class ModelProviderModelCredentialApi(Resource):
- @api.expect(parser_get_credentials)
+ @console_ns.expect(console_ns.models[ParserGetCredentials.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
- args = parser_get_credentials.parse_args()
+ args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore
model_provider_service = ModelProviderService()
current_credential = model_provider_service.get_model_credential(
tenant_id=tenant_id,
provider=provider,
- model_type=args["model_type"],
- model=args["model"],
- credential_id=args.get("credential_id"),
+ model_type=args.model_type,
+ model=args.model,
+ credential_id=args.credential_id,
)
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- config_from=args.get("config_from", ""),
+ model=args.model,
+ model_type=args.model_type,
+ config_from=args.config_from or "",
)
- if args.get("config_from", "") == "predefined-model":
+ if args.config_from == "predefined-model":
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
tenant_id=tenant_id, provider_name=provider
)
else:
- model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
+ model_type = args.model_type
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
- tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
+ tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
)
return jsonable_encoder(
@@ -304,7 +301,7 @@ class ModelProviderModelCredentialApi(Resource):
}
)
- @api.expect(parser_post_cred)
+ @console_ns.expect(console_ns.models[ParserCreateCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -312,7 +309,7 @@ class ModelProviderModelCredentialApi(Resource):
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
- args = parser_post_cred.parse_args()
+ args = ParserCreateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@@ -320,30 +317,30 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service.create_model_credential(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- credentials=args["credentials"],
- credential_name=args["name"],
+ model=args.model,
+ model_type=args.model_type,
+ credentials=args.credentials,
+ credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
logger.exception(
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
tenant_id,
- args.get("model"),
- args.get("model_type"),
+ args.model,
+ args.model_type,
)
raise ValueError(str(ex))
return {"result": "success"}, 201
- @api.expect(parser_put_cred)
+ @console_ns.expect(console_ns.models[ParserUpdateCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_put_cred.parse_args()
+ args = ParserUpdateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@@ -351,106 +348,87 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service.update_model_credential(
tenant_id=current_tenant_id,
provider=provider,
- model_type=args["model_type"],
- model=args["model"],
- credentials=args["credentials"],
- credential_id=args["credential_id"],
- credential_name=args["name"],
+ model_type=args.model_type,
+ model=args.model,
+ credentials=args.credentials,
+ credential_id=args.credential_id,
+ credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}
- @api.expect(parser_delete_cred)
+ @console_ns.expect(console_ns.models[ParserDeleteCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
- args = parser_delete_cred.parse_args()
+ args = ParserDeleteCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
tenant_id=current_tenant_id,
provider=provider,
- model_type=args["model_type"],
- model=args["model"],
- credential_id=args["credential_id"],
+ model_type=args.model_type,
+ model=args.model,
+ credential_id=args.credential_id,
)
return {"result": "success"}, 204
-parser_switch = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credential_id", type=str, required=True, nullable=False, location="json")
+class ParserSwitch(BaseModel):
+ model: str
+ model_type: ModelType
+ credential_id: str
+
+
+console_ns.schema_model(
+ ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/model-providers//models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource):
- @api.expect(parser_switch)
+ @console_ns.expect(console_ns.models[ParserSwitch.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
-
- args = parser_switch.parse_args()
+ args = ParserSwitch.model_validate(console_ns.payload)
service = ModelProviderService()
service.add_model_credential_to_model_list(
tenant_id=current_tenant_id,
provider=provider,
- model_type=args["model_type"],
- model=args["model"],
- credential_id=args["credential_id"],
+ model_type=args.model_type,
+ model=args.model,
+ credential_id=args.credential_id,
)
return {"result": "success"}
-parser_model_enable_disable = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
-)
-
-
@console_ns.route(
"/workspaces/current/model-providers//models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource):
- @api.expect(parser_model_enable_disable)
+ @console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
- args = parser_model_enable_disable.parse_args()
+ args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.enable_model(
- tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
+ tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}
@@ -460,48 +438,43 @@ class ModelProviderModelEnableApi(Resource):
"/workspaces/current/model-providers//models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource):
- @api.expect(parser_model_enable_disable)
+ @console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
- args = parser_model_enable_disable.parse_args()
+ args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.disable_model(
- tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
+ tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}
-parser_validate = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
+class ParserValidate(BaseModel):
+ model: str
+ model_type: ModelType
+ credentials: dict
+
+
+console_ns.schema_model(
+ ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/model-providers//models/credentials/validate")
class ModelProviderModelValidateApi(Resource):
- @api.expect(parser_validate)
+ @console_ns.expect(console_ns.models[ParserValidate.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
-
- args = parser_validate.parse_args()
+ args = ParserValidate.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@@ -512,9 +485,9 @@ class ModelProviderModelValidateApi(Resource):
model_provider_service.validate_model_credentials(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- credentials=args["credentials"],
+ model=args.model,
+ model_type=args.model_type,
+ credentials=args.credentials,
)
except CredentialsValidateFailedError as ex:
result = False
@@ -528,24 +501,19 @@ class ModelProviderModelValidateApi(Resource):
return response
-parser_parameter = reqparse.RequestParser().add_argument(
- "model", type=str, required=True, nullable=False, location="args"
-)
-
-
@console_ns.route("/workspaces/current/model-providers//models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource):
- @api.expect(parser_parameter)
+ @console_ns.expect(console_ns.models[ParserParameter.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
- args = parser_parameter.parse_args()
+ args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
- tenant_id=tenant_id, provider=provider, model=args["model"]
+ tenant_id=tenant_id, provider=provider, model=args.model
)
return jsonable_encoder({"data": parameter_rules})
diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py
index deae418e96..c5624e0fc2 100644
--- a/api/controllers/console/workspace/plugin.py
+++ b/api/controllers/console/workspace/plugin.py
@@ -1,11 +1,13 @@
import io
+from typing import Literal
from flask import request, send_file
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from configs import dify_config
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -17,6 +19,12 @@ from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+def reg(cls: type[BaseModel]):
+ console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@@ -37,88 +45,194 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e)
-parser_list = (
- reqparse.RequestParser()
- .add_argument("page", type=int, required=False, location="args", default=1)
- .add_argument("page_size", type=int, required=False, location="args", default=256)
-)
+class ParserList(BaseModel):
+ page: int = Field(default=1)
+ page_size: int = Field(default=256)
+
+
+reg(ParserList)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
- @api.expect(parser_list)
+ @console_ns.expect(console_ns.models[ParserList.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
- args = parser_list.parse_args()
+ args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
- plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
+ plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
-parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
+class ParserLatest(BaseModel):
+ plugin_ids: list[str]
+
+
+class ParserIcon(BaseModel):
+ tenant_id: str
+ filename: str
+
+
+class ParserAsset(BaseModel):
+ plugin_unique_identifier: str
+ file_name: str
+
+
+class ParserGithubUpload(BaseModel):
+ repo: str
+ version: str
+ package: str
+
+
+class ParserPluginIdentifiers(BaseModel):
+ plugin_unique_identifiers: list[str]
+
+
+class ParserGithubInstall(BaseModel):
+ plugin_unique_identifier: str
+ repo: str
+ version: str
+ package: str
+
+
+class ParserPluginIdentifierQuery(BaseModel):
+ plugin_unique_identifier: str
+
+
+class ParserTasks(BaseModel):
+ page: int
+ page_size: int
+
+
+class ParserMarketplaceUpgrade(BaseModel):
+ original_plugin_unique_identifier: str
+ new_plugin_unique_identifier: str
+
+
+class ParserGithubUpgrade(BaseModel):
+ original_plugin_unique_identifier: str
+ new_plugin_unique_identifier: str
+ repo: str
+ version: str
+ package: str
+
+
+class ParserUninstall(BaseModel):
+ plugin_installation_id: str
+
+
+class ParserPermissionChange(BaseModel):
+ install_permission: TenantPluginPermission.InstallPermission
+ debug_permission: TenantPluginPermission.DebugPermission
+
+
+class ParserDynamicOptions(BaseModel):
+ plugin_id: str
+ provider: str
+ action: str
+ parameter: str
+ credential_id: str | None = None
+ provider_type: Literal["tool", "trigger"]
+
+
+class PluginPermissionSettingsPayload(BaseModel):
+ install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
+ debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
+
+
+class PluginAutoUpgradeSettingsPayload(BaseModel):
+ strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting = (
+ TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY
+ )
+ upgrade_time_of_day: int = 0
+ upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
+ exclude_plugins: list[str] = Field(default_factory=list)
+ include_plugins: list[str] = Field(default_factory=list)
+
+
+class ParserPreferencesChange(BaseModel):
+ permission: PluginPermissionSettingsPayload
+ auto_upgrade: PluginAutoUpgradeSettingsPayload
+
+
+class ParserExcludePlugin(BaseModel):
+ plugin_id: str
+
+
+class ParserReadme(BaseModel):
+ plugin_unique_identifier: str
+ language: str = Field(default="en-US")
+
+
+reg(ParserLatest)
+reg(ParserIcon)
+reg(ParserAsset)
+reg(ParserGithubUpload)
+reg(ParserPluginIdentifiers)
+reg(ParserGithubInstall)
+reg(ParserPluginIdentifierQuery)
+reg(ParserTasks)
+reg(ParserMarketplaceUpgrade)
+reg(ParserGithubUpgrade)
+reg(ParserUninstall)
+reg(ParserPermissionChange)
+reg(ParserDynamicOptions)
+reg(ParserPreferencesChange)
+reg(ParserExcludePlugin)
+reg(ParserReadme)
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource):
- @api.expect(parser_latest)
+ @console_ns.expect(console_ns.models[ParserLatest.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
- args = parser_latest.parse_args()
+ args = ParserLatest.model_validate(console_ns.payload)
try:
- versions = PluginService.list_latest_versions(args["plugin_ids"])
+ versions = PluginService.list_latest_versions(args.plugin_ids)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"versions": versions})
-parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
-
-
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource):
- @api.expect(parser_ids)
+ @console_ns.expect(console_ns.models[ParserLatest.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_ids.parse_args()
+ args = ParserLatest.model_validate(console_ns.payload)
try:
- plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
+ plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins})
-parser_icon = (
- reqparse.RequestParser()
- .add_argument("tenant_id", type=str, required=True, location="args")
- .add_argument("filename", type=str, required=True, location="args")
-)
-
-
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource):
- @api.expect(parser_icon)
+ @console_ns.expect(console_ns.models[ParserIcon.__name__])
@setup_required
def get(self):
- args = parser_icon.parse_args()
+ args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
- icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
+ icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -128,20 +242,16 @@ class PluginIconApi(Resource):
@console_ns.route("/workspaces/current/plugin/asset")
class PluginAssetApi(Resource):
+ @console_ns.expect(console_ns.models[ParserAsset.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
- req = (
- reqparse.RequestParser()
- .add_argument("plugin_unique_identifier", type=str, required=True, location="args")
- .add_argument("file_name", type=str, required=True, location="args")
- )
- args = req.parse_args()
+ args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore
_, tenant_id = current_account_with_tenant()
try:
- binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"])
+ binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name)
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -171,17 +281,9 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response)
-parser_github = (
- reqparse.RequestParser()
- .add_argument("repo", type=str, required=True, location="json")
- .add_argument("version", type=str, required=True, location="json")
- .add_argument("package", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource):
- @api.expect(parser_github)
+ @console_ns.expect(console_ns.models[ParserGithubUpload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -189,10 +291,10 @@ class PluginUploadFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_github.parse_args()
+ args = ParserGithubUpload.model_validate(console_ns.payload)
try:
- response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
+ response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -223,47 +325,28 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response)
-parser_pkg = reqparse.RequestParser().add_argument(
- "plugin_unique_identifiers", type=list, required=True, location="json"
-)
-
-
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource):
- @api.expect(parser_pkg)
+ @console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_pkg.parse_args()
-
- # check if all plugin_unique_identifiers are valid string
- for plugin_unique_identifier in args["plugin_unique_identifiers"]:
- if not isinstance(plugin_unique_identifier, str):
- raise ValueError("Invalid plugin unique identifier")
+ args = ParserPluginIdentifiers.model_validate(console_ns.payload)
try:
- response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
+ response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
-parser_githubapi = (
- reqparse.RequestParser()
- .add_argument("repo", type=str, required=True, location="json")
- .add_argument("version", type=str, required=True, location="json")
- .add_argument("package", type=str, required=True, location="json")
- .add_argument("plugin_unique_identifier", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource):
- @api.expect(parser_githubapi)
+ @console_ns.expect(console_ns.models[ParserGithubInstall.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -271,15 +354,15 @@ class PluginInstallFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_githubapi.parse_args()
+ args = ParserGithubInstall.model_validate(console_ns.payload)
try:
response = PluginService.install_from_github(
tenant_id,
- args["plugin_unique_identifier"],
- args["repo"],
- args["version"],
- args["package"],
+ args.plugin_unique_identifier,
+ args.repo,
+ args.version,
+ args.package,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -287,14 +370,9 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response)
-parser_marketplace = reqparse.RequestParser().add_argument(
- "plugin_unique_identifiers", type=list, required=True, location="json"
-)
-
-
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource):
- @api.expect(parser_marketplace)
+ @console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -302,43 +380,33 @@ class PluginInstallFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_marketplace.parse_args()
-
- # check if all plugin_unique_identifiers are valid string
- for plugin_unique_identifier in args["plugin_unique_identifiers"]:
- if not isinstance(plugin_unique_identifier, str):
- raise ValueError("Invalid plugin unique identifier")
+ args = ParserPluginIdentifiers.model_validate(console_ns.payload)
try:
- response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
+ response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
-parser_pkgapi = reqparse.RequestParser().add_argument(
- "plugin_unique_identifier", type=str, required=True, location="args"
-)
-
-
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource):
- @api.expect(parser_pkgapi)
+ @console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
- args = parser_pkgapi.parse_args()
+ args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_marketplace_pkg(
tenant_id,
- args["plugin_unique_identifier"],
+ args.plugin_unique_identifier,
)
}
)
@@ -346,14 +414,9 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e)
-parser_fetch = reqparse.RequestParser().add_argument(
- "plugin_unique_identifier", type=str, required=True, location="args"
-)
-
-
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource):
- @api.expect(parser_fetch)
+ @console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -361,30 +424,19 @@ class PluginFetchManifestApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
- args = parser_fetch.parse_args()
+ args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
return jsonable_encoder(
- {
- "manifest": PluginService.fetch_plugin_manifest(
- tenant_id, args["plugin_unique_identifier"]
- ).model_dump()
- }
+ {"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
-parser_tasks = (
- reqparse.RequestParser()
- .add_argument("page", type=int, required=True, location="args")
- .add_argument("page_size", type=int, required=True, location="args")
-)
-
-
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource):
- @api.expect(parser_tasks)
+ @console_ns.expect(console_ns.models[ParserTasks.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -392,12 +444,10 @@ class PluginFetchInstallTasksApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
- args = parser_tasks.parse_args()
+ args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
- return jsonable_encoder(
- {"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
- )
+ return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)})
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -462,16 +512,9 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e)
-parser_marketplace_api = (
- reqparse.RequestParser()
- .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
- .add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource):
- @api.expect(parser_marketplace_api)
+ @console_ns.expect(console_ns.models[ParserMarketplaceUpgrade.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -479,31 +522,21 @@ class PluginUpgradeFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_marketplace_api.parse_args()
+ args = ParserMarketplaceUpgrade.model_validate(console_ns.payload)
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_marketplace(
- tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
+ tenant_id, args.original_plugin_unique_identifier, args.new_plugin_unique_identifier
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
-parser_github_post = (
- reqparse.RequestParser()
- .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
- .add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
- .add_argument("repo", type=str, required=True, location="json")
- .add_argument("version", type=str, required=True, location="json")
- .add_argument("package", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource):
- @api.expect(parser_github_post)
+ @console_ns.expect(console_ns.models[ParserGithubUpgrade.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -511,56 +544,44 @@ class PluginUpgradeFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
- args = parser_github_post.parse_args()
+ args = ParserGithubUpgrade.model_validate(console_ns.payload)
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_github(
tenant_id,
- args["original_plugin_unique_identifier"],
- args["new_plugin_unique_identifier"],
- args["repo"],
- args["version"],
- args["package"],
+ args.original_plugin_unique_identifier,
+ args.new_plugin_unique_identifier,
+ args.repo,
+ args.version,
+ args.package,
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
-parser_uninstall = reqparse.RequestParser().add_argument(
- "plugin_installation_id", type=str, required=True, location="json"
-)
-
-
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource):
- @api.expect(parser_uninstall)
+ @console_ns.expect(console_ns.models[ParserUninstall.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
- args = parser_uninstall.parse_args()
+ args = ParserUninstall.model_validate(console_ns.payload)
_, tenant_id = current_account_with_tenant()
try:
- return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
+ return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
-parser_change_post = (
- reqparse.RequestParser()
- .add_argument("install_permission", type=str, required=True, location="json")
- .add_argument("debug_permission", type=str, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource):
- @api.expect(parser_change_post)
+ @console_ns.expect(console_ns.models[ParserPermissionChange.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -570,14 +591,15 @@ class PluginChangePermissionApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
- args = parser_change_post.parse_args()
-
- install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
- debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
+ args = ParserPermissionChange.model_validate(console_ns.payload)
tenant_id = current_tenant_id
- return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
+ return {
+ "success": PluginPermissionService.change_permission(
+ tenant_id, args.install_permission, args.debug_permission
+ )
+ }
@console_ns.route("/workspaces/current/plugin/permission/fetch")
@@ -605,20 +627,9 @@ class PluginFetchPermissionApi(Resource):
)
-parser_dynamic = (
- reqparse.RequestParser()
- .add_argument("plugin_id", type=str, required=True, location="args")
- .add_argument("provider", type=str, required=True, location="args")
- .add_argument("action", type=str, required=True, location="args")
- .add_argument("parameter", type=str, required=True, location="args")
- .add_argument("credential_id", type=str, required=False, location="args")
- .add_argument("provider_type", type=str, required=True, location="args")
-)
-
-
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource):
- @api.expect(parser_dynamic)
+ @console_ns.expect(console_ns.models[ParserDynamicOptions.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -627,18 +638,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
current_user, tenant_id = current_account_with_tenant()
user_id = current_user.id
- args = parser_dynamic.parse_args()
+ args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
options = PluginParameterService.get_dynamic_select_options(
tenant_id=tenant_id,
user_id=user_id,
- plugin_id=args["plugin_id"],
- provider=args["provider"],
- action=args["action"],
- parameter=args["parameter"],
- credential_id=args["credential_id"],
- provider_type=args["provider_type"],
+ plugin_id=args.plugin_id,
+ provider=args.provider,
+ action=args.action,
+ parameter=args.parameter,
+ credential_id=args.credential_id,
+ provider_type=args.provider_type,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -646,16 +657,9 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
-parser_change = (
- reqparse.RequestParser()
- .add_argument("permission", type=dict, required=True, location="json")
- .add_argument("auto_upgrade", type=dict, required=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
- @api.expect(parser_change)
+ @console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -664,22 +668,20 @@ class PluginChangePreferencesApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
- args = parser_change.parse_args()
+ args = ParserPreferencesChange.model_validate(console_ns.payload)
- permission = args["permission"]
+ permission = args.permission
- install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
- debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone"))
+ install_permission = permission.install_permission
+ debug_permission = permission.debug_permission
- auto_upgrade = args["auto_upgrade"]
+ auto_upgrade = args.auto_upgrade
- strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting(
- auto_upgrade.get("strategy_setting", "fix_only")
- )
- upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0)
- upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude"))
- exclude_plugins = auto_upgrade.get("exclude_plugins", [])
- include_plugins = auto_upgrade.get("include_plugins", [])
+ strategy_setting = auto_upgrade.strategy_setting
+ upgrade_time_of_day = auto_upgrade.upgrade_time_of_day
+ upgrade_mode = auto_upgrade.upgrade_mode
+ exclude_plugins = auto_upgrade.exclude_plugins
+ include_plugins = auto_upgrade.include_plugins
# set permission
set_permission_result = PluginPermissionService.change_permission(
@@ -744,12 +746,9 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
-parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
-
-
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource):
- @api.expect(parser_exclude)
+ @console_ns.expect(console_ns.models[ParserExcludePlugin.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -757,28 +756,20 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
# exclude one single plugin
_, tenant_id = current_account_with_tenant()
- args = parser_exclude.parse_args()
+ args = ParserExcludePlugin.model_validate(console_ns.payload)
- return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
+ return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
@console_ns.route("/workspaces/current/plugin/readme")
class PluginReadmeApi(Resource):
+ @console_ns.expect(console_ns.models[ParserReadme.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
- parser = (
- reqparse.RequestParser()
- .add_argument("plugin_unique_identifier", type=str, required=True, location="args")
- .add_argument("language", type=str, required=False, location="args")
- )
- args = parser.parse_args()
+ args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore
return jsonable_encoder(
- {
- "readme": PluginService.fetch_plugin_readme(
- tenant_id, args["plugin_unique_identifier"], args.get("language", "en-US")
- )
- }
+ {"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}
)
diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py
index 917059bb4c..2c54aa5a20 100644
--- a/api/controllers/console/workspace/tool_providers.py
+++ b/api/controllers/console/workspace/tool_providers.py
@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
-from controllers.console import api, console_ns
+from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
@@ -65,7 +65,7 @@ parser_tool = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-providers")
class ToolProviderListApi(Resource):
- @api.expect(parser_tool)
+ @console_ns.expect(parser_tool)
@setup_required
@login_required
@account_initialization_required
@@ -113,7 +113,7 @@ parser_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/builtin//delete")
class ToolBuiltinProviderDeleteApi(Resource):
- @api.expect(parser_delete)
+ @console_ns.expect(parser_delete)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -140,7 +140,7 @@ parser_add = (
@console_ns.route("/workspaces/current/tool-provider/builtin//add")
class ToolBuiltinProviderAddApi(Resource):
- @api.expect(parser_add)
+ @console_ns.expect(parser_add)
@setup_required
@login_required
@account_initialization_required
@@ -174,7 +174,7 @@ parser_update = (
@console_ns.route("/workspaces/current/tool-provider/builtin//update")
class ToolBuiltinProviderUpdateApi(Resource):
- @api.expect(parser_update)
+ @console_ns.expect(parser_update)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -236,7 +236,7 @@ parser_api_add = (
@console_ns.route("/workspaces/current/tool-provider/api/add")
class ToolApiProviderAddApi(Resource):
- @api.expect(parser_api_add)
+ @console_ns.expect(parser_api_add)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -267,7 +267,7 @@ parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=
@console_ns.route("/workspaces/current/tool-provider/api/remote")
class ToolApiProviderGetRemoteSchemaApi(Resource):
- @api.expect(parser_remote)
+ @console_ns.expect(parser_remote)
@setup_required
@login_required
@account_initialization_required
@@ -292,7 +292,7 @@ parser_tools = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/api/tools")
class ToolApiProviderListToolsApi(Resource):
- @api.expect(parser_tools)
+ @console_ns.expect(parser_tools)
@setup_required
@login_required
@account_initialization_required
@@ -328,7 +328,7 @@ parser_api_update = (
@console_ns.route("/workspaces/current/tool-provider/api/update")
class ToolApiProviderUpdateApi(Resource):
- @api.expect(parser_api_update)
+ @console_ns.expect(parser_api_update)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -362,7 +362,7 @@ parser_api_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/api/delete")
class ToolApiProviderDeleteApi(Resource):
- @api.expect(parser_api_delete)
+ @console_ns.expect(parser_api_delete)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -386,7 +386,7 @@ parser_get = reqparse.RequestParser().add_argument("provider", type=str, require
@console_ns.route("/workspaces/current/tool-provider/api/get")
class ToolApiProviderGetApi(Resource):
- @api.expect(parser_get)
+ @console_ns.expect(parser_get)
@setup_required
@login_required
@account_initialization_required
@@ -426,7 +426,7 @@ parser_schema = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/api/schema")
class ToolApiProviderSchemaApi(Resource):
- @api.expect(parser_schema)
+ @console_ns.expect(parser_schema)
@setup_required
@login_required
@account_initialization_required
@@ -451,7 +451,7 @@ parser_pre = (
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
class ToolApiProviderPreviousTestApi(Resource):
- @api.expect(parser_pre)
+ @console_ns.expect(parser_pre)
@setup_required
@login_required
@account_initialization_required
@@ -484,7 +484,7 @@ parser_create = (
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
class ToolWorkflowProviderCreateApi(Resource):
- @api.expect(parser_create)
+ @console_ns.expect(parser_create)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -525,7 +525,7 @@ parser_workflow_update = (
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
class ToolWorkflowProviderUpdateApi(Resource):
- @api.expect(parser_workflow_update)
+ @console_ns.expect(parser_workflow_update)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -560,7 +560,7 @@ parser_workflow_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
class ToolWorkflowProviderDeleteApi(Resource):
- @api.expect(parser_workflow_delete)
+ @console_ns.expect(parser_workflow_delete)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -588,7 +588,7 @@ parser_wf_get = (
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
class ToolWorkflowProviderGetApi(Resource):
- @api.expect(parser_wf_get)
+ @console_ns.expect(parser_wf_get)
@setup_required
@login_required
@account_initialization_required
@@ -624,7 +624,7 @@ parser_wf_tools = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
class ToolWorkflowProviderListToolApi(Resource):
- @api.expect(parser_wf_tools)
+ @console_ns.expect(parser_wf_tools)
@setup_required
@login_required
@account_initialization_required
@@ -813,7 +813,7 @@ parser_default_cred = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/builtin//default-credential")
class ToolBuiltinProviderSetDefaultApi(Resource):
- @api.expect(parser_default_cred)
+ @console_ns.expect(parser_default_cred)
@setup_required
@login_required
@account_initialization_required
@@ -834,7 +834,7 @@ parser_custom = (
@console_ns.route("/workspaces/current/tool-provider/builtin//oauth/custom-client")
class ToolOAuthCustomClient(Resource):
- @api.expect(parser_custom)
+ @console_ns.expect(parser_custom)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -932,7 +932,7 @@ parser_mcp_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/mcp")
class ToolProviderMCPApi(Resource):
- @api.expect(parser_mcp)
+ @console_ns.expect(parser_mcp)
@setup_required
@login_required
@account_initialization_required
@@ -962,7 +962,7 @@ class ToolProviderMCPApi(Resource):
)
return jsonable_encoder(result)
- @api.expect(parser_mcp_put)
+ @console_ns.expect(parser_mcp_put)
@setup_required
@login_required
@account_initialization_required
@@ -1001,7 +1001,7 @@ class ToolProviderMCPApi(Resource):
)
return {"result": "success"}
- @api.expect(parser_mcp_delete)
+ @console_ns.expect(parser_mcp_delete)
@setup_required
@login_required
@account_initialization_required
@@ -1024,7 +1024,7 @@ parser_auth = (
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
class ToolMCPAuthApi(Resource):
- @api.expect(parser_auth)
+ @console_ns.expect(parser_auth)
@setup_required
@login_required
@account_initialization_required
@@ -1142,7 +1142,7 @@ parser_cb = (
@console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource):
- @api.expect(parser_cb)
+ @console_ns.expect(parser_cb)
def get(self):
args = parser_cb.parse_args()
state_key = args["state"]
diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py
index b2abae0b3d..268473d6d1 100644
--- a/api/controllers/console/workspace/trigger_providers.py
+++ b/api/controllers/console/workspace/trigger_providers.py
@@ -6,8 +6,6 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
-from controllers.console import api
-from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@@ -23,9 +21,18 @@ from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
+from .. import console_ns
+from ..wraps import (
+ account_initialization_required,
+ edit_permission_required,
+ is_admin_or_owner_required,
+ setup_required,
+)
+
logger = logging.getLogger(__name__)
+@console_ns.route("/workspaces/current/trigger-provider//icon")
class TriggerProviderIconApi(Resource):
@setup_required
@login_required
@@ -38,6 +45,7 @@ class TriggerProviderIconApi(Resource):
return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider)
+@console_ns.route("/workspaces/current/triggers")
class TriggerProviderListApi(Resource):
@setup_required
@login_required
@@ -50,6 +58,7 @@ class TriggerProviderListApi(Resource):
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
+@console_ns.route("/workspaces/current/trigger-provider//info")
class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@@ -64,10 +73,11 @@ class TriggerProviderInfoApi(Resource):
)
+@console_ns.route("/workspaces/current/trigger-provider//subscriptions/list")
class TriggerSubscriptionListApi(Resource):
@setup_required
@login_required
- @is_admin_or_owner_required
+ @edit_permission_required
@account_initialization_required
def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider"""
@@ -87,19 +97,25 @@ class TriggerSubscriptionListApi(Resource):
raise
+parser = reqparse.RequestParser().add_argument(
+ "credential_type", type=str, required=False, nullable=True, location="json"
+)
+
+
+@console_ns.route(
+ "/workspaces/current/trigger-provider//subscriptions/builder/create",
+)
class TriggerSubscriptionBuilderCreateApi(Resource):
+ @console_ns.expect(parser)
@setup_required
@login_required
- @is_admin_or_owner_required
+ @edit_permission_required
@account_initialization_required
def post(self, provider):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
- parser = reqparse.RequestParser().add_argument(
- "credential_type", type=str, required=False, nullable=True, location="json"
- )
args = parser.parse_args()
try:
@@ -116,9 +132,13 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
raise
+@console_ns.route(
+ "/workspaces/current/trigger-provider//subscriptions/builder/",
+)
class TriggerSubscriptionBuilderGetApi(Resource):
@setup_required
@login_required
+ @edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get a subscription instance for a trigger provider"""
@@ -127,21 +147,28 @@ class TriggerSubscriptionBuilderGetApi(Resource):
)
+parser_api = (
+ reqparse.RequestParser()
+ # The credentials of the subscription builder
+ .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
+)
+
+
+@console_ns.route(
+ "/workspaces/current/trigger-provider//subscriptions/builder/verify/",
+)
class TriggerSubscriptionBuilderVerifyApi(Resource):
+ @console_ns.expect(parser_api)
@setup_required
@login_required
- @is_admin_or_owner_required
+ @edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
- parser = (
- reqparse.RequestParser()
- # The credentials of the subscription builder
- .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
- )
- args = parser.parse_args()
+
+ args = parser_api.parse_args()
try:
# Use atomic update_and_verify to prevent race conditions
@@ -159,9 +186,27 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
raise ValueError(str(e)) from e
+parser_update_api = (
+ reqparse.RequestParser()
+ # The name of the subscription builder
+ .add_argument("name", type=str, required=False, nullable=True, location="json")
+ # The parameters of the subscription builder
+ .add_argument("parameters", type=dict, required=False, nullable=True, location="json")
+ # The properties of the subscription builder
+ .add_argument("properties", type=dict, required=False, nullable=True, location="json")
+ # The credentials of the subscription builder
+ .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
+)
+
+
+@console_ns.route(
+ "/workspaces/current/trigger-provider//subscriptions/builder/update/",
+)
class TriggerSubscriptionBuilderUpdateApi(Resource):
+ @console_ns.expect(parser_update_api)
@setup_required
@login_required
+ @edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Update a subscription instance for a trigger provider"""
@@ -169,18 +214,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
assert isinstance(user, Account)
assert user.current_tenant_id is not None
- parser = (
- reqparse.RequestParser()
- # The name of the subscription builder
- .add_argument("name", type=str, required=False, nullable=True, location="json")
- # The parameters of the subscription builder
- .add_argument("parameters", type=dict, required=False, nullable=True, location="json")
- # The properties of the subscription builder
- .add_argument("properties", type=dict, required=False, nullable=True, location="json")
- # The credentials of the subscription builder
- .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
- )
- args = parser.parse_args()
+ args = parser_update_api.parse_args()
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
@@ -200,9 +234,13 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
raise
+@console_ns.route(
+ "/workspaces/current/trigger-provider//subscriptions/builder/logs/",
+)
class TriggerSubscriptionBuilderLogsApi(Resource):
@setup_required
@login_required
+ @edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get the request logs for a subscription instance for a trigger provider"""
@@ -218,27 +256,20 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
raise
+@console_ns.route(
+ "/workspaces/current/trigger-provider//subscriptions/builder/build/",
+)
class TriggerSubscriptionBuilderBuildApi(Resource):
+ @console_ns.expect(parser_update_api)
@setup_required
@login_required
- @is_admin_or_owner_required
+ @edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
- parser = (
- reqparse.RequestParser()
- # The name of the subscription builder
- .add_argument("name", type=str, required=False, nullable=True, location="json")
- # The parameters of the subscription builder
- .add_argument("parameters", type=dict, required=False, nullable=True, location="json")
- # The properties of the subscription builder
- .add_argument("properties", type=dict, required=False, nullable=True, location="json")
- # The credentials of the subscription builder
- .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
- )
- args = parser.parse_args()
+ args = parser_update_api.parse_args()
try:
# Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder(
@@ -258,6 +289,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
raise ValueError(str(e)) from e
+@console_ns.route(
+ "/workspaces/current/trigger-provider//subscriptions/delete",
+)
class TriggerSubscriptionDeleteApi(Resource):
@setup_required
@login_required
@@ -291,6 +325,7 @@ class TriggerSubscriptionDeleteApi(Resource):
raise
+@console_ns.route("/workspaces/current/trigger-provider//subscriptions/oauth/authorize")
class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@@ -374,6 +409,7 @@ class TriggerOAuthAuthorizeApi(Resource):
raise
+@console_ns.route("/oauth/plugin//trigger/callback")
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
@@ -438,6 +474,14 @@ class TriggerOAuthCallbackApi(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
+parser_oauth_client = (
+ reqparse.RequestParser()
+ .add_argument("client_params", type=dict, required=False, nullable=True, location="json")
+ .add_argument("enabled", type=bool, required=False, nullable=True, location="json")
+)
+
+
+@console_ns.route("/workspaces/current/trigger-provider//oauth/client")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@login_required
@@ -484,6 +528,7 @@ class TriggerOAuthClientManageApi(Resource):
logger.exception("Error getting OAuth client", exc_info=e)
raise
+ @console_ns.expect(parser_oauth_client)
@setup_required
@login_required
@is_admin_or_owner_required
@@ -493,12 +538,7 @@ class TriggerOAuthClientManageApi(Resource):
user = current_user
assert user.current_tenant_id is not None
- parser = (
- reqparse.RequestParser()
- .add_argument("client_params", type=dict, required=False, nullable=True, location="json")
- .add_argument("enabled", type=bool, required=False, nullable=True, location="json")
- )
- args = parser.parse_args()
+ args = parser_oauth_client.parse_args()
try:
provider_id = TriggerProviderID(provider)
@@ -536,48 +576,3 @@ class TriggerOAuthClientManageApi(Resource):
except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e)
raise
-
-
-# Trigger Subscription
-api.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider//icon")
-api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
-api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider//info")
-api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider//subscriptions/list")
-api.add_resource(
- TriggerSubscriptionDeleteApi,
- "/workspaces/current/trigger-provider//subscriptions/delete",
-)
-
-# Trigger Subscription Builder
-api.add_resource(
- TriggerSubscriptionBuilderCreateApi,
- "/workspaces/current/trigger-provider//subscriptions/builder/create",
-)
-api.add_resource(
- TriggerSubscriptionBuilderGetApi,
- "/workspaces/current/trigger-provider//subscriptions/builder/",
-)
-api.add_resource(
- TriggerSubscriptionBuilderUpdateApi,
- "/workspaces/current/trigger-provider//subscriptions/builder/update/",
-)
-api.add_resource(
- TriggerSubscriptionBuilderVerifyApi,
- "/workspaces/current/trigger-provider//subscriptions/builder/verify/",
-)
-api.add_resource(
- TriggerSubscriptionBuilderBuildApi,
- "/workspaces/current/trigger-provider//subscriptions/builder/build/